Migrate conditional_test to PjRt.

PiperOrigin-RevId: 847911726
This commit is contained in:
Niklas Vangerow
2025-12-22 15:58:02 -08:00
committed by TensorFlower Gardener
parent 69ea2a9308
commit dd10786acd
2 changed files with 217 additions and 155 deletions

View File

@@ -801,9 +801,12 @@ xla_test(
name = "conditional_test",
srcs = ["conditional_test.cc"],
shard_count = 2,
tags = ["test_migrated_to_hlo_runner_pjrt"],
deps = [
":client_library_test_runner_mixin",
":hlo_test_base",
":hlo_pjrt_interpreter_reference_mixin",
":hlo_pjrt_test_base",
":literal_test_util",
":xla_internal_test_main", # fixdeps: keep
"//xla:array2d",
"//xla:error_spec",
@@ -812,11 +815,14 @@ xla_test(
"//xla:shape_util",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:test_helpers",
"//xla/service:hlo_runner_interface",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/strings",
],
)

View File

@@ -13,24 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/status/status_matchers.h"
#include "absl/strings/str_cat.h"
#include "xla/array2d.h"
#include "xla/error_spec.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/test_helpers.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/client_library_test_runner_mixin.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
@@ -39,12 +46,17 @@ limitations under the License.
namespace xla {
namespace {
using ::absl_testing::IsOkAndHolds;
constexpr ErrorSpec kErrorSpec{0.001};
class ConditionalOpTest : public ClientLibraryTestRunnerMixin<HloTestBase> {
class ConditionalOpTest
: public ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
protected:
void SetUp() override {
ClientLibraryTestRunnerMixin<HloTestBase>::SetUp();
ClientLibraryTestRunnerMixin<
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>::SetUp();
mutable_debug_options()->set_xla_test_add_command_buffer_mode(true);
}
@@ -212,31 +224,37 @@ TEST_F(ConditionalOpTest, Parameters0) {
// Test branch computations that do not take any parameters.
TEST_P(CaseOpTest, Parameters0) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32_t>(
bi, 0, "branch_index_arg", &builder, &branch_index);
auto operand = Tuple(&builder, {});
const int num_branches = GetParam();
XlaBuilder builder(TestName());
const XlaOp branch_index =
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
auto operand = Tuple(&builder, {});
std::vector<XlaOp> operands(num_branches, operand);
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(
CreateR0ConstantComputation(static_cast<float>(i) * 10));
branches.push_back(CreateR0ConstantComputation(static_cast<float>(i) * 10));
branches_p[i] = &branches[i];
}
Conditional(branch_index, branches_p, operands);
float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
? num_branches - 1
: bi);
ComputeAndCompareR0<float>(&builder, expected, {&branch_index_arg},
kErrorSpec);
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
test_runner().CreateExecutable(std::move(module),
/*run_hlo_passes=*/true));
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const Literal expected = LiteralUtil::CreateR0<float>(
10 * static_cast<float>(
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi));
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
ASSERT_OK_AND_ASSIGN(const Literal result,
test_runner().ExecuteWithExecutable(
executable.get(), {&branch_index_arg}));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
}
}
@@ -255,40 +273,45 @@ TEST_F(ConditionalOpTest, Parameters1) {
// Test branch computations that take in 1 parameter.
TEST_P(CaseOpTest, Parameters1) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32_t>(
bi, 0, "branch_index_arg", &builder, &branch_index);
const int num_branches = GetParam();
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
Parameter(sb.get(), 0, r0f32_, "p0"));
return sb->BuildAndNoteError();
};
XlaBuilder builder(TestName());
const XlaOp branch_index =
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
std::vector<XlaOp> operands;
operands.reserve(num_branches);
std::vector<float> expecteds(num_branches);
std::vector<Literal> expecteds(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i));
std::unique_ptr<XlaBuilder> sb =
builder.CreateSubBuilder(absl::StrCat("branch_", i));
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
Parameter(sb.get(), 0, r0f32_, "p0"));
branches.push_back(sb->BuildAndNoteError());
branches_p[i] = &branches[i];
auto fi = static_cast<float>(i);
operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
expecteds[i] = 10 * fi + 7 + fi;
const float fi = static_cast<float>(i);
operands.push_back(ConstantR0<float>(&builder, 10 * fi + 7));
expecteds[i] = LiteralUtil::CreateR0<float>(10 * fi + 7 + fi);
}
Conditional(branch_index, branches_p, operands);
float expected = (bi < 0 || bi >= num_branches)
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
test_runner().CreateExecutable(std::move(module),
/*run_hlo_passes=*/true));
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const Literal& expected = (bi < 0 || bi >= num_branches)
? expecteds[num_branches - 1]
: expecteds[bi];
ComputeAndCompareR0<float>(&builder, expected, {&branch_index_arg},
kErrorSpec);
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
ASSERT_OK_AND_ASSIGN(const Literal result,
test_runner().ExecuteWithExecutable(
executable.get(), {&branch_index_arg}));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
}
}
@@ -428,38 +451,46 @@ TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
// Test branch computations that take in 2 array parameters.
TEST_P(CaseOpTest, Parameters2Array) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const int num_branches = GetParam();
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg =
CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
GetTupleElement(p, 0)),
GetTupleElement(p, 1));
return sb->BuildAndNoteError();
};
const XlaOp branch_index =
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
const XlaOp operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
const XlaOp operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
const XlaOp operands = Tuple(&builder, {operand1, operand2});
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i));
std::unique_ptr<XlaBuilder> sb =
builder.CreateSubBuilder(absl::StrCat("branch_", i));
const XlaOp p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
GetTupleElement(p, 0)),
GetTupleElement(p, 1));
branches.push_back(sb->BuildAndNoteError());
branches_p[i] = &branches[i];
}
Conditional(branch_index, branches_p,
std::vector<XlaOp>(num_branches, operands));
auto modified_bi = static_cast<float>(
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
test_runner().CreateExecutable(std::move(module),
/*run_hlo_passes=*/true));
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
const float modified_bi = static_cast<float>(
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
ComputeAndCompareR1<float>(
&builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
{&branch_index_arg}, kErrorSpec);
const Literal expected = LiteralUtil::CreateR1<float>(
{24.0f * modified_bi + 10, 56.0f * modified_bi + 11});
ASSERT_OK_AND_ASSIGN(const Literal result,
test_runner().ExecuteWithExecutable(
executable.get(), {&branch_index_arg}));
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
}
}
@@ -561,47 +592,52 @@ TEST_F(ConditionalOpTest, ReturnNestedTuple) {
XlaBuilder true_builder(TestName() + ".true");
{
Parameter(&true_builder, 0, empty_tuple_, "tuple");
auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
const XlaOp true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
const XlaOp true_constant2 =
ConstantR1<float>(&true_builder, {12.8f, 14.6f});
const XlaOp true_constant3 =
ConstantR1<float>(&true_builder, {25.4f, 29.8f});
const XlaOp true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
Tuple(&true_builder,
{Tuple(&true_builder, {true_constant1, true_constant2}),
Tuple(&true_builder, {true_constant3, true_constant4})});
}
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
ASSERT_OK_AND_ASSIGN(XlaComputation true_comp, true_builder.Build());
XlaBuilder false_builder(TestName() + ".false");
{
Parameter(&false_builder, 0, empty_tuple_, "tuple");
auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
const XlaOp false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
const XlaOp false_constant2 =
ConstantR1<float>(&false_builder, {54.4f, 58.4f});
const XlaOp false_constant3 =
ConstantR1<float>(&false_builder, {62.1f, 67.4f});
const XlaOp false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
Tuple(&false_builder,
{Tuple(&false_builder, {false_constant1, false_constant2}),
Tuple(&false_builder, {false_constant3, false_constant4})});
}
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
ASSERT_OK_AND_ASSIGN(XlaComputation false_comp, false_builder.Build());
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, std::move(true_builder_result).value(), operands,
std::move(false_builder_result).value());
const Literal pred_arg =
CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
const XlaOp operands = Tuple(&builder, {});
const XlaOp result = Conditional(pred, operands, std::move(true_comp),
operands, std::move(false_comp));
// Flatten nested tuple for PjRt.
const XlaOp e0 = GetTupleElement(result, 0);
const XlaOp e1 = GetTupleElement(result, 1);
Tuple(&builder, {GetTupleElement(e0, 0), GetTupleElement(e0, 1),
GetTupleElement(e1, 0), GetTupleElement(e1, 1)});
ComputeAndCompareLiteral(
&builder,
ComputeAndCompareLiteral(&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(46.6f),
LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
LiteralUtil::CreateR0<float>(9.3f)})}),
LiteralUtil::CreateR1<float>({54.4f, 58.4f}),
LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
LiteralUtil::CreateR0<float>(9.3f)}),
{&pred_arg}, kErrorSpec);
}
@@ -751,21 +787,31 @@ TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
main = builder.Build().value();
}
auto test_swap = [&](float a, float b) {
XlaBuilder builder(TestName());
XlaOp x, y;
auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
auto tuple_operand = Tuple(&builder, {x, y});
const XlaOp x = Parameter(&builder, 0, r0f32_, "x");
const XlaOp y = Parameter(&builder, 1, r0f32_, "y");
const XlaOp tuple_operand = Tuple(&builder, {x, y});
Call(&builder, main, {tuple_operand});
ComputeAndCompareLiteral(
&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{&x_arg, &y_arg}, kErrorSpec);
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
ASSERT_OK_AND_ASSIGN(
std::unique_ptr<OpaqueExecutable> executable,
CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
const auto test_swap =
[&, this](float a,
float b) -> absl::StatusOr<::testing::AssertionResult> {
const Literal x_arg = LiteralUtil::CreateR0<float>(a);
const Literal y_arg = LiteralUtil::CreateR0<float>(b);
const Literal expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)});
ASSIGN_OR_RETURN(const Literal result,
test_runner().ExecuteWithExecutable(executable.get(),
{&x_arg, &y_arg}));
return LiteralTestUtil::Near(expected, result, kErrorSpec);
};
test_swap(3.11f, 9.4f);
test_swap(11.24f, 5.55f);
EXPECT_THAT(test_swap(3.11f, 9.4f), IsOkAndHolds(true));
EXPECT_THAT(test_swap(11.24f, 5.55f), IsOkAndHolds(true));
}
// Test conditional that duplicates tuple elements in the then and else
@@ -792,35 +838,45 @@ TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
else_comp = builder.Build().value();
}
{
// Pred is true case.
std::vector<Literal> args;
args.push_back(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<int32_t>(123),
LiteralUtil::CreateR0<int32_t>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
auto p0 = Parameter(&builder, 0, scalar, "p0.0");
auto p1 = Parameter(&builder, 1, scalar, "p0.1");
auto p = Tuple(&builder, {p0, p1});
auto p_pred = Parameter(&builder, 2, ShapeUtil::MakeShape(PRED, {}), "p1");
Conditional(p_pred, p, then_comp, p, else_comp);
ComputeAndCompare(&builder, {&args[0], &args[1]});
}
{
// Pred is false case.
std::vector<Literal> args;
args.push_back(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<int32_t>(123),
LiteralUtil::CreateR0<int32_t>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
Conditional(p_pred, p, then_comp, p, else_comp);
ComputeAndCompare(&builder, {&args[0], &args[1]});
}
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloModuleFromXlaBuilder(&builder, execution_options()));
const std::array<Literal, 4> args = {
LiteralUtil::CreateR0<int32_t>(123), LiteralUtil::CreateR0<int32_t>(-42),
LiteralUtil::CreateR0<bool>(true), LiteralUtil::CreateR0<bool>(false)};
const std::array<const Literal*, 3> true_args = {&args[0], &args[1],
&args[2]};
const std::array<const Literal*, 3> false_args = {&args[0], &args[1],
&args[3]};
// Compute reference values. Because this test is not parameterized, we need
// to manually invoke the test runner and reference runner.
ASSERT_OK_AND_ASSIGN(Literal true_reference,
reference_runner().Execute(module->Clone(), true_args,
/*run_hlo_passes=*/true));
ASSERT_OK_AND_ASSIGN(Literal false_reference,
reference_runner().Execute(module->Clone(), false_args,
/*run_hlo_passes=*/true));
ASSERT_OK_AND_ASSIGN(
std::unique_ptr<OpaqueExecutable> executable,
CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
ASSERT_OK_AND_ASSIGN(Literal true_result, test_runner().ExecuteWithExecutable(
executable.get(), true_args));
ASSERT_OK_AND_ASSIGN(
Literal false_result,
test_runner().ExecuteWithExecutable(executable.get(), false_args));
EXPECT_TRUE(LiteralTestUtil::Equal(true_reference, true_result));
EXPECT_TRUE(LiteralTestUtil::Equal(false_reference, false_result));
}
using ConditionalOpHloTest = HloTestBase;
using ConditionalOpHloTest = HloPjRtTestBase;
TEST_F(ConditionalOpHloTest, ParallelExecution) {
// Test conditional works when an executable is executed in parallel.