Refactor std::optional comparison in ReshapeSharding tests

PiperOrigin-RevId: 847749800
This commit is contained in:
Kanish Anand
2025-12-22 07:00:54 -08:00
committed by TensorFlower Gardener
parent 12502acbf5
commit 4d0edd395f

View File

@@ -196,8 +196,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned1) {
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 2, 2}, {1, 2, 0}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned2) {
@@ -208,8 +207,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned2) {
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {2, 3, 2}, {0, 2, 1}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned3) {
@@ -220,8 +218,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned3) {
HloSharding::PartialTile(TileAssignment({4, 3}, {2, 3, 2}, {0, 2, 1}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned4) {
@@ -232,8 +229,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned4) {
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 4}, {1, 0}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned5) {
@@ -243,8 +239,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned5) {
HloSharding output_sharding = HloSharding::IotaTile({2, 3, 2, 2});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
@@ -253,8 +248,7 @@ TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
HloSharding sharding = HloSharding::AssignDevice(7);
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
EXPECT_EQ(result, sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) {
@@ -263,7 +257,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) {
HloSharding sharding = HloSharding::IotaTile({1, 2, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_FALSE(result.has_value());
ASSERT_FALSE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) {
@@ -273,8 +267,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) {
HloSharding output_sharding = HloSharding::IotaTile({2, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
@@ -284,8 +277,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) {
@@ -295,8 +287,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) {
HloSharding output_sharding = HloSharding::IotaTile({4, 4, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit3) {
@@ -307,8 +298,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit3) {
HloSharding::PartialTile(TileAssignment({2, 1, 2}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
@@ -318,8 +308,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) {
@@ -328,8 +317,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) {
HloSharding sharding = HloSharding::IotaTile({2, 1, 1, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
EXPECT_EQ(result, sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) {
@@ -339,8 +327,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) {
HloSharding output_sharding = HloSharding::IotaTile({1, 2, 1, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTrivialDimensionInsertedToEnd) {
@@ -350,16 +337,14 @@ TEST(HloShardingUtilTest, ReshapeShardingTrivialDimensionInsertedToEnd) {
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) {
Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1});
HloSharding sharding = HloSharding::IotaTile({2, 1, 1});
std::optional<HloSharding> result = ReshapeSharding(shape, shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
EXPECT_EQ(result, sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingScalar) {
@@ -368,7 +353,7 @@ TEST(HloShardingUtilTest, ReshapeShardingScalar) {
HloSharding sharding = HloSharding::IotaTile({2, 1, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_FALSE(result.has_value());
ASSERT_FALSE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne1) {
@@ -379,12 +364,10 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne1) {
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
result = ReshapeSharding(output_shape, input_shape, output_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), input_sharding);
EXPECT_EQ(result, input_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne2) {
@@ -395,8 +378,7 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne2) {
HloSharding::PartialTile(TileAssignment({4, 2, 8}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne3) {
@@ -406,8 +388,7 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne3) {
HloSharding output_sharding = HloSharding::IotaTile({4, 2, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne4) {
@@ -418,8 +399,7 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne4) {
HloSharding::PartialTile(TileAssignment({4, 2, 4}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne1) {
@@ -429,12 +409,10 @@ TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne1) {
HloSharding output_sharding = HloSharding::IotaTile({1, 4});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
result = ReshapeSharding(output_shape, input_shape, output_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), input_sharding);
EXPECT_EQ(result, input_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne2) {
@@ -444,12 +422,10 @@ TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne2) {
HloSharding output_sharding = HloSharding::IotaTile({2, 1});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
result = ReshapeSharding(output_shape, input_shape, output_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), input_sharding);
EXPECT_EQ(result, input_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTranspose1) {
@@ -458,8 +434,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose1) {
HloSharding sharding = HloSharding::IotaTile({2, 1, 5});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
EXPECT_EQ(result, sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTranspose2) {
@@ -469,8 +444,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose2) {
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 13});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
ASSERT_TRUE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeShardingTranspose3) {
@@ -479,7 +453,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose3) {
HloSharding input_sharding = HloSharding::IotaTile({1, 1, 5});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_FALSE(result.has_value());
ASSERT_FALSE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeShardingTranspose4) {
@@ -490,8 +464,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose4) {
HloSharding::PartialTile(TileAssignment({1, 1, 5, 1, 1, 1, 13}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingWithPadding1) {
@@ -500,7 +473,7 @@ TEST(HloShardingUtilTest, ReshapeShardingWithPadding1) {
HloSharding input_sharding = HloSharding::IotaTile({8});
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_FALSE(result.has_value());
ASSERT_FALSE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeShardingWithPadding2) {
@@ -511,8 +484,7 @@ TEST(HloShardingUtilTest, ReshapeShardingWithPadding2) {
HloSharding::PartialTile(TileAssignment({4, 2}));
std::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
EXPECT_EQ(result, output_sharding);
}
TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose1) {