From dd10786acdfd2875771b3e3a37fd80f22f4a214b Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Mon, 22 Dec 2025 15:58:02 -0800 Subject: [PATCH] Migrate conditional_test to PjRt. PiperOrigin-RevId: 847911726 --- third_party/xla/xla/tests/BUILD | 8 +- third_party/xla/xla/tests/conditional_test.cc | 364 ++++++++++-------- 2 files changed, 217 insertions(+), 155 deletions(-) diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 88057624154..a53c743d0ab 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -801,9 +801,12 @@ xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], shard_count = 2, + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ ":client_library_test_runner_mixin", - ":hlo_test_base", + ":hlo_pjrt_interpreter_reference_mixin", + ":hlo_pjrt_test_base", + ":literal_test_util", ":xla_internal_test_main", # fixdeps: keep "//xla:array2d", "//xla:error_spec", @@ -812,11 +815,14 @@ xla_test( "//xla:shape_util", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", "//xla/hlo/testlib:test_helpers", + "//xla/service:hlo_runner_interface", "//xla/tsl/platform:env", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/xla/tests/conditional_test.cc b/third_party/xla/xla/tests/conditional_test.cc index 157c89c4241..c293e12186e 100644 --- a/third_party/xla/xla/tests/conditional_test.cc +++ b/third_party/xla/xla/tests/conditional_test.cc @@ -13,24 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include #include #include "absl/log/check.h" +#include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" #include "xla/array2d.h" #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/hlo_runner_interface.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_runner_mixin.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -39,12 +46,17 @@ limitations under the License. namespace xla { namespace { +using ::absl_testing::IsOkAndHolds; + constexpr ErrorSpec kErrorSpec{0.001}; -class ConditionalOpTest : public ClientLibraryTestRunnerMixin { +class ConditionalOpTest + : public ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin> { protected: void SetUp() override { - ClientLibraryTestRunnerMixin::SetUp(); + ClientLibraryTestRunnerMixin< + HloPjRtInterpreterReferenceMixin>::SetUp(); mutable_debug_options()->set_xla_test_add_command_buffer_mode(true); } @@ -212,31 +224,37 @@ TEST_F(ConditionalOpTest, Parameters0) { // Test branch computations that do not take any parameters. TEST_P(CaseOpTest, Parameters0) { - int num_branches = GetParam(); + const int num_branches = GetParam(); + + XlaBuilder builder(TestName()); + const XlaOp branch_index = + Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg"); + auto operand = Tuple(&builder, {}); + std::vector operands(num_branches, operand); + std::vector branches; + branches.reserve(num_branches); + std::vector branches_p(num_branches); + for (int i = 0; i < num_branches; ++i) { + branches.push_back(CreateR0ConstantComputation(static_cast(i) * 10)); + branches_p[i] = &branches[i]; + } + Conditional(branch_index, branches_p, operands); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloModuleFromXlaBuilder(&builder, execution_options())); + ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + test_runner().CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); for (int bi = -1; bi <= num_branches; ++bi) { SCOPED_TRACE(bi); - XlaBuilder builder(TestName()); - XlaOp branch_index; - auto branch_index_arg = CreateR0Parameter( - bi, 0, "branch_index_arg", &builder, &branch_index); - auto operand = Tuple(&builder, {}); - - std::vector operands(num_branches, operand); - std::vector branches; - branches.reserve(num_branches); - std::vector branches_p(num_branches); - for (int i = 0; i < num_branches; ++i) { - branches.emplace_back( - CreateR0ConstantComputation(static_cast(i) * 10)); - branches_p[i] = &branches[i]; - } - Conditional(branch_index, branches_p, operands); - - float expected = 10 * static_cast((bi < 0 || bi >= num_branches) - ? num_branches - 1 - : bi); - ComputeAndCompareR0(&builder, expected, {&branch_index_arg}, - kErrorSpec); + const Literal expected = LiteralUtil::CreateR0( + 10 * static_cast( + (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi)); + const Literal branch_index_arg = LiteralUtil::CreateR0(bi); + ASSERT_OK_AND_ASSIGN(const Literal result, + test_runner().ExecuteWithExecutable( + executable.get(), {&branch_index_arg})); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec)); } } @@ -255,40 +273,45 @@ TEST_F(ConditionalOpTest, Parameters1) { // Test branch computations that take in 1 parameter. TEST_P(CaseOpTest, Parameters1) { - int num_branches = GetParam(); + const int num_branches = GetParam(); + + XlaBuilder builder(TestName()); + const XlaOp branch_index = + Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg"); + std::vector branches; + branches.reserve(num_branches); + std::vector branches_p(num_branches); + std::vector operands; + operands.reserve(num_branches); + std::vector expecteds(num_branches); + for (int i = 0; i < num_branches; ++i) { + std::unique_ptr sb = + builder.CreateSubBuilder(absl::StrCat("branch_", i)); + Add(ConstantR0(sb.get(), static_cast(i)), + Parameter(sb.get(), 0, r0f32_, "p0")); + branches.push_back(sb->BuildAndNoteError()); + branches_p[i] = &branches[i]; + const float fi = static_cast(i); + operands.push_back(ConstantR0(&builder, 10 * fi + 7)); + expecteds[i] = LiteralUtil::CreateR0(10 * fi + 7 + fi); + } + Conditional(branch_index, branches_p, operands); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloModuleFromXlaBuilder(&builder, execution_options())); + ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + test_runner().CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); for (int bi = -1; bi <= num_branches; ++bi) { SCOPED_TRACE(bi); - XlaBuilder builder(TestName()); - XlaOp branch_index; - auto branch_index_arg = CreateR0Parameter( - bi, 0, "branch_index_arg", &builder, &branch_index); - - auto make_branch = [&builder, this](int i) { - auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); - Add(ConstantR0(sb.get(), static_cast(i)), - Parameter(sb.get(), 0, r0f32_, "p0")); - return sb->BuildAndNoteError(); - }; - std::vector branches; - branches.reserve(num_branches); - std::vector branches_p(num_branches); - std::vector operands; - operands.reserve(num_branches); - std::vector expecteds(num_branches); - for (int i = 0; i < num_branches; ++i) { - branches.emplace_back(make_branch(i)); - branches_p[i] = &branches[i]; - auto fi = static_cast(i); - operands.emplace_back(ConstantR0(&builder, 10 * fi + 7)); - expecteds[i] = 10 * fi + 7 + fi; - } - - Conditional(branch_index, branches_p, operands); - float expected = (bi < 0 || bi >= num_branches) - ? expecteds[num_branches - 1] - : expecteds[bi]; - ComputeAndCompareR0(&builder, expected, {&branch_index_arg}, - kErrorSpec); + const Literal& expected = (bi < 0 || bi >= num_branches) + ? expecteds[num_branches - 1] + : expecteds[bi]; + const Literal branch_index_arg = LiteralUtil::CreateR0(bi); + ASSERT_OK_AND_ASSIGN(const Literal result, + test_runner().ExecuteWithExecutable( + executable.get(), {&branch_index_arg})); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec)); } } @@ -428,38 +451,46 @@ TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { // Test branch computations that take in 2 array parameters. TEST_P(CaseOpTest, Parameters2Array) { - int num_branches = GetParam(); + const int num_branches = GetParam(); + + XlaBuilder builder(TestName()); + const XlaOp branch_index = + Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg"); + const XlaOp operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + const XlaOp operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + const XlaOp operands = Tuple(&builder, {operand1, operand2}); + std::vector branches; + branches.reserve(num_branches); + std::vector branches_p(num_branches); + for (int i = 0; i < num_branches; ++i) { + std::unique_ptr sb = + builder.CreateSubBuilder(absl::StrCat("branch_", i)); + const XlaOp p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0"); + Add(Mul(ConstantR0(sb.get(), static_cast(i)), + GetTupleElement(p, 0)), + GetTupleElement(p, 1)); + branches.push_back(sb->BuildAndNoteError()); + branches_p[i] = &branches[i]; + } + Conditional(branch_index, branches_p, + std::vector(num_branches, operands)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloModuleFromXlaBuilder(&builder, execution_options())); + ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + test_runner().CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); for (int bi = -1; bi <= num_branches; ++bi) { SCOPED_TRACE(bi); - XlaBuilder builder(TestName()); - XlaOp branch_index; - auto branch_index_arg = - CreateR0Parameter(bi, 0, "pred", &builder, &branch_index); - auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); - auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); - auto operands = Tuple(&builder, {operand1, operand2}); - auto make_branch = [&builder, this](int i) { - auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); - auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0"); - Add(Mul(ConstantR0(sb.get(), static_cast(i)), - GetTupleElement(p, 0)), - GetTupleElement(p, 1)); - return sb->BuildAndNoteError(); - }; - std::vector branches; - branches.reserve(num_branches); - std::vector branches_p(num_branches); - for (int i = 0; i < num_branches; ++i) { - branches.emplace_back(make_branch(i)); - branches_p[i] = &branches[i]; - } - Conditional(branch_index, branches_p, - std::vector(num_branches, operands)); - auto modified_bi = static_cast( + const Literal branch_index_arg = LiteralUtil::CreateR0(bi); + const float modified_bi = static_cast( (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi); - ComputeAndCompareR1( - &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11}, - {&branch_index_arg}, kErrorSpec); + const Literal expected = LiteralUtil::CreateR1( + {24.0f * modified_bi + 10, 56.0f * modified_bi + 11}); + ASSERT_OK_AND_ASSIGN(const Literal result, + test_runner().ExecuteWithExecutable( + executable.get(), {&branch_index_arg})); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec)); } } @@ -561,48 +592,53 @@ TEST_F(ConditionalOpTest, ReturnNestedTuple) { XlaBuilder true_builder(TestName() + ".true"); { Parameter(&true_builder, 0, empty_tuple_, "tuple"); - auto true_constant1 = ConstantR0(&true_builder, 12.2f); - auto true_constant2 = ConstantR1(&true_builder, {12.8f, 14.6f}); - auto true_constant3 = ConstantR1(&true_builder, {25.4f, 29.8f}); - auto true_constant4 = ConstantR0(&true_builder, 35.6f); + const XlaOp true_constant1 = ConstantR0(&true_builder, 12.2f); + const XlaOp true_constant2 = + ConstantR1(&true_builder, {12.8f, 14.6f}); + const XlaOp true_constant3 = + ConstantR1(&true_builder, {25.4f, 29.8f}); + const XlaOp true_constant4 = ConstantR0(&true_builder, 35.6f); Tuple(&true_builder, {Tuple(&true_builder, {true_constant1, true_constant2}), Tuple(&true_builder, {true_constant3, true_constant4})}); } - auto true_builder_result = true_builder.Build(); - EXPECT_IS_OK(true_builder_result.status()); + ASSERT_OK_AND_ASSIGN(XlaComputation true_comp, true_builder.Build()); XlaBuilder false_builder(TestName() + ".false"); { Parameter(&false_builder, 0, empty_tuple_, "tuple"); - auto false_constant1 = ConstantR0(&false_builder, 46.6f); - auto false_constant2 = ConstantR1(&false_builder, {54.4f, 58.4f}); - auto false_constant3 = ConstantR1(&false_builder, {62.1f, 67.4f}); - auto false_constant4 = ConstantR0(&false_builder, 9.3f); + const XlaOp false_constant1 = ConstantR0(&false_builder, 46.6f); + const XlaOp false_constant2 = + ConstantR1(&false_builder, {54.4f, 58.4f}); + const XlaOp false_constant3 = + ConstantR1(&false_builder, {62.1f, 67.4f}); + const XlaOp false_constant4 = ConstantR0(&false_builder, 9.3f); Tuple(&false_builder, {Tuple(&false_builder, {false_constant1, false_constant2}), Tuple(&false_builder, {false_constant3, false_constant4})}); } - auto false_builder_result = false_builder.Build(); - EXPECT_IS_OK(false_builder_result.status()); + ASSERT_OK_AND_ASSIGN(XlaComputation false_comp, false_builder.Build()); XlaBuilder builder(TestName()); XlaOp pred; - auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); - auto operands = Tuple(&builder, {}); - Conditional(pred, operands, std::move(true_builder_result).value(), operands, - std::move(false_builder_result).value()); + const Literal pred_arg = + CreateR0Parameter(false, 0, "pred", &builder, &pred); + const XlaOp operands = Tuple(&builder, {}); + const XlaOp result = Conditional(pred, operands, std::move(true_comp), + operands, std::move(false_comp)); + // Flatten nested tuple for PjRt. + const XlaOp e0 = GetTupleElement(result, 0); + const XlaOp e1 = GetTupleElement(result, 1); + Tuple(&builder, {GetTupleElement(e0, 0), GetTupleElement(e0, 1), + GetTupleElement(e1, 0), GetTupleElement(e1, 1)}); - ComputeAndCompareLiteral( - &builder, - LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR0(46.6f), - LiteralUtil::CreateR1({54.4f, 58.4f})}), - LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR1({62.1f, 67.4f}), - LiteralUtil::CreateR0(9.3f)})}), - {&pred_arg}, kErrorSpec); + ComputeAndCompareLiteral(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(46.6f), + LiteralUtil::CreateR1({54.4f, 58.4f}), + LiteralUtil::CreateR1({62.1f, 67.4f}), + LiteralUtil::CreateR0(9.3f)}), + {&pred_arg}, kErrorSpec); } // Test conditional that takes in scalar operands in the form of external @@ -751,21 +787,31 @@ TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { main = builder.Build().value(); } - auto test_swap = [&](float a, float b) { - XlaBuilder builder(TestName()); - XlaOp x, y; - auto x_arg = CreateR0Parameter(a, 0, "x", &builder, &x); - auto y_arg = CreateR0Parameter(b, 1, "y", &builder, &y); - auto tuple_operand = Tuple(&builder, {x, y}); - Call(&builder, main, {tuple_operand}); - ComputeAndCompareLiteral( - &builder, - LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), - {&x_arg, &y_arg}, kErrorSpec); + XlaBuilder builder(TestName()); + const XlaOp x = Parameter(&builder, 0, r0f32_, "x"); + const XlaOp y = Parameter(&builder, 1, r0f32_, "y"); + const XlaOp tuple_operand = Tuple(&builder, {x, y}); + Call(&builder, main, {tuple_operand}); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloModuleFromXlaBuilder(&builder, execution_options())); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); + + const auto test_swap = + [&, this](float a, + float b) -> absl::StatusOr<::testing::AssertionResult> { + const Literal x_arg = LiteralUtil::CreateR0(a); + const Literal y_arg = LiteralUtil::CreateR0(b); + const Literal expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}); + ASSIGN_OR_RETURN(const Literal result, + test_runner().ExecuteWithExecutable(executable.get(), + {&x_arg, &y_arg})); + return LiteralTestUtil::Near(expected, result, kErrorSpec); }; - test_swap(3.11f, 9.4f); - test_swap(11.24f, 5.55f); + EXPECT_THAT(test_swap(3.11f, 9.4f), IsOkAndHolds(true)); + EXPECT_THAT(test_swap(11.24f, 5.55f), IsOkAndHolds(true)); } // Test conditional that duplicates tuple elements in the then and else @@ -792,35 +838,45 @@ TEST_F(ConditionalOpTest, DuplicateElementsConditional) { else_comp = builder.Build().value(); } - { - // Pred is true case. - std::vector args; - args.push_back(LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR0(123), - LiteralUtil::CreateR0(-42)})); - args.push_back(LiteralUtil::CreateR0(true)); - XlaBuilder builder(TestName() + ".main"); - auto p = Parameter(&builder, 0, tuple2, "p0"); - auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); - Conditional(p_pred, p, then_comp, p, else_comp); - ComputeAndCompare(&builder, {&args[0], &args[1]}); - } - { - // Pred is false case. - std::vector args; - args.push_back(LiteralUtil::MakeTupleFromSlices( - {LiteralUtil::CreateR0(123), - LiteralUtil::CreateR0(-42)})); - args.push_back(LiteralUtil::CreateR0(false)); - XlaBuilder builder(TestName() + ".main"); - auto p = Parameter(&builder, 0, tuple2, "p0"); - auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); - Conditional(p_pred, p, then_comp, p, else_comp); - ComputeAndCompare(&builder, {&args[0], &args[1]}); - } + XlaBuilder builder(TestName() + ".main"); + auto p0 = Parameter(&builder, 0, scalar, "p0.0"); + auto p1 = Parameter(&builder, 1, scalar, "p0.1"); + auto p = Tuple(&builder, {p0, p1}); + auto p_pred = Parameter(&builder, 2, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloModuleFromXlaBuilder(&builder, execution_options())); + + const std::array args = { + LiteralUtil::CreateR0(123), LiteralUtil::CreateR0(-42), + LiteralUtil::CreateR0(true), LiteralUtil::CreateR0(false)}; + const std::array true_args = {&args[0], &args[1], + &args[2]}; + const std::array false_args = {&args[0], &args[1], + &args[3]}; + + // Compute reference values. Because this test is not parameterized, we need + // to manually invoke the test runner and reference runner. + ASSERT_OK_AND_ASSIGN(Literal true_reference, + reference_runner().Execute(module->Clone(), true_args, + /*run_hlo_passes=*/true)); + ASSERT_OK_AND_ASSIGN(Literal false_reference, + reference_runner().Execute(module->Clone(), false_args, + /*run_hlo_passes=*/true)); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); + ASSERT_OK_AND_ASSIGN(Literal true_result, test_runner().ExecuteWithExecutable( + executable.get(), true_args)); + ASSERT_OK_AND_ASSIGN( + Literal false_result, + test_runner().ExecuteWithExecutable(executable.get(), false_args)); + EXPECT_TRUE(LiteralTestUtil::Equal(true_reference, true_result)); + EXPECT_TRUE(LiteralTestUtil::Equal(false_reference, false_result)); } -using ConditionalOpHloTest = HloTestBase; +using ConditionalOpHloTest = HloPjRtTestBase; TEST_F(ConditionalOpHloTest, ParallelExecution) { // Test conditional works when an executable is executed in parallel.