mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Refactor std::optional comparison in ReshapeSharding tests
PiperOrigin-RevId: 847749800
This commit is contained in:
committed by
TensorFlower Gardener
parent
12502acbf5
commit
4d0edd395f
@@ -196,8 +196,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned1) {
|
||||
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 2, 2}, {1, 2, 0}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned2) {
|
||||
@@ -208,8 +207,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned2) {
|
||||
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {2, 3, 2}, {0, 2, 1}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned3) {
|
||||
@@ -220,8 +218,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned3) {
|
||||
HloSharding::PartialTile(TileAssignment({4, 3}, {2, 3, 2}, {0, 2, 1}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned4) {
|
||||
@@ -232,8 +229,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned4) {
|
||||
HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 4}, {1, 0}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned5) {
|
||||
@@ -243,8 +239,7 @@ TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned5) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 3, 2, 2});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
|
||||
@@ -253,8 +248,7 @@ TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
|
||||
HloSharding sharding = HloSharding::AssignDevice(7);
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
EXPECT_EQ(result, sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) {
|
||||
@@ -263,7 +257,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) {
|
||||
HloSharding sharding = HloSharding::IotaTile({1, 2, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_FALSE(result.has_value());
|
||||
ASSERT_FALSE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) {
|
||||
@@ -273,8 +267,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
|
||||
@@ -284,8 +277,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) {
|
||||
@@ -295,8 +287,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({4, 4, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit3) {
|
||||
@@ -307,8 +298,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit3) {
|
||||
HloSharding::PartialTile(TileAssignment({2, 1, 2}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
|
||||
@@ -318,8 +308,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) {
|
||||
@@ -328,8 +317,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) {
|
||||
HloSharding sharding = HloSharding::IotaTile({2, 1, 1, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
EXPECT_EQ(result, sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) {
|
||||
@@ -339,8 +327,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({1, 2, 1, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTrivialDimensionInsertedToEnd) {
|
||||
@@ -350,16 +337,14 @@ TEST(HloShardingUtilTest, ReshapeShardingTrivialDimensionInsertedToEnd) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1});
|
||||
HloSharding sharding = HloSharding::IotaTile({2, 1, 1});
|
||||
std::optional<HloSharding> result = ReshapeSharding(shape, shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
EXPECT_EQ(result, sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingScalar) {
|
||||
@@ -368,7 +353,7 @@ TEST(HloShardingUtilTest, ReshapeShardingScalar) {
|
||||
HloSharding sharding = HloSharding::IotaTile({2, 1, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_FALSE(result.has_value());
|
||||
ASSERT_FALSE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne1) {
|
||||
@@ -379,12 +364,10 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne1) {
|
||||
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
|
||||
result = ReshapeSharding(output_shape, input_shape, output_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), input_sharding);
|
||||
EXPECT_EQ(result, input_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne2) {
|
||||
@@ -395,8 +378,7 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne2) {
|
||||
HloSharding::PartialTile(TileAssignment({4, 2, 8}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne3) {
|
||||
@@ -406,8 +388,7 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne3) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({4, 2, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne4) {
|
||||
@@ -418,8 +399,7 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne4) {
|
||||
HloSharding::PartialTile(TileAssignment({4, 2, 4}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne1) {
|
||||
@@ -429,12 +409,10 @@ TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne1) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({1, 4});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
|
||||
result = ReshapeSharding(output_shape, input_shape, output_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), input_sharding);
|
||||
EXPECT_EQ(result, input_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne2) {
|
||||
@@ -444,12 +422,10 @@ TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne2) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 1});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
|
||||
result = ReshapeSharding(output_shape, input_shape, output_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), input_sharding);
|
||||
EXPECT_EQ(result, input_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTranspose1) {
|
||||
@@ -458,8 +434,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose1) {
|
||||
HloSharding sharding = HloSharding::IotaTile({2, 1, 5});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
EXPECT_EQ(result, sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTranspose2) {
|
||||
@@ -469,8 +444,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose2) {
|
||||
HloSharding output_sharding = HloSharding::IotaTile({2, 1, 13});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
ASSERT_TRUE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTranspose3) {
|
||||
@@ -479,7 +453,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose3) {
|
||||
HloSharding input_sharding = HloSharding::IotaTile({1, 1, 5});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_FALSE(result.has_value());
|
||||
ASSERT_FALSE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTranspose4) {
|
||||
@@ -490,8 +464,7 @@ TEST(HloShardingUtilTest, ReshapeShardingTranspose4) {
|
||||
HloSharding::PartialTile(TileAssignment({1, 1, 5, 1, 1, 1, 13}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingWithPadding1) {
|
||||
@@ -500,7 +473,7 @@ TEST(HloShardingUtilTest, ReshapeShardingWithPadding1) {
|
||||
HloSharding input_sharding = HloSharding::IotaTile({8});
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_FALSE(result.has_value());
|
||||
ASSERT_FALSE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingWithPadding2) {
|
||||
@@ -511,8 +484,7 @@ TEST(HloShardingUtilTest, ReshapeShardingWithPadding2) {
|
||||
HloSharding::PartialTile(TileAssignment({4, 2}));
|
||||
std::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
EXPECT_EQ(result, output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose1) {
|
||||
|
||||
Reference in New Issue
Block a user