diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index bd8c592cb48..29f62b139bd 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -422,6 +422,7 @@ xla_cc_test( ":mesh_and_axis", ":tile_assignment", "//xla:array", + "//xla:array2d", "//xla:xla_data_proto_cc", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:test_main", diff --git a/third_party/xla/xla/hlo/ir/mesh_and_axis.cc b/third_party/xla/xla/hlo/ir/mesh_and_axis.cc index 72117120a9c..a897519d7e8 100644 --- a/third_party/xla/xla/hlo/ir/mesh_and_axis.cc +++ b/third_party/xla/xla/hlo/ir/mesh_and_axis.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/ir/mesh_and_axis.h" +#include #include #include #include @@ -101,4 +102,42 @@ AxisRef AxisRef::FromProto(const AxisRefProto& proto) { return axis_ref; } +bool canSubAxesCoexist(int64_t minPreSize, int64_t maxPreSize, + int64_t minNextPreSize, int64_t maxNextPreSize) { + if (minNextPreSize > maxPreSize) { + // Sub-axes overlap, check if overlapping and non-overlapping parts are + // valid. + return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 && + maxNextPreSize % minNextPreSize == 0; + } + // Sub-axes don't overlap, check if the gap is valid. + return maxPreSize % minNextPreSize == 0; +} + +bool AxisRef::CanCoexist(const AxisRef& other) const { + if (mesh_axis_index() != other.mesh_axis_index()) { + return true; + } + if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) { + // If one is a full axis and the other is a sub-axis, they can coexist. + return true; + } + + const SubAxis& this_sub_axis = sub_axis_info_.value(); + const SubAxis& other_sub_axis = other.sub_axis_info_.value(); + + int64_t this_pre_size = this_sub_axis.pre_size; + int64_t other_pre_size = other_sub_axis.pre_size; + int64_t this_next_pre_size = this_sub_axis.next_pre_size(); + int64_t other_next_pre_size = other_sub_axis.next_pre_size(); + + auto [min_pre_size, max_pre_size] = + std::minmax(this_pre_size, other_pre_size); + auto [min_next_pre_size, max_next_pre_size] = + std::minmax(this_next_pre_size, other_next_pre_size); + + return canSubAxesCoexist(min_pre_size, max_pre_size, min_next_pre_size, + max_next_pre_size); +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/mesh_and_axis.h b/third_party/xla/xla/hlo/ir/mesh_and_axis.h index affea3ac436..5dc8c9651db 100644 --- a/third_party/xla/xla/hlo/ir/mesh_and_axis.h +++ b/third_party/xla/xla/hlo/ir/mesh_and_axis.h @@ -109,6 +109,12 @@ class Mesh { TileAssignment device_assignment() const { return device_assignment_; } std::vector axis_names() const { return axes_names_; } + absl::Span axis_sizes() const { + return device_assignment_.dimensions(); + } + int64_t axis_size(int64_t axis_index) const { + return device_assignment_.dim(axis_index); + } private: // Dimensions of the `device_assignment_` array correspond to the axes of the @@ -127,6 +133,7 @@ class AxisRef { struct SubAxis { int64_t pre_size; int64_t size; + int64_t next_pre_size() const { return pre_size * size; } }; // Index corresponding to axis in the mesh. It should be a valid index into @@ -177,6 +184,8 @@ class AxisRef { static AxisRef FromProto(const AxisRefProto& proto); + bool CanCoexist(const AxisRef& other) const; + int64_t mesh_axis_index() const { return mesh_axis_index_; } std::optional sub_axis_info() const { return sub_axis_info_; } }; diff --git a/third_party/xla/xla/hlo/ir/replica_group.cc b/third_party/xla/xla/hlo/ir/replica_group.cc index eba41f94d30..0a02c175b8b 100644 --- a/third_party/xla/xla/hlo/ir/replica_group.cc +++ b/third_party/xla/xla/hlo/ir/replica_group.cc @@ -15,12 +15,18 @@ limitations under the License. #include "xla/hlo/ir/replica_group.h" +#include #include #include +#include #include #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" @@ -33,7 +39,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" namespace xla { @@ -49,6 +54,145 @@ std::string ReplicaGroupsToString( } /************** MeshAxesReplicaGroupList implementation ***********************/ + +void HandleSingleAxisRefPerDimension(const AxisRef& axis, + int64_t full_axis_size, + std::vector& out_reshape_dims, + std::vector& out_aggregate_axes) { + if (axis.sub_axis_info().has_value()) { + out_reshape_dims = {axis.sub_axis_info()->pre_size, + axis.sub_axis_info()->size, + full_axis_size / axis.sub_axis_info()->next_pre_size()}; + // The aggregation axis is the second dimension. + out_aggregate_axes = {1}; + } else { + out_reshape_dims = {full_axis_size}; + out_aggregate_axes = {0}; + } +} + +void HandleMultiAxisRefPerDimension(std::vector& axes, + int64_t full_axis_size, + std::vector& out_reshape_dims, + std::vector& out_aggregate_axes) { + // --- 1. Sort Axes and Original Indices Together --- + // Sort both the axes and the original indices based on + // sub_axis_info()->pre_size. This allows us to maintain user specified order + // of AxisRef while still building the reshape and aggregate axes. + std::vector original_order(axes.size()); + std::iota(original_order.begin(), original_order.end(), 0); + std::sort(original_order.begin(), original_order.end(), + [&axes](int i, int j) { + return axes[i].sub_axis_info()->pre_size < + axes[j].sub_axis_info()->pre_size; + }); + std::sort(axes.begin(), axes.end(), [](const AxisRef& a, const AxisRef& b) { + return a.sub_axis_info()->pre_size < b.sub_axis_info()->pre_size; + }); + + // --- 2. Build Reshape Dims and Aggregation Axes --- + int64_t current_dim_index = 0; // Index in the new reshaped tensor + int64_t prefix_product = 1; // Product of the size of all prior dimensions + + for (const AxisRef& axis : axes) { + int64_t pre_size = axis.sub_axis_info()->pre_size; + int64_t size = axis.sub_axis_info()->size; + + // Insert "padding" dimension if the current prefix product doesn't match + // the required pre_size + if (pre_size != prefix_product) { + int64_t padding_size = pre_size / prefix_product; + out_reshape_dims.push_back(padding_size); + current_dim_index++; + prefix_product *= padding_size; + } + + // Insert the sharded size (the part to aggregate) + out_reshape_dims.push_back(size); + out_aggregate_axes.push_back( + current_dim_index); // This is the axis we aggregate over + current_dim_index++; + prefix_product *= size; + } + + // Insert "suffix" dimension if the full size hasn't been reached + if (prefix_product != full_axis_size) { + out_reshape_dims.push_back(full_axis_size / prefix_product); + } + + // --- 3. Permute Aggregate Axes back to Original Order --- + // The aggregate axes were calculated based on the sorted list. + // We must map them back to the original order to compute the correct + // flattened replica groups. + std::vector permuted_aggregate_axes(original_order.size()); + for (int64_t i = 0; i < original_order.size(); ++i) { + permuted_aggregate_axes[original_order[i]] = out_aggregate_axes[i]; + } + out_aggregate_axes = permuted_aggregate_axes; +} + +bool ValidateSingleDimensionAxes(int64_t dim, std::vector& axes, + const Mesh& mesh_) { + // If there's only one axis, nothing to check. + if (axes.size() <= 1) { + return true; + } + // --- Step 1: Deduplication --- + // If one is a "full" axis (no sub_axis_info), it subsumes all other AxisRefs + // with sub_axis_info. + for (const AxisRef& axis : axes) { + if (!axis.sub_axis_info().has_value()) { + LOG(WARNING) << "MeshAxesReplicaGroupList: Redundant axis definition at " + "dimension: " + << dim + << ". Keeping only the full axis: " << axis.ToString(mesh_); + axes = {axis}; + return true; + } + } + // --- Step 2: Overlap Check --- + // At this point, all remaining axes MUST have sub_axis_info(). + // Verify that the remaining multiple sub-axes do not overlap. + for (int64_t i = 0; i < axes.size() - 1; ++i) { + for (int64_t j = i + 1; j < axes.size(); ++j) { + // CHECK will terminate the program on failure, matching original + // behavior. + CHECK(axes[i].CanCoexist(axes[j])) + << "Overlapping sub-axes detected: " << axes[i].ToString(mesh_) + << " and " << axes[j].ToString(mesh_); + } + } + return true; // Passed all checks for this dimension. +} + +MeshAxesReplicaGroupList::MeshAxesReplicaGroupList(Mesh mesh, + std::vector axes) + : mesh_(std::move(mesh)), axes_(std::move(axes)) { + if (num_devices_per_group() == 1) { + LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString() + << " has only one device per replica group."; + } + + absl::flat_hash_set dimensions; + absl::flat_hash_map> dim_to_axes; + for (const AxisRef& axis : axes_) { + dim_to_axes[axis.mesh_axis_index()].push_back(axis); + dimensions.insert(axis.mesh_axis_index()); + if (axis.sub_axis_info().has_value()) { + CHECK(mesh_.axis_size(axis.mesh_axis_index()) % + axis.sub_axis_info()->next_pre_size() == + 0) + << "Next pre-size must divide the full axis size."; + } + } + + // Validate input AxisRefs. + for (int64_t dim : dimensions) { + std::vector& axes = dim_to_axes[dim]; + CHECK(ValidateSingleDimensionAxes(dim, axes, mesh_)); + } +} + int64_t MeshAxesReplicaGroupList::num_replica_groups() const { return mesh_.device_assignment().num_elements() / num_devices_per_group(); } @@ -58,15 +202,105 @@ int64_t MeshAxesReplicaGroupList::num_devices_per_group() const { // all axes. int64_t devices_per_group = 1; for (const AxisRef& axis : axes_) { - int64_t axis_size = - axis.sub_axis_info().has_value() - ? axis.sub_axis_info()->size - : mesh_.device_assignment().dim(axis.mesh_axis_index()); + int64_t axis_size = axis.sub_axis_info().has_value() + ? axis.sub_axis_info()->size + : mesh_.axis_size(axis.mesh_axis_index()); devices_per_group *= axis_size; } return devices_per_group; } +std::vector> get_replica_groups_for_full_axes( + const Mesh& mesh, absl::Span axis_sizes, + const absl::Span grouped_axes, + const int64_t num_replica_groups, const int64_t num_devices_per_group) { + // Reshape the device assignment array bases on the axis sizes and transpose + // grouped axes to the end. + std::vector transpose_axes; + transpose_axes.reserve(axis_sizes.size()); + for (int64_t i = 0; i < axis_sizes.size(); ++i) { + if (!absl::c_linear_search(grouped_axes, i)) { + transpose_axes.push_back(i); + } + } + for (int64_t grouped_axis : grouped_axes) { + transpose_axes.push_back(grouped_axis); + } + + TileAssignment device_assignment = + mesh.device_assignment().Reshape(axis_sizes).Transpose(transpose_axes); + + std::vector> replica_groups; + replica_groups.reserve(num_replica_groups); + for (auto it = device_assignment.array().begin(); + it != device_assignment.array().end(); it += num_devices_per_group) { + std::vector group(it, it + num_devices_per_group); + replica_groups.emplace_back(std::move(group)); + } + return replica_groups; +} + +void MeshAxesReplicaGroupList::InitializeDimToReshapeAndAggregateAxes() { + absl::flat_hash_map> dim_to_axes; + for (const AxisRef& axis : axes_) { + dim_to_axes[axis.mesh_axis_index()].push_back(axis); + } + absl::flat_hash_map dim_map; + // For each dimension determine the reshape that is consistent with it's + // AxisRef(s). Then maintain this reshape and the aggregated dims for easier + // computation of replica groups. As an example for @mesh<"a"=8> + // {a} -> no reshape, aggregate over [0] + // {a:(1)2} -> reshape [8]->[1,2,4], aggregate over [1] + // {a:(1)2, a:(4)2} -> reshape [8]->[2,2,2], aggregate over [0,2] + for (auto& [dim, axes] : dim_to_axes) { + int64_t full_axis_size = mesh_.axis_size(dim); + ReshapeAndAggregateAxes reshape_and_aggregate_axes; + if (axes.size() == 1) { + HandleSingleAxisRefPerDimension( + axes[0], full_axis_size, reshape_and_aggregate_axes.reshape_dims, + reshape_and_aggregate_axes.aggregate_axes); + } else { + // Otherwise dimension is a set of axes with sub-axes info. + HandleMultiAxisRefPerDimension(axes, full_axis_size, + reshape_and_aggregate_axes.reshape_dims, + reshape_and_aggregate_axes.aggregate_axes); + } + dim_map[dim] = reshape_and_aggregate_axes; + } + dim_to_reshape_and_aggregate_axes_ = dim_map; +} + +std::vector> +MeshAxesReplicaGroupList::flattened_replica_groups() { + if (!dim_to_reshape_and_aggregate_axes_.has_value()) { + InitializeDimToReshapeAndAggregateAxes(); + } + + absl::flat_hash_map dim_map = + dim_to_reshape_and_aggregate_axes_.value(); + std::vector reindex_axis_sizes; + std::vector reindexed_grouped_axes; + for (int64_t i = 0; i < mesh_.axis_sizes().size(); ++i) { + int64_t axis_size = mesh_.axis_size(i); + auto it = dim_map.find(i); + if (it == dim_map.end()) { + reindex_axis_sizes.push_back(axis_size); + continue; + } + int64_t offset_index = reindex_axis_sizes.size(); + const ReshapeAndAggregateAxes& reshape_and_aggregate_axes = it->second; + for (int64_t reshape_dim : reshape_and_aggregate_axes.reshape_dims) { + reindex_axis_sizes.push_back(reshape_dim); + } + for (int64_t aggregate_dim : reshape_and_aggregate_axes.aggregate_axes) { + reindexed_grouped_axes.push_back(aggregate_dim + offset_index); + } + } + return get_replica_groups_for_full_axes( + mesh_, reindex_axis_sizes, reindexed_grouped_axes, num_replica_groups(), + num_devices_per_group()); +} + void MeshAxesReplicaGroupList::Print(Printer* printer) const { printer->Append(ToString()); } diff --git a/third_party/xla/xla/hlo/ir/replica_group.h b/third_party/xla/xla/hlo/ir/replica_group.h index de8412df22a..06497aa0ddf 100644 --- a/third_party/xla/xla/hlo/ir/replica_group.h +++ b/third_party/xla/xla/hlo/ir/replica_group.h @@ -16,14 +16,19 @@ limitations under the License. #ifndef XLA_HLO_IR_REPLICA_GROUP_H_ #define XLA_HLO_IR_REPLICA_GROUP_H_ +#include #include #include #include +#include #include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/types/span.h" #include "google/protobuf/repeated_ptr_field.h" @@ -38,14 +43,13 @@ limitations under the License. namespace xla { class MeshAxesReplicaGroupList { + struct ReshapeAndAggregateAxes { + std::vector reshape_dims; + std::vector aggregate_axes; + }; + public: - explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector axes) - : mesh_(std::move(mesh)), axes_(std::move(axes)) { - if (num_devices_per_group() == 1) { - LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString() - << " has only one device per replica group."; - } - } + explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector axes); bool operator==(const MeshAxesReplicaGroupList& other) const { return mesh_ == other.mesh_ && axes_ == other.axes_; @@ -58,6 +62,7 @@ class MeshAxesReplicaGroupList { int64_t num_replica_groups() const; int64_t num_devices_per_group() const; + std::vector> flattened_replica_groups(); void Print(Printer* printer) const; @@ -69,8 +74,11 @@ class MeshAxesReplicaGroupList { const MeshAxesReplicaGroupListProto& proto); private: + void InitializeDimToReshapeAndAggregateAxes(); Mesh mesh_; std::vector axes_; + std::optional> + dim_to_reshape_and_aggregate_axes_; }; std::string ReplicaGroupsToString( diff --git a/third_party/xla/xla/hlo/ir/replica_group_test.cc b/third_party/xla/xla/hlo/ir/replica_group_test.cc index 5cd40ad68dc..9ae886a8e92 100644 --- a/third_party/xla/xla/hlo/ir/replica_group_test.cc +++ b/third_party/xla/xla/hlo/ir/replica_group_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "xla/array.h" +#include "xla/array2d.h" #include "xla/hlo/ir/mesh_and_axis.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/service/hlo.pb.h" @@ -41,23 +42,230 @@ CollectiveDeviceListProto CreateDeviceListProto( return proto; } -TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) { +TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) { + Mesh mesh_xy(TileAssignment({2, 2}), /*axes_names=*/{"x", "y"}); + + MeshAxesReplicaGroupList replica_group_none(mesh_xy, {}); + std::vector> expected_replica_groups_none = { + {0}, {1}, {2}, {3}}; + EXPECT_EQ(replica_group_none.flattened_replica_groups(), + expected_replica_groups_none); + + MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)}); + std::vector> expected_replica_groups_x = {{0, 2}, + {1, 3}}; + EXPECT_EQ(replica_group_x.flattened_replica_groups(), + expected_replica_groups_x); + + MeshAxesReplicaGroupList replica_group_y(mesh_xy, {AxisRef(1)}); + std::vector> expected_replica_groups_y = {{0, 1}, + {2, 3}}; + EXPECT_EQ(replica_group_y.flattened_replica_groups(), + expected_replica_groups_y); + + MeshAxesReplicaGroupList replica_group_xy(mesh_xy, {AxisRef(0), AxisRef(1)}); + std::vector> expected_replica_groups_xy = {{0, 1, 2, 3}}; + EXPECT_EQ(replica_group_xy.flattened_replica_groups(), + expected_replica_groups_xy); +} + +TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) { + Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{6, 6})), + /*axes_names=*/{"a", "b"}); + + // a:(1)2 + MeshAxesReplicaGroupList replica_group_a_1_2(mesh, {AxisRef(0, {1, 2})}); + std::vector> expected_replica_groups_a_1_2 = { + {0, 18}, {1, 19}, {2, 20}, {3, 21}, {4, 22}, {5, 23}, + {6, 24}, {7, 25}, {8, 26}, {9, 27}, {10, 28}, {11, 29}, + {12, 30}, {13, 31}, {14, 32}, {15, 33}, {16, 34}, {17, 35}}; + EXPECT_EQ(replica_group_a_1_2.flattened_replica_groups(), + expected_replica_groups_a_1_2); + + // a:(1)3 + MeshAxesReplicaGroupList replica_group_a_1_3(mesh, {AxisRef(0, {1, 3})}); + std::vector> expected_replica_groups_a_1_3 = { + {0, 12, 24}, {1, 13, 25}, {2, 14, 26}, {3, 15, 27}, + {4, 16, 28}, {5, 17, 29}, {6, 18, 30}, {7, 19, 31}, + {8, 20, 32}, {9, 21, 33}, {10, 22, 34}, {11, 23, 35}}; + EXPECT_EQ(replica_group_a_1_3.flattened_replica_groups(), + expected_replica_groups_a_1_3); + + // a:(3)2 + MeshAxesReplicaGroupList replica_group_a_3_2(mesh, {AxisRef(0, {3, 2})}); + std::vector> expected_replica_groups_a_3_2 = { + {0, 6}, {1, 7}, {2, 8}, {3, 9}, {4, 10}, {5, 11}, + {12, 18}, {13, 19}, {14, 20}, {15, 21}, {16, 22}, {17, 23}, + {24, 30}, {25, 31}, {26, 32}, {27, 33}, {28, 34}, {29, 35}}; + EXPECT_EQ(replica_group_a_3_2.flattened_replica_groups(), + expected_replica_groups_a_3_2); + + // b:(1)2 + MeshAxesReplicaGroupList replica_group_b_1_2(mesh, {AxisRef(1, {1, 2})}); + std::vector> expected_replica_groups_b_1_2 = { + {0, 3}, {1, 4}, {2, 5}, {6, 9}, {7, 10}, {8, 11}, + {12, 15}, {13, 16}, {14, 17}, {18, 21}, {19, 22}, {20, 23}, + {24, 27}, {25, 28}, {26, 29}, {30, 33}, {31, 34}, {32, 35}}; + EXPECT_EQ(replica_group_b_1_2.flattened_replica_groups(), + expected_replica_groups_b_1_2); + + // b:(1)3 + MeshAxesReplicaGroupList replica_group_b_1_3(mesh, {AxisRef(1, {1, 3})}); + std::vector> expected_replica_groups_b_1_3 = { + {0, 2, 4}, {1, 3, 5}, {6, 8, 10}, {7, 9, 11}, + {12, 14, 16}, {13, 15, 17}, {18, 20, 22}, {19, 21, 23}, + {24, 26, 28}, {25, 27, 29}, {30, 32, 34}, {31, 33, 35}}; + EXPECT_EQ(replica_group_b_1_3.flattened_replica_groups(), + expected_replica_groups_b_1_3); + + // b:(3)2 + MeshAxesReplicaGroupList replica_group_b_3_2(mesh, {AxisRef(1, {3, 2})}); + std::vector> expected_replica_groups_b_3_2 = { + {0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}, + {12, 13}, {14, 15}, {16, 17}, {18, 19}, {20, 21}, {22, 23}, + {24, 25}, {26, 27}, {28, 29}, {30, 31}, {32, 33}, {34, 35}}; + EXPECT_EQ(replica_group_b_3_2.flattened_replica_groups(), + expected_replica_groups_b_3_2); + + // a:(1)2, b:(1)2 + MeshAxesReplicaGroupList replica_group_a_1_2_b_1_2( + mesh, {AxisRef(0, {1, 2}), AxisRef(1, {1, 2})}); + std::vector> expected_replica_groups_a_1_2_b_1_2 = { + {0, 3, 18, 21}, {1, 4, 19, 22}, {2, 5, 20, 23}, + {6, 9, 24, 27}, {7, 10, 25, 28}, {8, 11, 26, 29}, + {12, 15, 30, 33}, {13, 16, 31, 34}, {14, 17, 32, 35}}; + EXPECT_EQ(replica_group_a_1_2_b_1_2.flattened_replica_groups(), + expected_replica_groups_a_1_2_b_1_2); + + // a:(1)3, b:(1)3 + MeshAxesReplicaGroupList replica_group_a_1_3_b_1_3( + mesh, {AxisRef(0, {1, 3}), AxisRef(1, {1, 3})}); + std::vector> expected_replica_groups_a_1_3_b_1_3 = { + {0, 2, 4, 12, 14, 16, 24, 26, 28}, + {1, 3, 5, 13, 15, 17, 25, 27, 29}, + {6, 8, 10, 18, 20, 22, 30, 32, 34}, + {7, 9, 11, 19, 21, 23, 31, 33, 35}}; + EXPECT_EQ(replica_group_a_1_3_b_1_3.flattened_replica_groups(), + expected_replica_groups_a_1_3_b_1_3); + + // b:(1)3, a:(1)3 (Reverse order from above). This should produce the same + // replica groups as the above but with ids in a different order. + MeshAxesReplicaGroupList replica_group_b_1_3_a_1_3( + mesh, {AxisRef(1, {1, 3}), AxisRef(0, {1, 3})}); + std::vector> expected_replica_groups_b_1_3_a_1_3 = { + {0, 12, 24, 2, 14, 26, 4, 16, 28}, + {1, 13, 25, 3, 15, 27, 5, 17, 29}, + {6, 18, 30, 8, 20, 32, 10, 22, 34}, + {7, 19, 31, 9, 21, 33, 11, 23, 35}}; + EXPECT_EQ(replica_group_a_1_3_b_1_3.flattened_replica_groups(), + expected_replica_groups_a_1_3_b_1_3); + + // a:(3)2, b:(3)2 + MeshAxesReplicaGroupList replica_group_a_3_2_b_3_2( + mesh, {AxisRef(0, {3, 2}), AxisRef(1, {3, 2})}); + std::vector> expected_replica_groups_a_3_2_b_3_2 = { + {0, 1, 6, 7}, {2, 3, 8, 9}, {4, 5, 10, 11}, + {12, 13, 18, 19}, {14, 15, 20, 21}, {16, 17, 22, 23}, + {24, 25, 30, 31}, {26, 27, 32, 33}, {28, 29, 34, 35}}; + EXPECT_EQ(replica_group_a_3_2_b_3_2.flattened_replica_groups(), + expected_replica_groups_a_3_2_b_3_2); +} + +TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsMatchExpectedV2) { + Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{8})), + /*axes_names=*/{"a"}); + + // a:(1)2 -> replica_groups=[4,2]<=[2,4]T(1,0) + MeshAxesReplicaGroupList v3_subaxis_1_2(mesh, {AxisRef(0, {1, 2})}); + IotaReplicaGroupList v2_subaxis_1_2(4, 2, {2, 4}, {1, 0}); + EXPECT_EQ(v3_subaxis_1_2.flattened_replica_groups(), + v2_subaxis_1_2.flattened_replica_groups()); + + // a:(1)4 -> replica_groups=[2,4]<=[4,2]T(1,0) + MeshAxesReplicaGroupList v3_subaxis_1_4(mesh, {AxisRef(0, {1, 4})}); + IotaReplicaGroupList v2_subaxis_1_4(2, 4, {4, 2}, {1, 0}); + EXPECT_EQ(v3_subaxis_1_4.flattened_replica_groups(), + v2_subaxis_1_4.flattened_replica_groups()); + + // a:(2)2 -> replica_groups=[4,2]<=[2,2,2]T(0,2,1) + MeshAxesReplicaGroupList v3_subaxis_2_2(mesh, {AxisRef(0, {2, 2})}); + IotaReplicaGroupList v2_subaxis_2_2(4, 2, {2, 2, 2}, {0, 2, 1}); + EXPECT_EQ(v3_subaxis_2_2.flattened_replica_groups(), + v2_subaxis_2_2.flattened_replica_groups()); + + // a:(2)4 -> replica_groups=[2,4]<=[8] + MeshAxesReplicaGroupList v3_subaxis_2_4(mesh, {AxisRef(0, {2, 4})}); + IotaReplicaGroupList v2_subaxis_2_4(2, 4, {8}, {0}); + EXPECT_EQ(v3_subaxis_2_4.flattened_replica_groups(), + v2_subaxis_2_4.flattened_replica_groups()); + + // a:(4)2 -> replica_groups=[4,2]<=[8] + MeshAxesReplicaGroupList v3_subaxis_4_2(mesh, {AxisRef(0, {4, 2})}); + IotaReplicaGroupList v2_subaxis_4_2(4, 2, {8}, {0}); + EXPECT_EQ(v3_subaxis_4_2.flattened_replica_groups(), + v2_subaxis_4_2.flattened_replica_groups()); + + // {a:(1)2, a:(4)2} -> replica_groups=[2,4]<=[2,2,2]T(1,0,2) + MeshAxesReplicaGroupList v3_subaxis_1_2_and_4_2( + mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})}); + IotaReplicaGroupList v2_subaxis_1_2_and_4_2(2, 4, {2, 2, 2}, {1, 0, 2}); + EXPECT_EQ(v3_subaxis_1_2_and_4_2.flattened_replica_groups(), + v2_subaxis_1_2_and_4_2.flattened_replica_groups()); + + // {a:(4)2, a:(1)2} -> replica_groups=[2,4]<=[2,2,2]T(1,2,0) + MeshAxesReplicaGroupList v3_subaxis_4_2_and_1_2( + mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})}); + IotaReplicaGroupList v2_subaxis_4_2_and_1_2(2, 4, {2, 2, 2}, {1, 2, 0}); + EXPECT_EQ(v3_subaxis_4_2_and_1_2.flattened_replica_groups(), + v2_subaxis_4_2_and_1_2.flattened_replica_groups()); + + // a -> replica_groups=[1,8]<=[8] + MeshAxesReplicaGroupList v3_no_subaxis(mesh, {AxisRef(0)}); + IotaReplicaGroupList v2_no_subaxis(1, 8, {8}, {0}); + EXPECT_EQ(v3_no_subaxis.flattened_replica_groups(), + v2_no_subaxis.flattened_replica_groups()); +} + +TEST(MeshAxesReplicaGroupListTest, + MaterializedReplicaGroupsRespectNonIotaDeviceOrdering) { + // Create a mesh with non-iota device ordering. + Array2D array({{3, 1}, {0, 2}}); + TileAssignment tile_assignment(std::make_shared>(array)); + Mesh mesh_xy(tile_assignment, /*axes_names=*/{"x", "y"}); + + // Reduce along x axis. + MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)}); + // With iota device ordering, the expected replica groups would be + // {{0, 2}, {1, 3}}. + std::vector> expected_replica_groups_x = {{3, 0}, + {1, 2}}; + EXPECT_THAT(replica_group_x.flattened_replica_groups(), + testing::UnorderedElementsAreArray(expected_replica_groups_x)); + + // Reduce along y axis. + MeshAxesReplicaGroupList replica_group_y(mesh_xy, {AxisRef(1)}); + // With iota device ordering, the expected replica groups would be + // {{0, 1}, {2, 3}}. + std::vector> expected_replica_groups_y = {{3, 1}, + {0, 2}}; + EXPECT_THAT(replica_group_y.flattened_replica_groups(), + testing::UnorderedElementsAreArray(expected_replica_groups_y)); +} + +TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) { Mesh all_axes(TileAssignment(IotaTileAssignment::Create( /*dims=*/{4, 4})), /*axes_names=*/{"x", "y"}); MeshAxesReplicaGroupList replica_group_across_all_axes( - all_axes, - /*axes=*/{AxisRef(0), AxisRef(1)}); + all_axes, {AxisRef(0), AxisRef(1)}); EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1); EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16); Mesh one_axes(TileAssignment(IotaTileAssignment::Create( /*dims=*/{3, 5})), /*axes_names=*/{"a", "b"}); - MeshAxesReplicaGroupList replica_group_across_a(one_axes, - /*axes=*/{AxisRef(0)}); - MeshAxesReplicaGroupList replica_group_across_b(one_axes, - /*axes=*/{AxisRef(1)}); + MeshAxesReplicaGroupList replica_group_across_a(one_axes, {AxisRef(0)}); + MeshAxesReplicaGroupList replica_group_across_b(one_axes, {AxisRef(1)}); EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5); EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3); EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3); @@ -71,16 +279,30 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) { EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1); } +TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) { + Mesh mesh(TileAssignment({8}), /*axes_names=*/{"1"}); + MeshAxesReplicaGroupList replica_group_multiple_subaxes1( + mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})}); + MeshAxesReplicaGroupList replica_group_multiple_subaxes2( + mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})}); + + Mesh overlap_mesh(TileAssignment({2 * 3 * 5}), /*axes_names=*/{"u"}); + EXPECT_DEATH( + { + MeshAxesReplicaGroupList overlapping_subaxes( + overlap_mesh, {AxisRef(0, {6, 5}), AxisRef(0, {10, 3})}); + }, + "Overlapping sub-axes"); +} + TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) { Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create( /*dims=*/{2, 6, 10})), /*axes_names=*/{"axis1", "axis2", "axis3"}); MeshAxesReplicaGroupList replica_group_across_axis1_subaxis( - mesh_one_subaxis, - /*axes=*/{AxisRef(0, {1, 2})}); + mesh_one_subaxis, {AxisRef(0, {1, 2})}); MeshAxesReplicaGroupList replica_group_across_axis2_subaxis( - mesh_one_subaxis, - /*axes=*/{AxisRef(1, {2, 3})}); + mesh_one_subaxis, {AxisRef(1, {2, 3})}); EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60); EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2); EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40); @@ -91,10 +313,10 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) { /*axes_names=*/{"alpha", "beta", "gamma"}); MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1( mesh_multiple_subaxis, - /*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})}); + {AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})}); MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2( mesh_multiple_subaxis, - /*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})}); + {AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})}); EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(), 3 * 7 * 13); EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(), @@ -110,11 +332,10 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create( /*dims=*/{10, 12, 15})), /*axes_names=*/{"u", "v", "w"}); - MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{}); + MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, {}); EXPECT_EQ(replica_group_across_none.ToString(), "@mesh {}"); - MeshAxesReplicaGroupList replica_group_across_uv( - mesh_uvw, - /*axes=*/{AxisRef(0), AxisRef(1)}); + MeshAxesReplicaGroupList replica_group_across_uv(mesh_uvw, + {AxisRef(0), AxisRef(1)}); EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh {u,v}"); // Subaxes and replica group v2 iota style device assignment. @@ -122,11 +343,11 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { /*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, /*transpose_perm=*/{2, 3, 0, 1})), /*axes_names=*/{"a", "b", "c", "d"}); - MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{}); + MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, {}); EXPECT_EQ(rg_abcd_across_none.ToString(), "@mesh([4,16]T(1,0)) {}"); MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes( - mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)}); + mesh_abcd, {AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)}); EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(), "@mesh([4,16]T(1,0)) {a,b:(1)2,d}"); @@ -135,11 +356,11 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { array.Reshape({10}); TileAssignment tile_assignment(std::make_shared>(array)); Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"}); - MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{}); + MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, {}); EXPECT_EQ(rg_ooo_across_none.ToString(), "@mesh(8,3,7,5,4,2,6,0,1,9) {}"); MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo, - /*axes=*/{AxisRef(0, {5, 2})}); + {AxisRef(0, {5, 2})}); EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(), "@mesh(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}"); }