[ReplicaGroupV3][MeshAxesReplicaGroupList][2/2] Add flattened_replica_groups function for MeshAxesReplicaGroupList.

PiperOrigin-RevId: 826619318
This commit is contained in:
Bill Varcho
2025-10-31 14:00:07 -07:00
committed by TensorFlower Gardener
parent a6e123761d
commit 261e077984
6 changed files with 545 additions and 33 deletions

View File

@@ -422,6 +422,7 @@ xla_cc_test(
":mesh_and_axis",
":tile_assignment",
"//xla:array",
"//xla:array2d",
"//xla:xla_data_proto_cc",
"//xla/service:hlo_proto_cc",
"//xla/tsl/platform:test_main",

View File

@@ -15,6 +15,7 @@ limitations under the License.
#include "xla/hlo/ir/mesh_and_axis.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
@@ -101,4 +102,42 @@ AxisRef AxisRef::FromProto(const AxisRefProto& proto) {
return axis_ref;
}
bool canSubAxesCoexist(int64_t minPreSize, int64_t maxPreSize,
int64_t minNextPreSize, int64_t maxNextPreSize) {
if (minNextPreSize > maxPreSize) {
// Sub-axes overlap, check if overlapping and non-overlapping parts are
// valid.
return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 &&
maxNextPreSize % minNextPreSize == 0;
}
// Sub-axes don't overlap, check if the gap is valid.
return maxPreSize % minNextPreSize == 0;
}
bool AxisRef::CanCoexist(const AxisRef& other) const {
if (mesh_axis_index() != other.mesh_axis_index()) {
return true;
}
if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
// If one is a full axis and the other is a sub-axis, they can coexist.
return true;
}
const SubAxis& this_sub_axis = sub_axis_info_.value();
const SubAxis& other_sub_axis = other.sub_axis_info_.value();
int64_t this_pre_size = this_sub_axis.pre_size;
int64_t other_pre_size = other_sub_axis.pre_size;
int64_t this_next_pre_size = this_sub_axis.next_pre_size();
int64_t other_next_pre_size = other_sub_axis.next_pre_size();
auto [min_pre_size, max_pre_size] =
std::minmax(this_pre_size, other_pre_size);
auto [min_next_pre_size, max_next_pre_size] =
std::minmax(this_next_pre_size, other_next_pre_size);
return canSubAxesCoexist(min_pre_size, max_pre_size, min_next_pre_size,
max_next_pre_size);
}
} // namespace xla

View File

@@ -109,6 +109,12 @@ class Mesh {
TileAssignment device_assignment() const { return device_assignment_; }
std::vector<std::string> axis_names() const { return axes_names_; }
absl::Span<const int64_t> axis_sizes() const {
return device_assignment_.dimensions();
}
int64_t axis_size(int64_t axis_index) const {
return device_assignment_.dim(axis_index);
}
private:
// Dimensions of the `device_assignment_` array correspond to the axes of the
@@ -127,6 +133,7 @@ class AxisRef {
struct SubAxis {
int64_t pre_size;
int64_t size;
int64_t next_pre_size() const { return pre_size * size; }
};
// Index corresponding to axis in the mesh. It should be a valid index into
@@ -177,6 +184,8 @@ class AxisRef {
static AxisRef FromProto(const AxisRefProto& proto);
bool CanCoexist(const AxisRef& other) const;
int64_t mesh_axis_index() const { return mesh_axis_index_; }
std::optional<SubAxis> sub_axis_info() const { return sub_axis_info_; }
};

View File

@@ -15,12 +15,18 @@ limitations under the License.
#include "xla/hlo/ir/replica_group.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
@@ -33,7 +39,6 @@ limitations under the License.
#include "xla/service/hlo.pb.h"
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"
namespace xla {
@@ -49,6 +54,145 @@ std::string ReplicaGroupsToString(
}
/************** MeshAxesReplicaGroupList implementation ***********************/
void HandleSingleAxisRefPerDimension(const AxisRef& axis,
int64_t full_axis_size,
std::vector<int64_t>& out_reshape_dims,
std::vector<int64_t>& out_aggregate_axes) {
if (axis.sub_axis_info().has_value()) {
out_reshape_dims = {axis.sub_axis_info()->pre_size,
axis.sub_axis_info()->size,
full_axis_size / axis.sub_axis_info()->next_pre_size()};
// The aggregation axis is the second dimension.
out_aggregate_axes = {1};
} else {
out_reshape_dims = {full_axis_size};
out_aggregate_axes = {0};
}
}
void HandleMultiAxisRefPerDimension(std::vector<AxisRef>& axes,
int64_t full_axis_size,
std::vector<int64_t>& out_reshape_dims,
std::vector<int64_t>& out_aggregate_axes) {
// --- 1. Sort Axes and Original Indices Together ---
// Sort both the axes and the original indices based on
// sub_axis_info()->pre_size. This allows us to maintain user specified order
// of AxisRef while still building the reshape and aggregate axes.
std::vector<int> original_order(axes.size());
std::iota(original_order.begin(), original_order.end(), 0);
std::sort(original_order.begin(), original_order.end(),
[&axes](int i, int j) {
return axes[i].sub_axis_info()->pre_size <
axes[j].sub_axis_info()->pre_size;
});
std::sort(axes.begin(), axes.end(), [](const AxisRef& a, const AxisRef& b) {
return a.sub_axis_info()->pre_size < b.sub_axis_info()->pre_size;
});
// --- 2. Build Reshape Dims and Aggregation Axes ---
int64_t current_dim_index = 0; // Index in the new reshaped tensor
int64_t prefix_product = 1; // Product of the size of all prior dimensions
for (const AxisRef& axis : axes) {
int64_t pre_size = axis.sub_axis_info()->pre_size;
int64_t size = axis.sub_axis_info()->size;
// Insert "padding" dimension if the current prefix product doesn't match
// the required pre_size
if (pre_size != prefix_product) {
int64_t padding_size = pre_size / prefix_product;
out_reshape_dims.push_back(padding_size);
current_dim_index++;
prefix_product *= padding_size;
}
// Insert the sharded size (the part to aggregate)
out_reshape_dims.push_back(size);
out_aggregate_axes.push_back(
current_dim_index); // This is the axis we aggregate over
current_dim_index++;
prefix_product *= size;
}
// Insert "suffix" dimension if the full size hasn't been reached
if (prefix_product != full_axis_size) {
out_reshape_dims.push_back(full_axis_size / prefix_product);
}
// --- 3. Permute Aggregate Axes back to Original Order ---
// The aggregate axes were calculated based on the sorted list.
// We must map them back to the original order to compute the correct
// flattened replica groups.
std::vector<int64_t> permuted_aggregate_axes(original_order.size());
for (int64_t i = 0; i < original_order.size(); ++i) {
permuted_aggregate_axes[original_order[i]] = out_aggregate_axes[i];
}
out_aggregate_axes = permuted_aggregate_axes;
}
bool ValidateSingleDimensionAxes(int64_t dim, std::vector<AxisRef>& axes,
const Mesh& mesh_) {
// If there's only one axis, nothing to check.
if (axes.size() <= 1) {
return true;
}
// --- Step 1: Deduplication ---
// If one is a "full" axis (no sub_axis_info), it subsumes all other AxisRefs
// with sub_axis_info.
for (const AxisRef& axis : axes) {
if (!axis.sub_axis_info().has_value()) {
LOG(WARNING) << "MeshAxesReplicaGroupList: Redundant axis definition at "
"dimension: "
<< dim
<< ". Keeping only the full axis: " << axis.ToString(mesh_);
axes = {axis};
return true;
}
}
// --- Step 2: Overlap Check ---
// At this point, all remaining axes MUST have sub_axis_info().
// Verify that the remaining multiple sub-axes do not overlap.
for (int64_t i = 0; i < axes.size() - 1; ++i) {
for (int64_t j = i + 1; j < axes.size(); ++j) {
// CHECK will terminate the program on failure, matching original
// behavior.
CHECK(axes[i].CanCoexist(axes[j]))
<< "Overlapping sub-axes detected: " << axes[i].ToString(mesh_)
<< " and " << axes[j].ToString(mesh_);
}
}
return true; // Passed all checks for this dimension.
}
MeshAxesReplicaGroupList::MeshAxesReplicaGroupList(Mesh mesh,
std::vector<AxisRef> axes)
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
if (num_devices_per_group() == 1) {
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
<< " has only one device per replica group.";
}
absl::flat_hash_set<int64_t> dimensions;
absl::flat_hash_map<int64_t, std::vector<AxisRef>> dim_to_axes;
for (const AxisRef& axis : axes_) {
dim_to_axes[axis.mesh_axis_index()].push_back(axis);
dimensions.insert(axis.mesh_axis_index());
if (axis.sub_axis_info().has_value()) {
CHECK(mesh_.axis_size(axis.mesh_axis_index()) %
axis.sub_axis_info()->next_pre_size() ==
0)
<< "Next pre-size must divide the full axis size.";
}
}
// Validate input AxisRefs.
for (int64_t dim : dimensions) {
std::vector<AxisRef>& axes = dim_to_axes[dim];
CHECK(ValidateSingleDimensionAxes(dim, axes, mesh_));
}
}
int64_t MeshAxesReplicaGroupList::num_replica_groups() const {
return mesh_.device_assignment().num_elements() / num_devices_per_group();
}
@@ -58,15 +202,105 @@ int64_t MeshAxesReplicaGroupList::num_devices_per_group() const {
// all axes.
int64_t devices_per_group = 1;
for (const AxisRef& axis : axes_) {
int64_t axis_size =
axis.sub_axis_info().has_value()
? axis.sub_axis_info()->size
: mesh_.device_assignment().dim(axis.mesh_axis_index());
int64_t axis_size = axis.sub_axis_info().has_value()
? axis.sub_axis_info()->size
: mesh_.axis_size(axis.mesh_axis_index());
devices_per_group *= axis_size;
}
return devices_per_group;
}
std::vector<std::vector<int64_t>> get_replica_groups_for_full_axes(
const Mesh& mesh, absl::Span<const int64_t> axis_sizes,
const absl::Span<const int64_t> grouped_axes,
const int64_t num_replica_groups, const int64_t num_devices_per_group) {
// Reshape the device assignment array bases on the axis sizes and transpose
// grouped axes to the end.
std::vector<int> transpose_axes;
transpose_axes.reserve(axis_sizes.size());
for (int64_t i = 0; i < axis_sizes.size(); ++i) {
if (!absl::c_linear_search(grouped_axes, i)) {
transpose_axes.push_back(i);
}
}
for (int64_t grouped_axis : grouped_axes) {
transpose_axes.push_back(grouped_axis);
}
TileAssignment device_assignment =
mesh.device_assignment().Reshape(axis_sizes).Transpose(transpose_axes);
std::vector<std::vector<int64_t>> replica_groups;
replica_groups.reserve(num_replica_groups);
for (auto it = device_assignment.array().begin();
it != device_assignment.array().end(); it += num_devices_per_group) {
std::vector<int64_t> group(it, it + num_devices_per_group);
replica_groups.emplace_back(std::move(group));
}
return replica_groups;
}
void MeshAxesReplicaGroupList::InitializeDimToReshapeAndAggregateAxes() {
absl::flat_hash_map<int64_t, std::vector<AxisRef>> dim_to_axes;
for (const AxisRef& axis : axes_) {
dim_to_axes[axis.mesh_axis_index()].push_back(axis);
}
absl::flat_hash_map<int64_t, ReshapeAndAggregateAxes> dim_map;
// For each dimension determine the reshape that is consistent with it's
// AxisRef(s). Then maintain this reshape and the aggregated dims for easier
// computation of replica groups. As an example for @mesh<"a"=8>
// {a} -> no reshape, aggregate over [0]
// {a:(1)2} -> reshape [8]->[1,2,4], aggregate over [1]
// {a:(1)2, a:(4)2} -> reshape [8]->[2,2,2], aggregate over [0,2]
for (auto& [dim, axes] : dim_to_axes) {
int64_t full_axis_size = mesh_.axis_size(dim);
ReshapeAndAggregateAxes reshape_and_aggregate_axes;
if (axes.size() == 1) {
HandleSingleAxisRefPerDimension(
axes[0], full_axis_size, reshape_and_aggregate_axes.reshape_dims,
reshape_and_aggregate_axes.aggregate_axes);
} else {
// Otherwise dimension is a set of axes with sub-axes info.
HandleMultiAxisRefPerDimension(axes, full_axis_size,
reshape_and_aggregate_axes.reshape_dims,
reshape_and_aggregate_axes.aggregate_axes);
}
dim_map[dim] = reshape_and_aggregate_axes;
}
dim_to_reshape_and_aggregate_axes_ = dim_map;
}
std::vector<std::vector<int64_t>>
MeshAxesReplicaGroupList::flattened_replica_groups() {
if (!dim_to_reshape_and_aggregate_axes_.has_value()) {
InitializeDimToReshapeAndAggregateAxes();
}
absl::flat_hash_map<int64_t, ReshapeAndAggregateAxes> dim_map =
dim_to_reshape_and_aggregate_axes_.value();
std::vector<int64_t> reindex_axis_sizes;
std::vector<int64_t> reindexed_grouped_axes;
for (int64_t i = 0; i < mesh_.axis_sizes().size(); ++i) {
int64_t axis_size = mesh_.axis_size(i);
auto it = dim_map.find(i);
if (it == dim_map.end()) {
reindex_axis_sizes.push_back(axis_size);
continue;
}
int64_t offset_index = reindex_axis_sizes.size();
const ReshapeAndAggregateAxes& reshape_and_aggregate_axes = it->second;
for (int64_t reshape_dim : reshape_and_aggregate_axes.reshape_dims) {
reindex_axis_sizes.push_back(reshape_dim);
}
for (int64_t aggregate_dim : reshape_and_aggregate_axes.aggregate_axes) {
reindexed_grouped_axes.push_back(aggregate_dim + offset_index);
}
}
return get_replica_groups_for_full_axes(
mesh_, reindex_axis_sizes, reindexed_grouped_axes, num_replica_groups(),
num_devices_per_group());
}
void MeshAxesReplicaGroupList::Print(Printer* printer) const {
printer->Append(ToString());
}

View File

@@ -16,14 +16,19 @@ limitations under the License.
#ifndef XLA_HLO_IR_REPLICA_GROUP_H_
#define XLA_HLO_IR_REPLICA_GROUP_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/types/span.h"
#include "google/protobuf/repeated_ptr_field.h"
@@ -38,14 +43,13 @@ limitations under the License.
namespace xla {
class MeshAxesReplicaGroupList {
struct ReshapeAndAggregateAxes {
std::vector<int64_t> reshape_dims;
std::vector<int64_t> aggregate_axes;
};
public:
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes)
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
if (num_devices_per_group() == 1) {
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
<< " has only one device per replica group.";
}
}
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes);
bool operator==(const MeshAxesReplicaGroupList& other) const {
return mesh_ == other.mesh_ && axes_ == other.axes_;
@@ -58,6 +62,7 @@ class MeshAxesReplicaGroupList {
int64_t num_replica_groups() const;
int64_t num_devices_per_group() const;
std::vector<std::vector<int64_t>> flattened_replica_groups();
void Print(Printer* printer) const;
@@ -69,8 +74,11 @@ class MeshAxesReplicaGroupList {
const MeshAxesReplicaGroupListProto& proto);
private:
void InitializeDimToReshapeAndAggregateAxes();
Mesh mesh_;
std::vector<AxisRef> axes_;
std::optional<absl::flat_hash_map<int64_t, ReshapeAndAggregateAxes>>
dim_to_reshape_and_aggregate_axes_;
};
std::string ReplicaGroupsToString(

View File

@@ -22,6 +22,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "xla/array.h"
#include "xla/array2d.h"
#include "xla/hlo/ir/mesh_and_axis.h"
#include "xla/hlo/ir/tile_assignment.h"
#include "xla/service/hlo.pb.h"
@@ -41,23 +42,230 @@ CollectiveDeviceListProto CreateDeviceListProto(
return proto;
}
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) {
Mesh mesh_xy(TileAssignment({2, 2}), /*axes_names=*/{"x", "y"});
MeshAxesReplicaGroupList replica_group_none(mesh_xy, {});
std::vector<std::vector<int64_t>> expected_replica_groups_none = {
{0}, {1}, {2}, {3}};
EXPECT_EQ(replica_group_none.flattened_replica_groups(),
expected_replica_groups_none);
MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)});
std::vector<std::vector<int64_t>> expected_replica_groups_x = {{0, 2},
{1, 3}};
EXPECT_EQ(replica_group_x.flattened_replica_groups(),
expected_replica_groups_x);
MeshAxesReplicaGroupList replica_group_y(mesh_xy, {AxisRef(1)});
std::vector<std::vector<int64_t>> expected_replica_groups_y = {{0, 1},
{2, 3}};
EXPECT_EQ(replica_group_y.flattened_replica_groups(),
expected_replica_groups_y);
MeshAxesReplicaGroupList replica_group_xy(mesh_xy, {AxisRef(0), AxisRef(1)});
std::vector<std::vector<int64_t>> expected_replica_groups_xy = {{0, 1, 2, 3}};
EXPECT_EQ(replica_group_xy.flattened_replica_groups(),
expected_replica_groups_xy);
}
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) {
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{6, 6})),
/*axes_names=*/{"a", "b"});
// a:(1)2
MeshAxesReplicaGroupList replica_group_a_1_2(mesh, {AxisRef(0, {1, 2})});
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_2 = {
{0, 18}, {1, 19}, {2, 20}, {3, 21}, {4, 22}, {5, 23},
{6, 24}, {7, 25}, {8, 26}, {9, 27}, {10, 28}, {11, 29},
{12, 30}, {13, 31}, {14, 32}, {15, 33}, {16, 34}, {17, 35}};
EXPECT_EQ(replica_group_a_1_2.flattened_replica_groups(),
expected_replica_groups_a_1_2);
// a:(1)3
MeshAxesReplicaGroupList replica_group_a_1_3(mesh, {AxisRef(0, {1, 3})});
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_3 = {
{0, 12, 24}, {1, 13, 25}, {2, 14, 26}, {3, 15, 27},
{4, 16, 28}, {5, 17, 29}, {6, 18, 30}, {7, 19, 31},
{8, 20, 32}, {9, 21, 33}, {10, 22, 34}, {11, 23, 35}};
EXPECT_EQ(replica_group_a_1_3.flattened_replica_groups(),
expected_replica_groups_a_1_3);
// a:(3)2
MeshAxesReplicaGroupList replica_group_a_3_2(mesh, {AxisRef(0, {3, 2})});
std::vector<std::vector<int64_t>> expected_replica_groups_a_3_2 = {
{0, 6}, {1, 7}, {2, 8}, {3, 9}, {4, 10}, {5, 11},
{12, 18}, {13, 19}, {14, 20}, {15, 21}, {16, 22}, {17, 23},
{24, 30}, {25, 31}, {26, 32}, {27, 33}, {28, 34}, {29, 35}};
EXPECT_EQ(replica_group_a_3_2.flattened_replica_groups(),
expected_replica_groups_a_3_2);
// b:(1)2
MeshAxesReplicaGroupList replica_group_b_1_2(mesh, {AxisRef(1, {1, 2})});
std::vector<std::vector<int64_t>> expected_replica_groups_b_1_2 = {
{0, 3}, {1, 4}, {2, 5}, {6, 9}, {7, 10}, {8, 11},
{12, 15}, {13, 16}, {14, 17}, {18, 21}, {19, 22}, {20, 23},
{24, 27}, {25, 28}, {26, 29}, {30, 33}, {31, 34}, {32, 35}};
EXPECT_EQ(replica_group_b_1_2.flattened_replica_groups(),
expected_replica_groups_b_1_2);
// b:(1)3
MeshAxesReplicaGroupList replica_group_b_1_3(mesh, {AxisRef(1, {1, 3})});
std::vector<std::vector<int64_t>> expected_replica_groups_b_1_3 = {
{0, 2, 4}, {1, 3, 5}, {6, 8, 10}, {7, 9, 11},
{12, 14, 16}, {13, 15, 17}, {18, 20, 22}, {19, 21, 23},
{24, 26, 28}, {25, 27, 29}, {30, 32, 34}, {31, 33, 35}};
EXPECT_EQ(replica_group_b_1_3.flattened_replica_groups(),
expected_replica_groups_b_1_3);
// b:(3)2
MeshAxesReplicaGroupList replica_group_b_3_2(mesh, {AxisRef(1, {3, 2})});
std::vector<std::vector<int64_t>> expected_replica_groups_b_3_2 = {
{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11},
{12, 13}, {14, 15}, {16, 17}, {18, 19}, {20, 21}, {22, 23},
{24, 25}, {26, 27}, {28, 29}, {30, 31}, {32, 33}, {34, 35}};
EXPECT_EQ(replica_group_b_3_2.flattened_replica_groups(),
expected_replica_groups_b_3_2);
// a:(1)2, b:(1)2
MeshAxesReplicaGroupList replica_group_a_1_2_b_1_2(
mesh, {AxisRef(0, {1, 2}), AxisRef(1, {1, 2})});
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_2_b_1_2 = {
{0, 3, 18, 21}, {1, 4, 19, 22}, {2, 5, 20, 23},
{6, 9, 24, 27}, {7, 10, 25, 28}, {8, 11, 26, 29},
{12, 15, 30, 33}, {13, 16, 31, 34}, {14, 17, 32, 35}};
EXPECT_EQ(replica_group_a_1_2_b_1_2.flattened_replica_groups(),
expected_replica_groups_a_1_2_b_1_2);
// a:(1)3, b:(1)3
MeshAxesReplicaGroupList replica_group_a_1_3_b_1_3(
mesh, {AxisRef(0, {1, 3}), AxisRef(1, {1, 3})});
std::vector<std::vector<int64_t>> expected_replica_groups_a_1_3_b_1_3 = {
{0, 2, 4, 12, 14, 16, 24, 26, 28},
{1, 3, 5, 13, 15, 17, 25, 27, 29},
{6, 8, 10, 18, 20, 22, 30, 32, 34},
{7, 9, 11, 19, 21, 23, 31, 33, 35}};
EXPECT_EQ(replica_group_a_1_3_b_1_3.flattened_replica_groups(),
expected_replica_groups_a_1_3_b_1_3);
// b:(1)3, a:(1)3 (Reverse order from above). This should produce the same
// replica groups as the above but with ids in a different order.
MeshAxesReplicaGroupList replica_group_b_1_3_a_1_3(
mesh, {AxisRef(1, {1, 3}), AxisRef(0, {1, 3})});
std::vector<std::vector<int64_t>> expected_replica_groups_b_1_3_a_1_3 = {
{0, 12, 24, 2, 14, 26, 4, 16, 28},
{1, 13, 25, 3, 15, 27, 5, 17, 29},
{6, 18, 30, 8, 20, 32, 10, 22, 34},
{7, 19, 31, 9, 21, 33, 11, 23, 35}};
EXPECT_EQ(replica_group_a_1_3_b_1_3.flattened_replica_groups(),
expected_replica_groups_a_1_3_b_1_3);
// a:(3)2, b:(3)2
MeshAxesReplicaGroupList replica_group_a_3_2_b_3_2(
mesh, {AxisRef(0, {3, 2}), AxisRef(1, {3, 2})});
std::vector<std::vector<int64_t>> expected_replica_groups_a_3_2_b_3_2 = {
{0, 1, 6, 7}, {2, 3, 8, 9}, {4, 5, 10, 11},
{12, 13, 18, 19}, {14, 15, 20, 21}, {16, 17, 22, 23},
{24, 25, 30, 31}, {26, 27, 32, 33}, {28, 29, 34, 35}};
EXPECT_EQ(replica_group_a_3_2_b_3_2.flattened_replica_groups(),
expected_replica_groups_a_3_2_b_3_2);
}
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsMatchExpectedV2) {
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{8})),
/*axes_names=*/{"a"});
// a:(1)2 -> replica_groups=[4,2]<=[2,4]T(1,0)
MeshAxesReplicaGroupList v3_subaxis_1_2(mesh, {AxisRef(0, {1, 2})});
IotaReplicaGroupList v2_subaxis_1_2(4, 2, {2, 4}, {1, 0});
EXPECT_EQ(v3_subaxis_1_2.flattened_replica_groups(),
v2_subaxis_1_2.flattened_replica_groups());
// a:(1)4 -> replica_groups=[2,4]<=[4,2]T(1,0)
MeshAxesReplicaGroupList v3_subaxis_1_4(mesh, {AxisRef(0, {1, 4})});
IotaReplicaGroupList v2_subaxis_1_4(2, 4, {4, 2}, {1, 0});
EXPECT_EQ(v3_subaxis_1_4.flattened_replica_groups(),
v2_subaxis_1_4.flattened_replica_groups());
// a:(2)2 -> replica_groups=[4,2]<=[2,2,2]T(0,2,1)
MeshAxesReplicaGroupList v3_subaxis_2_2(mesh, {AxisRef(0, {2, 2})});
IotaReplicaGroupList v2_subaxis_2_2(4, 2, {2, 2, 2}, {0, 2, 1});
EXPECT_EQ(v3_subaxis_2_2.flattened_replica_groups(),
v2_subaxis_2_2.flattened_replica_groups());
// a:(2)4 -> replica_groups=[2,4]<=[8]
MeshAxesReplicaGroupList v3_subaxis_2_4(mesh, {AxisRef(0, {2, 4})});
IotaReplicaGroupList v2_subaxis_2_4(2, 4, {8}, {0});
EXPECT_EQ(v3_subaxis_2_4.flattened_replica_groups(),
v2_subaxis_2_4.flattened_replica_groups());
// a:(4)2 -> replica_groups=[4,2]<=[8]
MeshAxesReplicaGroupList v3_subaxis_4_2(mesh, {AxisRef(0, {4, 2})});
IotaReplicaGroupList v2_subaxis_4_2(4, 2, {8}, {0});
EXPECT_EQ(v3_subaxis_4_2.flattened_replica_groups(),
v2_subaxis_4_2.flattened_replica_groups());
// {a:(1)2, a:(4)2} -> replica_groups=[2,4]<=[2,2,2]T(1,0,2)
MeshAxesReplicaGroupList v3_subaxis_1_2_and_4_2(
mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})});
IotaReplicaGroupList v2_subaxis_1_2_and_4_2(2, 4, {2, 2, 2}, {1, 0, 2});
EXPECT_EQ(v3_subaxis_1_2_and_4_2.flattened_replica_groups(),
v2_subaxis_1_2_and_4_2.flattened_replica_groups());
// {a:(4)2, a:(1)2} -> replica_groups=[2,4]<=[2,2,2]T(1,2,0)
MeshAxesReplicaGroupList v3_subaxis_4_2_and_1_2(
mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})});
IotaReplicaGroupList v2_subaxis_4_2_and_1_2(2, 4, {2, 2, 2}, {1, 2, 0});
EXPECT_EQ(v3_subaxis_4_2_and_1_2.flattened_replica_groups(),
v2_subaxis_4_2_and_1_2.flattened_replica_groups());
// a -> replica_groups=[1,8]<=[8]
MeshAxesReplicaGroupList v3_no_subaxis(mesh, {AxisRef(0)});
IotaReplicaGroupList v2_no_subaxis(1, 8, {8}, {0});
EXPECT_EQ(v3_no_subaxis.flattened_replica_groups(),
v2_no_subaxis.flattened_replica_groups());
}
TEST(MeshAxesReplicaGroupListTest,
MaterializedReplicaGroupsRespectNonIotaDeviceOrdering) {
// Create a mesh with non-iota device ordering.
Array2D<int64_t> array({{3, 1}, {0, 2}});
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
Mesh mesh_xy(tile_assignment, /*axes_names=*/{"x", "y"});
// Reduce along x axis.
MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)});
// With iota device ordering, the expected replica groups would be
// {{0, 2}, {1, 3}}.
std::vector<std::vector<int64_t>> expected_replica_groups_x = {{3, 0},
{1, 2}};
EXPECT_THAT(replica_group_x.flattened_replica_groups(),
testing::UnorderedElementsAreArray(expected_replica_groups_x));
// Reduce along y axis.
MeshAxesReplicaGroupList replica_group_y(mesh_xy, {AxisRef(1)});
// With iota device ordering, the expected replica groups would be
// {{0, 1}, {2, 3}}.
std::vector<std::vector<int64_t>> expected_replica_groups_y = {{3, 1},
{0, 2}};
EXPECT_THAT(replica_group_y.flattened_replica_groups(),
testing::UnorderedElementsAreArray(expected_replica_groups_y));
}
TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) {
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{4, 4})),
/*axes_names=*/{"x", "y"});
MeshAxesReplicaGroupList replica_group_across_all_axes(
all_axes,
/*axes=*/{AxisRef(0), AxisRef(1)});
all_axes, {AxisRef(0), AxisRef(1)});
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{3, 5})),
/*axes_names=*/{"a", "b"});
MeshAxesReplicaGroupList replica_group_across_a(one_axes,
/*axes=*/{AxisRef(0)});
MeshAxesReplicaGroupList replica_group_across_b(one_axes,
/*axes=*/{AxisRef(1)});
MeshAxesReplicaGroupList replica_group_across_a(one_axes, {AxisRef(0)});
MeshAxesReplicaGroupList replica_group_across_b(one_axes, {AxisRef(1)});
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3);
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
@@ -71,16 +279,30 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
}
TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) {
Mesh mesh(TileAssignment({8}), /*axes_names=*/{"1"});
MeshAxesReplicaGroupList replica_group_multiple_subaxes1(
mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})});
MeshAxesReplicaGroupList replica_group_multiple_subaxes2(
mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})});
Mesh overlap_mesh(TileAssignment({2 * 3 * 5}), /*axes_names=*/{"u"});
EXPECT_DEATH(
{
MeshAxesReplicaGroupList overlapping_subaxes(
overlap_mesh, {AxisRef(0, {6, 5}), AxisRef(0, {10, 3})});
},
"Overlapping sub-axes");
}
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 6, 10})),
/*axes_names=*/{"axis1", "axis2", "axis3"});
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
mesh_one_subaxis,
/*axes=*/{AxisRef(0, {1, 2})});
mesh_one_subaxis, {AxisRef(0, {1, 2})});
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
mesh_one_subaxis,
/*axes=*/{AxisRef(1, {2, 3})});
mesh_one_subaxis, {AxisRef(1, {2, 3})});
EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60);
EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2);
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
@@ -91,10 +313,10 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
/*axes_names=*/{"alpha", "beta", "gamma"});
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
mesh_multiple_subaxis,
/*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2(
mesh_multiple_subaxis,
/*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(),
3 * 7 * 13);
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(),
@@ -110,11 +332,10 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{10, 12, 15})),
/*axes_names=*/{"u", "v", "w"});
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{});
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, {});
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
MeshAxesReplicaGroupList replica_group_across_uv(
mesh_uvw,
/*axes=*/{AxisRef(0), AxisRef(1)});
MeshAxesReplicaGroupList replica_group_across_uv(mesh_uvw,
{AxisRef(0), AxisRef(1)});
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
// Subaxes and replica group v2 iota style device assignment.
@@ -122,11 +343,11 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1})),
/*axes_names=*/{"a", "b", "c", "d"});
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{});
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, {});
EXPECT_EQ(rg_abcd_across_none.ToString(),
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes(
mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
mesh_abcd, {AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(),
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {a,b:(1)2,d}");
@@ -135,11 +356,11 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
array.Reshape({10});
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{});
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, {});
EXPECT_EQ(rg_ooo_across_none.ToString(),
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");
MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo,
/*axes=*/{AxisRef(0, {5, 2})});
{AxisRef(0, {5, 2})});
EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(),
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}");
}