Migrate conditional_test to PjRt.

PiperOrigin-RevId: 847911726
This commit is contained in:
Niklas Vangerow
2025-12-22 15:58:02 -08:00
committed by TensorFlower Gardener
parent 69ea2a9308
commit dd10786acd
2 changed files with 217 additions and 155 deletions

View File

@@ -801,9 +801,12 @@ xla_test(
name = "conditional_test", name = "conditional_test",
srcs = ["conditional_test.cc"], srcs = ["conditional_test.cc"],
shard_count = 2, shard_count = 2,
tags = ["test_migrated_to_hlo_runner_pjrt"],
deps = [ deps = [
":client_library_test_runner_mixin", ":client_library_test_runner_mixin",
":hlo_test_base", ":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":literal_test_util",
":xla_internal_test_main", # fixdeps: keep ":xla_internal_test_main", # fixdeps: keep
"//xla:array2d", "//xla:array2d",
"//xla:error_spec", "//xla:error_spec",
@@ -812,11 +815,14 @@ xla_test(
"//xla:shape_util", "//xla:shape_util",
"//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation", "//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:test_helpers", "//xla/hlo/testlib:test_helpers",
"//xla/service:hlo_runner_interface",
"//xla/tsl/platform:env", "//xla/tsl/platform:env",
"//xla/tsl/platform:statusor", "//xla/tsl/platform:statusor",
"//xla/tsl/platform:test", "//xla/tsl/platform:test",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@@ -13,24 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <array>
#include <cstdint> #include <cstdint>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/log/check.h" #include "absl/log/check.h"
#include "absl/status/status_matchers.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "xla/array2d.h" #include "xla/array2d.h"
#include "xla/error_spec.h" #include "xla/error_spec.h"
#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/testlib/test_helpers.h"
#include "xla/literal.h" #include "xla/literal.h"
#include "xla/literal_util.h" #include "xla/literal_util.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/shape.h" #include "xla/shape.h"
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/tests/client_library_test_runner_mixin.h" #include "xla/tests/client_library_test_runner_mixin.h"
#include "xla/tests/hlo_test_base.h" #include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/env.h" #include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test.h"
@@ -39,12 +46,17 @@ limitations under the License.
namespace xla { namespace xla {
namespace { namespace {
using ::absl_testing::IsOkAndHolds;
constexpr ErrorSpec kErrorSpec{0.001}; constexpr ErrorSpec kErrorSpec{0.001};
class ConditionalOpTest : public ClientLibraryTestRunnerMixin<HloTestBase> { class ConditionalOpTest
: public ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
protected: protected:
void SetUp() override { void SetUp() override {
ClientLibraryTestRunnerMixin<HloTestBase>::SetUp(); ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>::SetUp();
mutable_debug_options()->set_xla_test_add_command_buffer_mode(true); mutable_debug_options()->set_xla_test_add_command_buffer_mode(true);
} }
@@ -212,31 +224,37 @@ TEST_F(ConditionalOpTest, Parameters0) {
// Test branch computations that do not take any parameters. // Test branch computations that do not take any parameters.
TEST_P(CaseOpTest, Parameters0) { TEST_P(CaseOpTest, Parameters0) {
int num_branches = GetParam(); const int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32_t>(
bi, 0, "branch_index_arg", &builder, &branch_index);
auto operand = Tuple(&builder, {});
XlaBuilder builder(TestName());
const XlaOp branch_index =
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
auto operand = Tuple(&builder, {});
std::vector<XlaOp> operands(num_branches, operand); std::vector<XlaOp> operands(num_branches, operand);
std::vector<XlaComputation> branches; std::vector<XlaComputation> branches;
branches.reserve(num_branches); branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches); std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) { for (int i = 0; i < num_branches; ++i) {
branches.emplace_back( branches.push_back(CreateR0ConstantComputation(static_cast<float>(i) * 10));
CreateR0ConstantComputation(static_cast<float>(i) * 10));
branches_p[i] = &branches[i]; branches_p[i] = &branches[i];
} }
Conditional(branch_index, branches_p, operands); Conditional(branch_index, branches_p, operands);
float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches) ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
? num_branches - 1 HloModuleFromXlaBuilder(&builder, execution_options()));
: bi); ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
ComputeAndCompareR0<float>(&builder, expected, {&branch_index_arg}, test_runner().CreateExecutable(std::move(module),
kErrorSpec); /*run_hlo_passes=*/true));
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const Literal expected = LiteralUtil::CreateR0<float>(
10 * static_cast<float>(
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi));
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
ASSERT_OK_AND_ASSIGN(const Literal result,
test_runner().ExecuteWithExecutable(
executable.get(), {&branch_index_arg}));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
} }
} }
@@ -255,40 +273,45 @@ TEST_F(ConditionalOpTest, Parameters1) {
// Test branch computations that take in 1 parameter. // Test branch computations that take in 1 parameter.
TEST_P(CaseOpTest, Parameters1) { TEST_P(CaseOpTest, Parameters1) {
int num_branches = GetParam(); const int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32_t>(
bi, 0, "branch_index_arg", &builder, &branch_index);
auto make_branch = [&builder, this](int i) { XlaBuilder builder(TestName());
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); const XlaOp branch_index =
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)), Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
Parameter(sb.get(), 0, r0f32_, "p0"));
return sb->BuildAndNoteError();
};
std::vector<XlaComputation> branches; std::vector<XlaComputation> branches;
branches.reserve(num_branches); branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches); std::vector<const XlaComputation*> branches_p(num_branches);
std::vector<XlaOp> operands; std::vector<XlaOp> operands;
operands.reserve(num_branches); operands.reserve(num_branches);
std::vector<float> expecteds(num_branches); std::vector<Literal> expecteds(num_branches);
for (int i = 0; i < num_branches; ++i) { for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i)); std::unique_ptr<XlaBuilder> sb =
builder.CreateSubBuilder(absl::StrCat("branch_", i));
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
Parameter(sb.get(), 0, r0f32_, "p0"));
branches.push_back(sb->BuildAndNoteError());
branches_p[i] = &branches[i]; branches_p[i] = &branches[i];
auto fi = static_cast<float>(i); const float fi = static_cast<float>(i);
operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7)); operands.push_back(ConstantR0<float>(&builder, 10 * fi + 7));
expecteds[i] = 10 * fi + 7 + fi; expecteds[i] = LiteralUtil::CreateR0<float>(10 * fi + 7 + fi);
} }
Conditional(branch_index, branches_p, operands); Conditional(branch_index, branches_p, operands);
float expected = (bi < 0 || bi >= num_branches)
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
test_runner().CreateExecutable(std::move(module),
/*run_hlo_passes=*/true));
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const Literal& expected = (bi < 0 || bi >= num_branches)
? expecteds[num_branches - 1] ? expecteds[num_branches - 1]
: expecteds[bi]; : expecteds[bi];
ComputeAndCompareR0<float>(&builder, expected, {&branch_index_arg}, const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
kErrorSpec); ASSERT_OK_AND_ASSIGN(const Literal result,
test_runner().ExecuteWithExecutable(
executable.get(), {&branch_index_arg}));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
} }
} }
@@ -428,38 +451,46 @@ TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
// Test branch computations that take in 2 array parameters. // Test branch computations that take in 2 array parameters.
TEST_P(CaseOpTest, Parameters2Array) { TEST_P(CaseOpTest, Parameters2Array) {
int num_branches = GetParam(); const int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp branch_index; const XlaOp branch_index =
auto branch_index_arg = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index); const XlaOp operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f}); const XlaOp operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f}); const XlaOp operands = Tuple(&builder, {operand1, operand2});
auto operands = Tuple(&builder, {operand1, operand2});
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
GetTupleElement(p, 0)),
GetTupleElement(p, 1));
return sb->BuildAndNoteError();
};
std::vector<XlaComputation> branches; std::vector<XlaComputation> branches;
branches.reserve(num_branches); branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches); std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) { for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i)); std::unique_ptr<XlaBuilder> sb =
builder.CreateSubBuilder(absl::StrCat("branch_", i));
const XlaOp p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
GetTupleElement(p, 0)),
GetTupleElement(p, 1));
branches.push_back(sb->BuildAndNoteError());
branches_p[i] = &branches[i]; branches_p[i] = &branches[i];
} }
Conditional(branch_index, branches_p, Conditional(branch_index, branches_p,
std::vector<XlaOp>(num_branches, operands)); std::vector<XlaOp>(num_branches, operands));
auto modified_bi = static_cast<float>(
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
test_runner().CreateExecutable(std::move(module),
/*run_hlo_passes=*/true));
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
const float modified_bi = static_cast<float>(
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi); (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
ComputeAndCompareR1<float>( const Literal expected = LiteralUtil::CreateR1<float>(
&builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11}, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11});
{&branch_index_arg}, kErrorSpec); ASSERT_OK_AND_ASSIGN(const Literal result,
test_runner().ExecuteWithExecutable(
executable.get(), {&branch_index_arg}));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
} }
} }
@@ -561,47 +592,52 @@ TEST_F(ConditionalOpTest, ReturnNestedTuple) {
XlaBuilder true_builder(TestName() + ".true"); XlaBuilder true_builder(TestName() + ".true");
{ {
Parameter(&true_builder, 0, empty_tuple_, "tuple"); Parameter(&true_builder, 0, empty_tuple_, "tuple");
auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f); const XlaOp true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f}); const XlaOp true_constant2 =
auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f}); ConstantR1<float>(&true_builder, {12.8f, 14.6f});
auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f); const XlaOp true_constant3 =
ConstantR1<float>(&true_builder, {25.4f, 29.8f});
const XlaOp true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
Tuple(&true_builder, Tuple(&true_builder,
{Tuple(&true_builder, {true_constant1, true_constant2}), {Tuple(&true_builder, {true_constant1, true_constant2}),
Tuple(&true_builder, {true_constant3, true_constant4})}); Tuple(&true_builder, {true_constant3, true_constant4})});
} }
auto true_builder_result = true_builder.Build(); ASSERT_OK_AND_ASSIGN(XlaComputation true_comp, true_builder.Build());
EXPECT_IS_OK(true_builder_result.status());
XlaBuilder false_builder(TestName() + ".false"); XlaBuilder false_builder(TestName() + ".false");
{ {
Parameter(&false_builder, 0, empty_tuple_, "tuple"); Parameter(&false_builder, 0, empty_tuple_, "tuple");
auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f); const XlaOp false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f}); const XlaOp false_constant2 =
auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f}); ConstantR1<float>(&false_builder, {54.4f, 58.4f});
auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f); const XlaOp false_constant3 =
ConstantR1<float>(&false_builder, {62.1f, 67.4f});
const XlaOp false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
Tuple(&false_builder, Tuple(&false_builder,
{Tuple(&false_builder, {false_constant1, false_constant2}), {Tuple(&false_builder, {false_constant1, false_constant2}),
Tuple(&false_builder, {false_constant3, false_constant4})}); Tuple(&false_builder, {false_constant3, false_constant4})});
} }
auto false_builder_result = false_builder.Build(); ASSERT_OK_AND_ASSIGN(XlaComputation false_comp, false_builder.Build());
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp pred; XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); const Literal pred_arg =
auto operands = Tuple(&builder, {}); CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
Conditional(pred, operands, std::move(true_builder_result).value(), operands, const XlaOp operands = Tuple(&builder, {});
std::move(false_builder_result).value()); const XlaOp result = Conditional(pred, operands, std::move(true_comp),
operands, std::move(false_comp));
// Flatten nested tuple for PjRt.
const XlaOp e0 = GetTupleElement(result, 0);
const XlaOp e1 = GetTupleElement(result, 1);
Tuple(&builder, {GetTupleElement(e0, 0), GetTupleElement(e0, 1),
GetTupleElement(e1, 0), GetTupleElement(e1, 1)});
ComputeAndCompareLiteral( ComputeAndCompareLiteral(&builder,
&builder,
LiteralUtil::MakeTupleFromSlices( LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(46.6f), {LiteralUtil::CreateR0<float>(46.6f),
LiteralUtil::CreateR1<float>({54.4f, 58.4f})}), LiteralUtil::CreateR1<float>({54.4f, 58.4f}),
LiteralUtil::MakeTupleFromSlices( LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}), LiteralUtil::CreateR0<float>(9.3f)}),
LiteralUtil::CreateR0<float>(9.3f)})}),
{&pred_arg}, kErrorSpec); {&pred_arg}, kErrorSpec);
} }
@@ -751,21 +787,31 @@ TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
main = builder.Build().value(); main = builder.Build().value();
} }
auto test_swap = [&](float a, float b) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
XlaOp x, y; const XlaOp x = Parameter(&builder, 0, r0f32_, "x");
auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x); const XlaOp y = Parameter(&builder, 1, r0f32_, "y");
auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y); const XlaOp tuple_operand = Tuple(&builder, {x, y});
auto tuple_operand = Tuple(&builder, {x, y});
Call(&builder, main, {tuple_operand}); Call(&builder, main, {tuple_operand});
ComputeAndCompareLiteral( ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
&builder, HloModuleFromXlaBuilder(&builder, execution_options()));
LiteralUtil::MakeTupleFromSlices( ASSERT_OK_AND_ASSIGN(
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}), std::unique_ptr<OpaqueExecutable> executable,
{&x_arg, &y_arg}, kErrorSpec); CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
const auto test_swap =
[&, this](float a,
float b) -> absl::StatusOr<::testing::AssertionResult> {
const Literal x_arg = LiteralUtil::CreateR0<float>(a);
const Literal y_arg = LiteralUtil::CreateR0<float>(b);
const Literal expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)});
ASSIGN_OR_RETURN(const Literal result,
test_runner().ExecuteWithExecutable(executable.get(),
{&x_arg, &y_arg}));
return LiteralTestUtil::Near(expected, result, kErrorSpec);
}; };
test_swap(3.11f, 9.4f); EXPECT_THAT(test_swap(3.11f, 9.4f), IsOkAndHolds(true));
test_swap(11.24f, 5.55f); EXPECT_THAT(test_swap(11.24f, 5.55f), IsOkAndHolds(true));
} }
// Test conditional that duplicates tuple elements in the then and else // Test conditional that duplicates tuple elements in the then and else
@@ -792,35 +838,45 @@ TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
else_comp = builder.Build().value(); else_comp = builder.Build().value();
} }
{
// Pred is true case.
std::vector<Literal> args;
args.push_back(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<int32_t>(123),
LiteralUtil::CreateR0<int32_t>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main"); XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0"); auto p0 = Parameter(&builder, 0, scalar, "p0.0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); auto p1 = Parameter(&builder, 1, scalar, "p0.1");
auto p = Tuple(&builder, {p0, p1});
auto p_pred = Parameter(&builder, 2, ShapeUtil::MakeShape(PRED, {}), "p1");
Conditional(p_pred, p, then_comp, p, else_comp); Conditional(p_pred, p, then_comp, p, else_comp);
ComputeAndCompare(&builder, {&args[0], &args[1]}); ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
} HloModuleFromXlaBuilder(&builder, execution_options()));
{
// Pred is false case. const std::array<Literal, 4> args = {
std::vector<Literal> args; LiteralUtil::CreateR0<int32_t>(123), LiteralUtil::CreateR0<int32_t>(-42),
args.push_back(LiteralUtil::MakeTupleFromSlices( LiteralUtil::CreateR0<bool>(true), LiteralUtil::CreateR0<bool>(false)};
{LiteralUtil::CreateR0<int32_t>(123), const std::array<const Literal*, 3> true_args = {&args[0], &args[1],
LiteralUtil::CreateR0<int32_t>(-42)})); &args[2]};
args.push_back(LiteralUtil::CreateR0<bool>(false)); const std::array<const Literal*, 3> false_args = {&args[0], &args[1],
XlaBuilder builder(TestName() + ".main"); &args[3]};
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); // Compute reference values. Because this test is not parameterized, we need
Conditional(p_pred, p, then_comp, p, else_comp); // to manually invoke the test runner and reference runner.
ComputeAndCompare(&builder, {&args[0], &args[1]}); ASSERT_OK_AND_ASSIGN(Literal true_reference,
} reference_runner().Execute(module->Clone(), true_args,
/*run_hlo_passes=*/true));
ASSERT_OK_AND_ASSIGN(Literal false_reference,
reference_runner().Execute(module->Clone(), false_args,
/*run_hlo_passes=*/true));
ASSERT_OK_AND_ASSIGN(
std::unique_ptr<OpaqueExecutable> executable,
CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
ASSERT_OK_AND_ASSIGN(Literal true_result, test_runner().ExecuteWithExecutable(
executable.get(), true_args));
ASSERT_OK_AND_ASSIGN(
Literal false_result,
test_runner().ExecuteWithExecutable(executable.get(), false_args));
EXPECT_TRUE(LiteralTestUtil::Equal(true_reference, true_result));
EXPECT_TRUE(LiteralTestUtil::Equal(false_reference, false_result));
} }
using ConditionalOpHloTest = HloTestBase; using ConditionalOpHloTest = HloPjRtTestBase;
TEST_F(ConditionalOpHloTest, ParallelExecution) { TEST_F(ConditionalOpHloTest, ParallelExecution) {
// Test conditional works when an executable is executed in parallel. // Test conditional works when an executable is executed in parallel.