mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[ReplicaGroupV3][MeshAxesReplicaGroupList][2/2] Add flattened_replica_groups function for MeshAxesReplicaGroupList.
PiperOrigin-RevId: 826619318
This commit is contained in:
committed by
TensorFlower Gardener
parent
a6e123761d
commit
261e077984
1
third_party/xla/xla/hlo/ir/BUILD
vendored
1
third_party/xla/xla/hlo/ir/BUILD
vendored
@@ -422,6 +422,7 @@ xla_cc_test(
|
||||
":mesh_and_axis",
|
||||
":tile_assignment",
|
||||
"//xla:array",
|
||||
"//xla:array2d",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/service:hlo_proto_cc",
|
||||
"//xla/tsl/platform:test_main",
|
||||
|
||||
39
third_party/xla/xla/hlo/ir/mesh_and_axis.cc
vendored
39
third_party/xla/xla/hlo/ir/mesh_and_axis.cc
vendored
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
@@ -101,4 +102,42 @@ AxisRef AxisRef::FromProto(const AxisRefProto& proto) {
|
||||
return axis_ref;
|
||||
}
|
||||
|
||||
bool canSubAxesCoexist(int64_t minPreSize, int64_t maxPreSize,
|
||||
int64_t minNextPreSize, int64_t maxNextPreSize) {
|
||||
if (minNextPreSize > maxPreSize) {
|
||||
// Sub-axes overlap, check if overlapping and non-overlapping parts are
|
||||
// valid.
|
||||
return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 &&
|
||||
maxNextPreSize % minNextPreSize == 0;
|
||||
}
|
||||
// Sub-axes don't overlap, check if the gap is valid.
|
||||
return maxPreSize % minNextPreSize == 0;
|
||||
}
|
||||
|
||||
bool AxisRef::CanCoexist(const AxisRef& other) const {
|
||||
if (mesh_axis_index() != other.mesh_axis_index()) {
|
||||
return true;
|
||||
}
|
||||
if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
|
||||
// If one is a full axis and the other is a sub-axis, they can coexist.
|
||||
return true;
|
||||
}
|
||||
|
||||
const SubAxis& this_sub_axis = sub_axis_info_.value();
|
||||
const SubAxis& other_sub_axis = other.sub_axis_info_.value();
|
||||
|
||||
int64_t this_pre_size = this_sub_axis.pre_size;
|
||||
int64_t other_pre_size = other_sub_axis.pre_size;
|
||||
int64_t this_next_pre_size = this_sub_axis.next_pre_size();
|
||||
int64_t other_next_pre_size = other_sub_axis.next_pre_size();
|
||||
|
||||
auto [min_pre_size, max_pre_size] =
|
||||
std::minmax(this_pre_size, other_pre_size);
|
||||
auto [min_next_pre_size, max_next_pre_size] =
|
||||
std::minmax(this_next_pre_size, other_next_pre_size);
|
||||
|
||||
return canSubAxesCoexist(min_pre_size, max_pre_size, min_next_pre_size,
|
||||
max_next_pre_size);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
9
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
9
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
@@ -109,6 +109,12 @@ class Mesh {
|
||||
|
||||
TileAssignment device_assignment() const { return device_assignment_; }
|
||||
std::vector<std::string> axis_names() const { return axes_names_; }
|
||||
absl::Span<const int64_t> axis_sizes() const {
|
||||
return device_assignment_.dimensions();
|
||||
}
|
||||
int64_t axis_size(int64_t axis_index) const {
|
||||
return device_assignment_.dim(axis_index);
|
||||
}
|
||||
|
||||
private:
|
||||
// Dimensions of the `device_assignment_` array correspond to the axes of the
|
||||
@@ -127,6 +133,7 @@ class AxisRef {
|
||||
struct SubAxis {
|
||||
int64_t pre_size;
|
||||
int64_t size;
|
||||
int64_t next_pre_size() const { return pre_size * size; }
|
||||
};
|
||||
|
||||
// Index corresponding to axis in the mesh. It should be a valid index into
|
||||
@@ -177,6 +184,8 @@ class AxisRef {
|
||||
|
||||
static AxisRef FromProto(const AxisRefProto& proto);
|
||||
|
||||
bool CanCoexist(const AxisRef& other) const;
|
||||
|
||||
int64_t mesh_axis_index() const { return mesh_axis_index_; }
|
||||
std::optional<SubAxis> sub_axis_info() const { return sub_axis_info_; }
|
||||
};
|
||||
|
||||
244
third_party/xla/xla/hlo/ir/replica_group.cc
vendored
244
third_party/xla/xla/hlo/ir/replica_group.cc
vendored
@@ -15,12 +15,18 @@ limitations under the License.
|
||||
|
||||
#include "xla/hlo/ir/replica_group.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@@ -33,7 +39,6 @@ limitations under the License.
|
||||
#include "xla/service/hlo.pb.h"
|
||||
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
|
||||
#include "xla/xla_data.pb.h"
|
||||
#include "tsl/platform/protobuf.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@@ -49,6 +54,145 @@ std::string ReplicaGroupsToString(
|
||||
}
|
||||
|
||||
/************** MeshAxesReplicaGroupList implementation ***********************/
|
||||
|
||||
void HandleSingleAxisRefPerDimension(const AxisRef& axis,
|
||||
int64_t full_axis_size,
|
||||
std::vector<int64_t>& out_reshape_dims,
|
||||
std::vector<int64_t>& out_aggregate_axes) {
|
||||
if (axis.sub_axis_info().has_value()) {
|
||||
out_reshape_dims = {axis.sub_axis_info()->pre_size,
|
||||
axis.sub_axis_info()->size,
|
||||
full_axis_size / axis.sub_axis_info()->next_pre_size()};
|
||||
// The aggregation axis is the second dimension.
|
||||
out_aggregate_axes = {1};
|
||||
} else {
|
||||
out_reshape_dims = {full_axis_size};
|
||||
out_aggregate_axes = {0};
|
||||
}
|
||||
}
|
||||
|
||||
void HandleMultiAxisRefPerDimension(std::vector<AxisRef>& axes,
|
||||
int64_t full_axis_size,
|
||||
std::vector<int64_t>& out_reshape_dims,
|
||||
std::vector<int64_t>& out_aggregate_axes) {
|
||||
// --- 1. Sort Axes and Original Indices Together ---
|
||||
// Sort both the axes and the original indices based on
|
||||
// sub_axis_info()->pre_size. This allows us to maintain user specified order
|
||||
// of AxisRef while still building the reshape and aggregate axes.
|
||||
std::vector<int> original_order(axes.size());
|
||||
std::iota(original_order.begin(), original_order.end(), 0);
|
||||
std::sort(original_order.begin(), original_order.end(),
|
||||
[&axes](int i, int j) {
|
||||
return axes[i].sub_axis_info()->pre_size <
|
||||
axes[j].sub_axis_info()->pre_size;
|
||||
});
|
||||
std::sort(axes.begin(), axes.end(), [](const AxisRef& a, const AxisRef& b) {
|
||||
return a.sub_axis_info()->pre_size < b.sub_axis_info()->pre_size;
|
||||
});
|
||||
|
||||
// --- 2. Build Reshape Dims and Aggregation Axes ---
|
||||
int64_t current_dim_index = 0; // Index in the new reshaped tensor
|
||||
int64_t prefix_product = 1; // Product of the size of all prior dimensions
|
||||
|
||||
for (const AxisRef& axis : axes) {
|
||||
int64_t pre_size = axis.sub_axis_info()->pre_size;
|
||||
int64_t size = axis.sub_axis_info()->size;
|
||||
|
||||
// Insert "padding" dimension if the current prefix product doesn't match
|
||||
// the required pre_size
|
||||
if (pre_size != prefix_product) {
|
||||
int64_t padding_size = pre_size / prefix_product;
|
||||
out_reshape_dims.push_back(padding_size);
|
||||
current_dim_index++;
|
||||
prefix_product *= padding_size;
|
||||
}
|
||||
|
||||
// Insert the sharded size (the part to aggregate)
|
||||
out_reshape_dims.push_back(size);
|
||||
out_aggregate_axes.push_back(
|
||||
current_dim_index); // This is the axis we aggregate over
|
||||
current_dim_index++;
|
||||
prefix_product *= size;
|
||||
}
|
||||
|
||||
// Insert "suffix" dimension if the full size hasn't been reached
|
||||
if (prefix_product != full_axis_size) {
|
||||
out_reshape_dims.push_back(full_axis_size / prefix_product);
|
||||
}
|
||||
|
||||
// --- 3. Permute Aggregate Axes back to Original Order ---
|
||||
// The aggregate axes were calculated based on the sorted list.
|
||||
// We must map them back to the original order to compute the correct
|
||||
// flattened replica groups.
|
||||
std::vector<int64_t> permuted_aggregate_axes(original_order.size());
|
||||
for (int64_t i = 0; i < original_order.size(); ++i) {
|
||||
permuted_aggregate_axes[original_order[i]] = out_aggregate_axes[i];
|
||||
}
|
||||
out_aggregate_axes = permuted_aggregate_axes;
|
||||
}
|
||||
|
||||
bool ValidateSingleDimensionAxes(int64_t dim, std::vector<AxisRef>& axes,
|
||||
const Mesh& mesh_) {
|
||||
// If there's only one axis, nothing to check.
|
||||
if (axes.size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
// --- Step 1: Deduplication ---
|
||||
// If one is a "full" axis (no sub_axis_info), it subsumes all other AxisRefs
|
||||
// with sub_axis_info.
|
||||
for (const AxisRef& axis : axes) {
|
||||
if (!axis.sub_axis_info().has_value()) {
|
||||
LOG(WARNING) << "MeshAxesReplicaGroupList: Redundant axis definition at "
|
||||
"dimension: "
|
||||
<< dim
|
||||
<< ". Keeping only the full axis: " << axis.ToString(mesh_);
|
||||
axes = {axis};
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// --- Step 2: Overlap Check ---
|
||||
// At this point, all remaining axes MUST have sub_axis_info().
|
||||
// Verify that the remaining multiple sub-axes do not overlap.
|
||||
for (int64_t i = 0; i < axes.size() - 1; ++i) {
|
||||
for (int64_t j = i + 1; j < axes.size(); ++j) {
|
||||
// CHECK will terminate the program on failure, matching original
|
||||
// behavior.
|
||||
CHECK(axes[i].CanCoexist(axes[j]))
|
||||
<< "Overlapping sub-axes detected: " << axes[i].ToString(mesh_)
|
||||
<< " and " << axes[j].ToString(mesh_);
|
||||
}
|
||||
}
|
||||
return true; // Passed all checks for this dimension.
|
||||
}
|
||||
|
||||
MeshAxesReplicaGroupList::MeshAxesReplicaGroupList(Mesh mesh,
|
||||
std::vector<AxisRef> axes)
|
||||
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
|
||||
if (num_devices_per_group() == 1) {
|
||||
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
|
||||
<< " has only one device per replica group.";
|
||||
}
|
||||
|
||||
absl::flat_hash_set<int64_t> dimensions;
|
||||
absl::flat_hash_map<int64_t, std::vector<AxisRef>> dim_to_axes;
|
||||
for (const AxisRef& axis : axes_) {
|
||||
dim_to_axes[axis.mesh_axis_index()].push_back(axis);
|
||||
dimensions.insert(axis.mesh_axis_index());
|
||||
if (axis.sub_axis_info().has_value()) {
|
||||
CHECK(mesh_.axis_size(axis.mesh_axis_index()) %
|
||||
axis.sub_axis_info()->next_pre_size() ==
|
||||
0)
|
||||
<< "Next pre-size must divide the full axis size.";
|
||||
}
|
||||
}
|
||||
|
||||
// Validate input AxisRefs.
|
||||
for (int64_t dim : dimensions) {
|
||||
std::vector<AxisRef>& axes = dim_to_axes[dim];
|
||||
CHECK(ValidateSingleDimensionAxes(dim, axes, mesh_));
|
||||
}
|
||||
}
|
||||
|
||||
int64_t MeshAxesReplicaGroupList::num_replica_groups() const {
|
||||
return mesh_.device_assignment().num_elements() / num_devices_per_group();
|
||||
}
|
||||
@@ -58,15 +202,105 @@ int64_t MeshAxesReplicaGroupList::num_devices_per_group() const {
|
||||
// all axes.
|
||||
int64_t devices_per_group = 1;
|
||||
for (const AxisRef& axis : axes_) {
|
||||
int64_t axis_size =
|
||||
axis.sub_axis_info().has_value()
|
||||
? axis.sub_axis_info()->size
|
||||
: mesh_.device_assignment().dim(axis.mesh_axis_index());
|
||||
int64_t axis_size = axis.sub_axis_info().has_value()
|
||||
? axis.sub_axis_info()->size
|
||||
: mesh_.axis_size(axis.mesh_axis_index());
|
||||
devices_per_group *= axis_size;
|
||||
}
|
||||
return devices_per_group;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> get_replica_groups_for_full_axes(
|
||||
const Mesh& mesh, absl::Span<const int64_t> axis_sizes,
|
||||
const absl::Span<const int64_t> grouped_axes,
|
||||
const int64_t num_replica_groups, const int64_t num_devices_per_group) {
|
||||
// Reshape the device assignment array bases on the axis sizes and transpose
|
||||
// grouped axes to the end.
|
||||
std::vector<int> transpose_axes;
|
||||
transpose_axes.reserve(axis_sizes.size());
|
||||
for (int64_t i = 0; i < axis_sizes.size(); ++i) {
|
||||
if (!absl::c_linear_search(grouped_axes, i)) {
|
||||
transpose_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
for (int64_t grouped_axis : grouped_axes) {
|
||||
transpose_axes.push_back(grouped_axis);
|
||||
}
|
||||
|
||||
TileAssignment device_assignment =
|
||||
mesh.device_assignment().Reshape(axis_sizes).Transpose(transpose_axes);
|
||||
|
||||
std::vector<std::vector<int64_t>> replica_groups;
|
||||
replica_groups.reserve(num_replica_groups);
|
||||
for (auto it = device_assignment.array().begin();
|
||||
it != device_assignment.array().end(); it += num_devices_per_group) {
|
||||
std::vector<int64_t> group(it, it + num_devices_per_group);
|
||||
replica_groups.emplace_back(std::move(group));
|
||||
}
|
||||
return replica_groups;
|
||||
}
|
||||
|
||||
void MeshAxesReplicaGroupList::InitializeDimToReshapeAndAggregateAxes() {
|
||||
absl::flat_hash_map<int64_t, std::vector<AxisRef>> dim_to_axes;
|
||||
for (const AxisRef& axis : axes_) {
|
||||
dim_to_axes[axis.mesh_axis_index()].push_back(axis);
|
||||
}
|
||||
absl::flat_hash_map<int64_t, ReshapeAndAggregateAxes> dim_map;
|
||||
// For each dimension determine the reshape that is consistent with it's
|
||||
// AxisRef(s). Then maintain this reshape and the aggregated dims for easier
|
||||
// computation of replica groups. As an example for @mesh<"a"=8>
|
||||
// {a} -> no reshape, aggregate over [0]
|
||||
// {a:(1)2} -> reshape [8]->[1,2,4], aggregate over [1]
|
||||
// {a:(1)2, a:(4)2} -> reshape [8]->[2,2,2], aggregate over [0,2]
|
||||
for (auto& [dim, axes] : dim_to_axes) {
|
||||
int64_t full_axis_size = mesh_.axis_size(dim);
|
||||
ReshapeAndAggregateAxes reshape_and_aggregate_axes;
|
||||
if (axes.size() == 1) {
|
||||
HandleSingleAxisRefPerDimension(
|
||||
axes[0], full_axis_size, reshape_and_aggregate_axes.reshape_dims,
|
||||
reshape_and_aggregate_axes.aggregate_axes);
|
||||
} else {
|
||||
// Otherwise dimension is a set of axes with sub-axes info.
|
||||
HandleMultiAxisRefPerDimension(axes, full_axis_size,
|
||||
reshape_and_aggregate_axes.reshape_dims,
|
||||
reshape_and_aggregate_axes.aggregate_axes);
|
||||
}
|
||||
dim_map[dim] = reshape_and_aggregate_axes;
|
||||
}
|
||||
dim_to_reshape_and_aggregate_axes_ = dim_map;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
MeshAxesReplicaGroupList::flattened_replica_groups() {
|
||||
if (!dim_to_reshape_and_aggregate_axes_.has_value()) {
|
||||
InitializeDimToReshapeAndAggregateAxes();
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int64_t, ReshapeAndAggregateAxes> dim_map =
|
||||
dim_to_reshape_and_aggregate_axes_.value();
|
||||
std::vector<int64_t> reindex_axis_sizes;
|
||||
std::vector<int64_t> reindexed_grouped_axes;
|
||||
for (int64_t i = 0; i < mesh_.axis_sizes().size(); ++i) {
|
||||
int64_t axis_size = mesh_.axis_size(i);
|
||||
auto it = dim_map.find(i);
|
||||
if (it == dim_map.end()) {
|
||||
reindex_axis_sizes.push_back(axis_size);
|
||||
continue;
|
||||
}
|
||||
int64_t offset_index = reindex_axis_sizes.size();
|
||||
const ReshapeAndAggregateAxes& reshape_and_aggregate_axes = it->second;
|
||||
for (int64_t reshape_dim : reshape_and_aggregate_axes.reshape_dims) {
|
||||
reindex_axis_sizes.push_back(reshape_dim);
|
||||
}
|
||||
for (int64_t aggregate_dim : reshape_and_aggregate_axes.aggregate_axes) {
|
||||
reindexed_grouped_axes.push_back(aggregate_dim + offset_index);
|
||||
}
|
||||
}
|
||||
return get_replica_groups_for_full_axes(
|
||||
mesh_, reindex_axis_sizes, reindexed_grouped_axes, num_replica_groups(),
|
||||
num_devices_per_group());
|
||||
}
|
||||
|
||||
void MeshAxesReplicaGroupList::Print(Printer* printer) const {
|
||||
printer->Append(ToString());
|
||||
}
|
||||
|
||||
22
third_party/xla/xla/hlo/ir/replica_group.h
vendored
22
third_party/xla/xla/hlo/ir/replica_group.h
vendored
@@ -16,14 +16,19 @@ limitations under the License.
|
||||
#ifndef XLA_HLO_IR_REPLICA_GROUP_H_
|
||||
#define XLA_HLO_IR_REPLICA_GROUP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "google/protobuf/repeated_ptr_field.h"
|
||||
@@ -38,14 +43,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
class MeshAxesReplicaGroupList {
|
||||
struct ReshapeAndAggregateAxes {
|
||||
std::vector<int64_t> reshape_dims;
|
||||
std::vector<int64_t> aggregate_axes;
|
||||
};
|
||||
|
||||
public:
|
||||
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes)
|
||||
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
|
||||
if (num_devices_per_group() == 1) {
|
||||
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
|
||||
<< " has only one device per replica group.";
|
||||
}
|
||||
}
|
||||
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes);
|
||||
|
||||
bool operator==(const MeshAxesReplicaGroupList& other) const {
|
||||
return mesh_ == other.mesh_ && axes_ == other.axes_;
|
||||
@@ -58,6 +62,7 @@ class MeshAxesReplicaGroupList {
|
||||
|
||||
int64_t num_replica_groups() const;
|
||||
int64_t num_devices_per_group() const;
|
||||
std::vector<std::vector<int64_t>> flattened_replica_groups();
|
||||
|
||||
void Print(Printer* printer) const;
|
||||
|
||||
@@ -69,8 +74,11 @@ class MeshAxesReplicaGroupList {
|
||||
const MeshAxesReplicaGroupListProto& proto);
|
||||
|
||||
private:
|
||||
void InitializeDimToReshapeAndAggregateAxes();
|
||||
Mesh mesh_;
|
||||
std::vector<AxisRef> axes_;
|
||||
std::optional<absl::flat_hash_map<int64_t, ReshapeAndAggregateAxes>>
|
||||
dim_to_reshape_and_aggregate_axes_;
|
||||
};
|
||||
|
||||
std::string ReplicaGroupsToString(
|
||||
|
||||
263
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
263
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
@@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "xla/array.h"
|
||||
#include "xla/array2d.h"
|
||||
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||
#include "xla/hlo/ir/tile_assignment.h"
|
||||
#include "xla/service/hlo.pb.h"
|
||||
@@ -41,23 +42,230 @@ CollectiveDeviceListProto CreateDeviceListProto(
|
||||
return proto;
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
|
||||
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) {
|
||||
Mesh mesh_xy(TileAssignment({2, 2}), /*axes_names=*/{"x", "y"});
|
||||
|
||||
MeshAxesReplicaGroupList replica_group_none(mesh_xy, {});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_none = {
|
||||
{0}, {1}, {2}, {3}};
|
||||
EXPECT_EQ(replica_group_none.flattened_replica_groups(),
|
||||
expected_replica_groups_none);
|
||||
|
||||
MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_x = {{0, 2},
|
||||
{1, 3}};
|
||||
EXPECT_EQ(replica_group_x.flattened_replica_groups(),
|
||||
expected_replica_groups_x);
|
||||
|
||||
MeshAxesReplicaGroupList replica_group_y(mesh_xy, {AxisRef(1)});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_y = {{0, 1},
|
||||
{2, 3}};
|
||||
EXPECT_EQ(replica_group_y.flattened_replica_groups(),
|
||||
expected_replica_groups_y);
|
||||
|
||||
MeshAxesReplicaGroupList replica_group_xy(mesh_xy, {AxisRef(0), AxisRef(1)});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_xy = {{0, 1, 2, 3}};
|
||||
EXPECT_EQ(replica_group_xy.flattened_replica_groups(),
|
||||
expected_replica_groups_xy);
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) {
|
||||
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{6, 6})),
|
||||
/*axes_names=*/{"a", "b"});
|
||||
|
||||
// a:(1)2
|
||||
MeshAxesReplicaGroupList replica_group_a_1_2(mesh, {AxisRef(0, {1, 2})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_2 = {
|
||||
{0, 18}, {1, 19}, {2, 20}, {3, 21}, {4, 22}, {5, 23},
|
||||
{6, 24}, {7, 25}, {8, 26}, {9, 27}, {10, 28}, {11, 29},
|
||||
{12, 30}, {13, 31}, {14, 32}, {15, 33}, {16, 34}, {17, 35}};
|
||||
EXPECT_EQ(replica_group_a_1_2.flattened_replica_groups(),
|
||||
expected_replica_groups_a_1_2);
|
||||
|
||||
// a:(1)3
|
||||
MeshAxesReplicaGroupList replica_group_a_1_3(mesh, {AxisRef(0, {1, 3})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_3 = {
|
||||
{0, 12, 24}, {1, 13, 25}, {2, 14, 26}, {3, 15, 27},
|
||||
{4, 16, 28}, {5, 17, 29}, {6, 18, 30}, {7, 19, 31},
|
||||
{8, 20, 32}, {9, 21, 33}, {10, 22, 34}, {11, 23, 35}};
|
||||
EXPECT_EQ(replica_group_a_1_3.flattened_replica_groups(),
|
||||
expected_replica_groups_a_1_3);
|
||||
|
||||
// a:(3)2
|
||||
MeshAxesReplicaGroupList replica_group_a_3_2(mesh, {AxisRef(0, {3, 2})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_a_3_2 = {
|
||||
{0, 6}, {1, 7}, {2, 8}, {3, 9}, {4, 10}, {5, 11},
|
||||
{12, 18}, {13, 19}, {14, 20}, {15, 21}, {16, 22}, {17, 23},
|
||||
{24, 30}, {25, 31}, {26, 32}, {27, 33}, {28, 34}, {29, 35}};
|
||||
EXPECT_EQ(replica_group_a_3_2.flattened_replica_groups(),
|
||||
expected_replica_groups_a_3_2);
|
||||
|
||||
// b:(1)2
|
||||
MeshAxesReplicaGroupList replica_group_b_1_2(mesh, {AxisRef(1, {1, 2})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_b_1_2 = {
|
||||
{0, 3}, {1, 4}, {2, 5}, {6, 9}, {7, 10}, {8, 11},
|
||||
{12, 15}, {13, 16}, {14, 17}, {18, 21}, {19, 22}, {20, 23},
|
||||
{24, 27}, {25, 28}, {26, 29}, {30, 33}, {31, 34}, {32, 35}};
|
||||
EXPECT_EQ(replica_group_b_1_2.flattened_replica_groups(),
|
||||
expected_replica_groups_b_1_2);
|
||||
|
||||
// b:(1)3
|
||||
MeshAxesReplicaGroupList replica_group_b_1_3(mesh, {AxisRef(1, {1, 3})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_b_1_3 = {
|
||||
{0, 2, 4}, {1, 3, 5}, {6, 8, 10}, {7, 9, 11},
|
||||
{12, 14, 16}, {13, 15, 17}, {18, 20, 22}, {19, 21, 23},
|
||||
{24, 26, 28}, {25, 27, 29}, {30, 32, 34}, {31, 33, 35}};
|
||||
EXPECT_EQ(replica_group_b_1_3.flattened_replica_groups(),
|
||||
expected_replica_groups_b_1_3);
|
||||
|
||||
// b:(3)2
|
||||
MeshAxesReplicaGroupList replica_group_b_3_2(mesh, {AxisRef(1, {3, 2})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_b_3_2 = {
|
||||
{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11},
|
||||
{12, 13}, {14, 15}, {16, 17}, {18, 19}, {20, 21}, {22, 23},
|
||||
{24, 25}, {26, 27}, {28, 29}, {30, 31}, {32, 33}, {34, 35}};
|
||||
EXPECT_EQ(replica_group_b_3_2.flattened_replica_groups(),
|
||||
expected_replica_groups_b_3_2);
|
||||
|
||||
// a:(1)2, b:(1)2
|
||||
MeshAxesReplicaGroupList replica_group_a_1_2_b_1_2(
|
||||
mesh, {AxisRef(0, {1, 2}), AxisRef(1, {1, 2})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_2_b_1_2 = {
|
||||
{0, 3, 18, 21}, {1, 4, 19, 22}, {2, 5, 20, 23},
|
||||
{6, 9, 24, 27}, {7, 10, 25, 28}, {8, 11, 26, 29},
|
||||
{12, 15, 30, 33}, {13, 16, 31, 34}, {14, 17, 32, 35}};
|
||||
EXPECT_EQ(replica_group_a_1_2_b_1_2.flattened_replica_groups(),
|
||||
expected_replica_groups_a_1_2_b_1_2);
|
||||
|
||||
// a:(1)3, b:(1)3
|
||||
MeshAxesReplicaGroupList replica_group_a_1_3_b_1_3(
|
||||
mesh, {AxisRef(0, {1, 3}), AxisRef(1, {1, 3})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_3_b_1_3 = {
|
||||
{0, 2, 4, 12, 14, 16, 24, 26, 28},
|
||||
{1, 3, 5, 13, 15, 17, 25, 27, 29},
|
||||
{6, 8, 10, 18, 20, 22, 30, 32, 34},
|
||||
{7, 9, 11, 19, 21, 23, 31, 33, 35}};
|
||||
EXPECT_EQ(replica_group_a_1_3_b_1_3.flattened_replica_groups(),
|
||||
expected_replica_groups_a_1_3_b_1_3);
|
||||
|
||||
// b:(1)3, a:(1)3 (Reverse order from above). This should produce the same
|
||||
// replica groups as the above but with ids in a different order.
|
||||
MeshAxesReplicaGroupList replica_group_b_1_3_a_1_3(
|
||||
mesh, {AxisRef(1, {1, 3}), AxisRef(0, {1, 3})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_b_1_3_a_1_3 = {
|
||||
{0, 12, 24, 2, 14, 26, 4, 16, 28},
|
||||
{1, 13, 25, 3, 15, 27, 5, 17, 29},
|
||||
{6, 18, 30, 8, 20, 32, 10, 22, 34},
|
||||
{7, 19, 31, 9, 21, 33, 11, 23, 35}};
|
||||
EXPECT_EQ(replica_group_a_1_3_b_1_3.flattened_replica_groups(),
|
||||
expected_replica_groups_a_1_3_b_1_3);
|
||||
|
||||
// a:(3)2, b:(3)2
|
||||
MeshAxesReplicaGroupList replica_group_a_3_2_b_3_2(
|
||||
mesh, {AxisRef(0, {3, 2}), AxisRef(1, {3, 2})});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_a_3_2_b_3_2 = {
|
||||
{0, 1, 6, 7}, {2, 3, 8, 9}, {4, 5, 10, 11},
|
||||
{12, 13, 18, 19}, {14, 15, 20, 21}, {16, 17, 22, 23},
|
||||
{24, 25, 30, 31}, {26, 27, 32, 33}, {28, 29, 34, 35}};
|
||||
EXPECT_EQ(replica_group_a_3_2_b_3_2.flattened_replica_groups(),
|
||||
expected_replica_groups_a_3_2_b_3_2);
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsMatchExpectedV2) {
|
||||
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{8})),
|
||||
/*axes_names=*/{"a"});
|
||||
|
||||
// a:(1)2 -> replica_groups=[4,2]<=[2,4]T(1,0)
|
||||
MeshAxesReplicaGroupList v3_subaxis_1_2(mesh, {AxisRef(0, {1, 2})});
|
||||
IotaReplicaGroupList v2_subaxis_1_2(4, 2, {2, 4}, {1, 0});
|
||||
EXPECT_EQ(v3_subaxis_1_2.flattened_replica_groups(),
|
||||
v2_subaxis_1_2.flattened_replica_groups());
|
||||
|
||||
// a:(1)4 -> replica_groups=[2,4]<=[4,2]T(1,0)
|
||||
MeshAxesReplicaGroupList v3_subaxis_1_4(mesh, {AxisRef(0, {1, 4})});
|
||||
IotaReplicaGroupList v2_subaxis_1_4(2, 4, {4, 2}, {1, 0});
|
||||
EXPECT_EQ(v3_subaxis_1_4.flattened_replica_groups(),
|
||||
v2_subaxis_1_4.flattened_replica_groups());
|
||||
|
||||
// a:(2)2 -> replica_groups=[4,2]<=[2,2,2]T(0,2,1)
|
||||
MeshAxesReplicaGroupList v3_subaxis_2_2(mesh, {AxisRef(0, {2, 2})});
|
||||
IotaReplicaGroupList v2_subaxis_2_2(4, 2, {2, 2, 2}, {0, 2, 1});
|
||||
EXPECT_EQ(v3_subaxis_2_2.flattened_replica_groups(),
|
||||
v2_subaxis_2_2.flattened_replica_groups());
|
||||
|
||||
// a:(2)4 -> replica_groups=[2,4]<=[8]
|
||||
MeshAxesReplicaGroupList v3_subaxis_2_4(mesh, {AxisRef(0, {2, 4})});
|
||||
IotaReplicaGroupList v2_subaxis_2_4(2, 4, {8}, {0});
|
||||
EXPECT_EQ(v3_subaxis_2_4.flattened_replica_groups(),
|
||||
v2_subaxis_2_4.flattened_replica_groups());
|
||||
|
||||
// a:(4)2 -> replica_groups=[4,2]<=[8]
|
||||
MeshAxesReplicaGroupList v3_subaxis_4_2(mesh, {AxisRef(0, {4, 2})});
|
||||
IotaReplicaGroupList v2_subaxis_4_2(4, 2, {8}, {0});
|
||||
EXPECT_EQ(v3_subaxis_4_2.flattened_replica_groups(),
|
||||
v2_subaxis_4_2.flattened_replica_groups());
|
||||
|
||||
// {a:(1)2, a:(4)2} -> replica_groups=[2,4]<=[2,2,2]T(1,0,2)
|
||||
MeshAxesReplicaGroupList v3_subaxis_1_2_and_4_2(
|
||||
mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})});
|
||||
IotaReplicaGroupList v2_subaxis_1_2_and_4_2(2, 4, {2, 2, 2}, {1, 0, 2});
|
||||
EXPECT_EQ(v3_subaxis_1_2_and_4_2.flattened_replica_groups(),
|
||||
v2_subaxis_1_2_and_4_2.flattened_replica_groups());
|
||||
|
||||
// {a:(4)2, a:(1)2} -> replica_groups=[2,4]<=[2,2,2]T(1,2,0)
|
||||
MeshAxesReplicaGroupList v3_subaxis_4_2_and_1_2(
|
||||
mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})});
|
||||
IotaReplicaGroupList v2_subaxis_4_2_and_1_2(2, 4, {2, 2, 2}, {1, 2, 0});
|
||||
EXPECT_EQ(v3_subaxis_4_2_and_1_2.flattened_replica_groups(),
|
||||
v2_subaxis_4_2_and_1_2.flattened_replica_groups());
|
||||
|
||||
// a -> replica_groups=[1,8]<=[8]
|
||||
MeshAxesReplicaGroupList v3_no_subaxis(mesh, {AxisRef(0)});
|
||||
IotaReplicaGroupList v2_no_subaxis(1, 8, {8}, {0});
|
||||
EXPECT_EQ(v3_no_subaxis.flattened_replica_groups(),
|
||||
v2_no_subaxis.flattened_replica_groups());
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest,
|
||||
MaterializedReplicaGroupsRespectNonIotaDeviceOrdering) {
|
||||
// Create a mesh with non-iota device ordering.
|
||||
Array2D<int64_t> array({{3, 1}, {0, 2}});
|
||||
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||
Mesh mesh_xy(tile_assignment, /*axes_names=*/{"x", "y"});
|
||||
|
||||
// Reduce along x axis.
|
||||
MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)});
|
||||
// With iota device ordering, the expected replica groups would be
|
||||
// {{0, 2}, {1, 3}}.
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_x = {{3, 0},
|
||||
{1, 2}};
|
||||
EXPECT_THAT(replica_group_x.flattened_replica_groups(),
|
||||
testing::UnorderedElementsAreArray(expected_replica_groups_x));
|
||||
|
||||
// Reduce along y axis.
|
||||
MeshAxesReplicaGroupList replica_group_y(mesh_xy, {AxisRef(1)});
|
||||
// With iota device ordering, the expected replica groups would be
|
||||
// {{0, 1}, {2, 3}}.
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_y = {{3, 1},
|
||||
{0, 2}};
|
||||
EXPECT_THAT(replica_group_y.flattened_replica_groups(),
|
||||
testing::UnorderedElementsAreArray(expected_replica_groups_y));
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) {
|
||||
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{4, 4})),
|
||||
/*axes_names=*/{"x", "y"});
|
||||
MeshAxesReplicaGroupList replica_group_across_all_axes(
|
||||
all_axes,
|
||||
/*axes=*/{AxisRef(0), AxisRef(1)});
|
||||
all_axes, {AxisRef(0), AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
|
||||
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
|
||||
|
||||
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{3, 5})),
|
||||
/*axes_names=*/{"a", "b"});
|
||||
MeshAxesReplicaGroupList replica_group_across_a(one_axes,
|
||||
/*axes=*/{AxisRef(0)});
|
||||
MeshAxesReplicaGroupList replica_group_across_b(one_axes,
|
||||
/*axes=*/{AxisRef(1)});
|
||||
MeshAxesReplicaGroupList replica_group_across_a(one_axes, {AxisRef(0)});
|
||||
MeshAxesReplicaGroupList replica_group_across_b(one_axes, {AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
|
||||
EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3);
|
||||
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
|
||||
@@ -71,16 +279,30 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
|
||||
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) {
|
||||
Mesh mesh(TileAssignment({8}), /*axes_names=*/{"1"});
|
||||
MeshAxesReplicaGroupList replica_group_multiple_subaxes1(
|
||||
mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})});
|
||||
MeshAxesReplicaGroupList replica_group_multiple_subaxes2(
|
||||
mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})});
|
||||
|
||||
Mesh overlap_mesh(TileAssignment({2 * 3 * 5}), /*axes_names=*/{"u"});
|
||||
EXPECT_DEATH(
|
||||
{
|
||||
MeshAxesReplicaGroupList overlapping_subaxes(
|
||||
overlap_mesh, {AxisRef(0, {6, 5}), AxisRef(0, {10, 3})});
|
||||
},
|
||||
"Overlapping sub-axes");
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
||||
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 6, 10})),
|
||||
/*axes_names=*/{"axis1", "axis2", "axis3"});
|
||||
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
|
||||
mesh_one_subaxis,
|
||||
/*axes=*/{AxisRef(0, {1, 2})});
|
||||
mesh_one_subaxis, {AxisRef(0, {1, 2})});
|
||||
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
|
||||
mesh_one_subaxis,
|
||||
/*axes=*/{AxisRef(1, {2, 3})});
|
||||
mesh_one_subaxis, {AxisRef(1, {2, 3})});
|
||||
EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60);
|
||||
EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2);
|
||||
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
|
||||
@@ -91,10 +313,10 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
||||
/*axes_names=*/{"alpha", "beta", "gamma"});
|
||||
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
|
||||
mesh_multiple_subaxis,
|
||||
/*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
|
||||
{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
|
||||
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2(
|
||||
mesh_multiple_subaxis,
|
||||
/*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
|
||||
{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
|
||||
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(),
|
||||
3 * 7 * 13);
|
||||
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(),
|
||||
@@ -110,11 +332,10 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{10, 12, 15})),
|
||||
/*axes_names=*/{"u", "v", "w"});
|
||||
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{});
|
||||
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, {});
|
||||
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
|
||||
MeshAxesReplicaGroupList replica_group_across_uv(
|
||||
mesh_uvw,
|
||||
/*axes=*/{AxisRef(0), AxisRef(1)});
|
||||
MeshAxesReplicaGroupList replica_group_across_uv(mesh_uvw,
|
||||
{AxisRef(0), AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
|
||||
|
||||
// Subaxes and replica group v2 iota style device assignment.
|
||||
@@ -122,11 +343,11 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||
/*axes_names=*/{"a", "b", "c", "d"});
|
||||
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{});
|
||||
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, {});
|
||||
EXPECT_EQ(rg_abcd_across_none.ToString(),
|
||||
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
|
||||
MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes(
|
||||
mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
|
||||
mesh_abcd, {AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
|
||||
EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(),
|
||||
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {a,b:(1)2,d}");
|
||||
|
||||
@@ -135,11 +356,11 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||
array.Reshape({10});
|
||||
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
|
||||
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{});
|
||||
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, {});
|
||||
EXPECT_EQ(rg_ooo_across_none.ToString(),
|
||||
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");
|
||||
MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo,
|
||||
/*axes=*/{AxisRef(0, {5, 2})});
|
||||
{AxisRef(0, {5, 2})});
|
||||
EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(),
|
||||
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user