mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Migrate conditional_test to PjRt.
PiperOrigin-RevId: 847911726
This commit is contained in:
committed by
TensorFlower Gardener
parent
69ea2a9308
commit
dd10786acd
8
third_party/xla/xla/tests/BUILD
vendored
8
third_party/xla/xla/tests/BUILD
vendored
@@ -801,9 +801,12 @@ xla_test(
|
|||||||
name = "conditional_test",
|
name = "conditional_test",
|
||||||
srcs = ["conditional_test.cc"],
|
srcs = ["conditional_test.cc"],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
|
tags = ["test_migrated_to_hlo_runner_pjrt"],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_runner_mixin",
|
":client_library_test_runner_mixin",
|
||||||
":hlo_test_base",
|
":hlo_pjrt_interpreter_reference_mixin",
|
||||||
|
":hlo_pjrt_test_base",
|
||||||
|
":literal_test_util",
|
||||||
":xla_internal_test_main", # fixdeps: keep
|
":xla_internal_test_main", # fixdeps: keep
|
||||||
"//xla:array2d",
|
"//xla:array2d",
|
||||||
"//xla:error_spec",
|
"//xla:error_spec",
|
||||||
@@ -812,11 +815,14 @@ xla_test(
|
|||||||
"//xla:shape_util",
|
"//xla:shape_util",
|
||||||
"//xla/hlo/builder:xla_builder",
|
"//xla/hlo/builder:xla_builder",
|
||||||
"//xla/hlo/builder:xla_computation",
|
"//xla/hlo/builder:xla_computation",
|
||||||
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/hlo/testlib:test_helpers",
|
"//xla/hlo/testlib:test_helpers",
|
||||||
|
"//xla/service:hlo_runner_interface",
|
||||||
"//xla/tsl/platform:env",
|
"//xla/tsl/platform:env",
|
||||||
"//xla/tsl/platform:statusor",
|
"//xla/tsl/platform:statusor",
|
||||||
"//xla/tsl/platform:test",
|
"//xla/tsl/platform:test",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
|
"@com_google_absl//absl/status:status_matchers",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
364
third_party/xla/xla/tests/conditional_test.cc
vendored
364
third_party/xla/xla/tests/conditional_test.cc
vendored
@@ -13,24 +13,31 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/log/check.h"
|
#include "absl/log/check.h"
|
||||||
|
#include "absl/status/status_matchers.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "xla/array2d.h"
|
#include "xla/array2d.h"
|
||||||
#include "xla/error_spec.h"
|
#include "xla/error_spec.h"
|
||||||
#include "xla/hlo/builder/xla_builder.h"
|
#include "xla/hlo/builder/xla_builder.h"
|
||||||
#include "xla/hlo/builder/xla_computation.h"
|
#include "xla/hlo/builder/xla_computation.h"
|
||||||
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/hlo/testlib/test_helpers.h"
|
#include "xla/hlo/testlib/test_helpers.h"
|
||||||
#include "xla/literal.h"
|
#include "xla/literal.h"
|
||||||
#include "xla/literal_util.h"
|
#include "xla/literal_util.h"
|
||||||
|
#include "xla/service/hlo_runner_interface.h"
|
||||||
#include "xla/shape.h"
|
#include "xla/shape.h"
|
||||||
#include "xla/shape_util.h"
|
#include "xla/shape_util.h"
|
||||||
#include "xla/tests/client_library_test_runner_mixin.h"
|
#include "xla/tests/client_library_test_runner_mixin.h"
|
||||||
#include "xla/tests/hlo_test_base.h"
|
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
|
||||||
|
#include "xla/tests/hlo_pjrt_test_base.h"
|
||||||
|
#include "xla/tests/literal_test_util.h"
|
||||||
#include "xla/tsl/platform/env.h"
|
#include "xla/tsl/platform/env.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
#include "xla/tsl/platform/test.h"
|
#include "xla/tsl/platform/test.h"
|
||||||
@@ -39,12 +46,17 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::absl_testing::IsOkAndHolds;
|
||||||
|
|
||||||
constexpr ErrorSpec kErrorSpec{0.001};
|
constexpr ErrorSpec kErrorSpec{0.001};
|
||||||
|
|
||||||
class ConditionalOpTest : public ClientLibraryTestRunnerMixin<HloTestBase> {
|
class ConditionalOpTest
|
||||||
|
: public ClientLibraryTestRunnerMixin<
|
||||||
|
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
ClientLibraryTestRunnerMixin<HloTestBase>::SetUp();
|
ClientLibraryTestRunnerMixin<
|
||||||
|
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>>::SetUp();
|
||||||
mutable_debug_options()->set_xla_test_add_command_buffer_mode(true);
|
mutable_debug_options()->set_xla_test_add_command_buffer_mode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,31 +224,37 @@ TEST_F(ConditionalOpTest, Parameters0) {
|
|||||||
|
|
||||||
// Test branch computations that do not take any parameters.
|
// Test branch computations that do not take any parameters.
|
||||||
TEST_P(CaseOpTest, Parameters0) {
|
TEST_P(CaseOpTest, Parameters0) {
|
||||||
int num_branches = GetParam();
|
const int num_branches = GetParam();
|
||||||
|
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
const XlaOp branch_index =
|
||||||
|
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
|
||||||
|
auto operand = Tuple(&builder, {});
|
||||||
|
std::vector<XlaOp> operands(num_branches, operand);
|
||||||
|
std::vector<XlaComputation> branches;
|
||||||
|
branches.reserve(num_branches);
|
||||||
|
std::vector<const XlaComputation*> branches_p(num_branches);
|
||||||
|
for (int i = 0; i < num_branches; ++i) {
|
||||||
|
branches.push_back(CreateR0ConstantComputation(static_cast<float>(i) * 10));
|
||||||
|
branches_p[i] = &branches[i];
|
||||||
|
}
|
||||||
|
Conditional(branch_index, branches_p, operands);
|
||||||
|
|
||||||
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
HloModuleFromXlaBuilder(&builder, execution_options()));
|
||||||
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
|
||||||
|
test_runner().CreateExecutable(std::move(module),
|
||||||
|
/*run_hlo_passes=*/true));
|
||||||
for (int bi = -1; bi <= num_branches; ++bi) {
|
for (int bi = -1; bi <= num_branches; ++bi) {
|
||||||
SCOPED_TRACE(bi);
|
SCOPED_TRACE(bi);
|
||||||
XlaBuilder builder(TestName());
|
const Literal expected = LiteralUtil::CreateR0<float>(
|
||||||
XlaOp branch_index;
|
10 * static_cast<float>(
|
||||||
auto branch_index_arg = CreateR0Parameter<int32_t>(
|
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi));
|
||||||
bi, 0, "branch_index_arg", &builder, &branch_index);
|
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
|
||||||
auto operand = Tuple(&builder, {});
|
ASSERT_OK_AND_ASSIGN(const Literal result,
|
||||||
|
test_runner().ExecuteWithExecutable(
|
||||||
std::vector<XlaOp> operands(num_branches, operand);
|
executable.get(), {&branch_index_arg}));
|
||||||
std::vector<XlaComputation> branches;
|
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
|
||||||
branches.reserve(num_branches);
|
|
||||||
std::vector<const XlaComputation*> branches_p(num_branches);
|
|
||||||
for (int i = 0; i < num_branches; ++i) {
|
|
||||||
branches.emplace_back(
|
|
||||||
CreateR0ConstantComputation(static_cast<float>(i) * 10));
|
|
||||||
branches_p[i] = &branches[i];
|
|
||||||
}
|
|
||||||
Conditional(branch_index, branches_p, operands);
|
|
||||||
|
|
||||||
float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
|
|
||||||
? num_branches - 1
|
|
||||||
: bi);
|
|
||||||
ComputeAndCompareR0<float>(&builder, expected, {&branch_index_arg},
|
|
||||||
kErrorSpec);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,40 +273,45 @@ TEST_F(ConditionalOpTest, Parameters1) {
|
|||||||
|
|
||||||
// Test branch computations that take in 1 parameter.
|
// Test branch computations that take in 1 parameter.
|
||||||
TEST_P(CaseOpTest, Parameters1) {
|
TEST_P(CaseOpTest, Parameters1) {
|
||||||
int num_branches = GetParam();
|
const int num_branches = GetParam();
|
||||||
|
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
const XlaOp branch_index =
|
||||||
|
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
|
||||||
|
std::vector<XlaComputation> branches;
|
||||||
|
branches.reserve(num_branches);
|
||||||
|
std::vector<const XlaComputation*> branches_p(num_branches);
|
||||||
|
std::vector<XlaOp> operands;
|
||||||
|
operands.reserve(num_branches);
|
||||||
|
std::vector<Literal> expecteds(num_branches);
|
||||||
|
for (int i = 0; i < num_branches; ++i) {
|
||||||
|
std::unique_ptr<XlaBuilder> sb =
|
||||||
|
builder.CreateSubBuilder(absl::StrCat("branch_", i));
|
||||||
|
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
|
||||||
|
Parameter(sb.get(), 0, r0f32_, "p0"));
|
||||||
|
branches.push_back(sb->BuildAndNoteError());
|
||||||
|
branches_p[i] = &branches[i];
|
||||||
|
const float fi = static_cast<float>(i);
|
||||||
|
operands.push_back(ConstantR0<float>(&builder, 10 * fi + 7));
|
||||||
|
expecteds[i] = LiteralUtil::CreateR0<float>(10 * fi + 7 + fi);
|
||||||
|
}
|
||||||
|
Conditional(branch_index, branches_p, operands);
|
||||||
|
|
||||||
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
HloModuleFromXlaBuilder(&builder, execution_options()));
|
||||||
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
|
||||||
|
test_runner().CreateExecutable(std::move(module),
|
||||||
|
/*run_hlo_passes=*/true));
|
||||||
for (int bi = -1; bi <= num_branches; ++bi) {
|
for (int bi = -1; bi <= num_branches; ++bi) {
|
||||||
SCOPED_TRACE(bi);
|
SCOPED_TRACE(bi);
|
||||||
XlaBuilder builder(TestName());
|
const Literal& expected = (bi < 0 || bi >= num_branches)
|
||||||
XlaOp branch_index;
|
? expecteds[num_branches - 1]
|
||||||
auto branch_index_arg = CreateR0Parameter<int32_t>(
|
: expecteds[bi];
|
||||||
bi, 0, "branch_index_arg", &builder, &branch_index);
|
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
|
||||||
|
ASSERT_OK_AND_ASSIGN(const Literal result,
|
||||||
auto make_branch = [&builder, this](int i) {
|
test_runner().ExecuteWithExecutable(
|
||||||
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
|
executable.get(), {&branch_index_arg}));
|
||||||
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
|
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
|
||||||
Parameter(sb.get(), 0, r0f32_, "p0"));
|
|
||||||
return sb->BuildAndNoteError();
|
|
||||||
};
|
|
||||||
std::vector<XlaComputation> branches;
|
|
||||||
branches.reserve(num_branches);
|
|
||||||
std::vector<const XlaComputation*> branches_p(num_branches);
|
|
||||||
std::vector<XlaOp> operands;
|
|
||||||
operands.reserve(num_branches);
|
|
||||||
std::vector<float> expecteds(num_branches);
|
|
||||||
for (int i = 0; i < num_branches; ++i) {
|
|
||||||
branches.emplace_back(make_branch(i));
|
|
||||||
branches_p[i] = &branches[i];
|
|
||||||
auto fi = static_cast<float>(i);
|
|
||||||
operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
|
|
||||||
expecteds[i] = 10 * fi + 7 + fi;
|
|
||||||
}
|
|
||||||
|
|
||||||
Conditional(branch_index, branches_p, operands);
|
|
||||||
float expected = (bi < 0 || bi >= num_branches)
|
|
||||||
? expecteds[num_branches - 1]
|
|
||||||
: expecteds[bi];
|
|
||||||
ComputeAndCompareR0<float>(&builder, expected, {&branch_index_arg},
|
|
||||||
kErrorSpec);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,38 +451,46 @@ TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
|
|||||||
|
|
||||||
// Test branch computations that take in 2 array parameters.
|
// Test branch computations that take in 2 array parameters.
|
||||||
TEST_P(CaseOpTest, Parameters2Array) {
|
TEST_P(CaseOpTest, Parameters2Array) {
|
||||||
int num_branches = GetParam();
|
const int num_branches = GetParam();
|
||||||
|
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
const XlaOp branch_index =
|
||||||
|
Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "branch_index_arg");
|
||||||
|
const XlaOp operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
|
||||||
|
const XlaOp operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
|
||||||
|
const XlaOp operands = Tuple(&builder, {operand1, operand2});
|
||||||
|
std::vector<XlaComputation> branches;
|
||||||
|
branches.reserve(num_branches);
|
||||||
|
std::vector<const XlaComputation*> branches_p(num_branches);
|
||||||
|
for (int i = 0; i < num_branches; ++i) {
|
||||||
|
std::unique_ptr<XlaBuilder> sb =
|
||||||
|
builder.CreateSubBuilder(absl::StrCat("branch_", i));
|
||||||
|
const XlaOp p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
|
||||||
|
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
|
||||||
|
GetTupleElement(p, 0)),
|
||||||
|
GetTupleElement(p, 1));
|
||||||
|
branches.push_back(sb->BuildAndNoteError());
|
||||||
|
branches_p[i] = &branches[i];
|
||||||
|
}
|
||||||
|
Conditional(branch_index, branches_p,
|
||||||
|
std::vector<XlaOp>(num_branches, operands));
|
||||||
|
|
||||||
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
HloModuleFromXlaBuilder(&builder, execution_options()));
|
||||||
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> executable,
|
||||||
|
test_runner().CreateExecutable(std::move(module),
|
||||||
|
/*run_hlo_passes=*/true));
|
||||||
for (int bi = -1; bi <= num_branches; ++bi) {
|
for (int bi = -1; bi <= num_branches; ++bi) {
|
||||||
SCOPED_TRACE(bi);
|
SCOPED_TRACE(bi);
|
||||||
XlaBuilder builder(TestName());
|
const Literal branch_index_arg = LiteralUtil::CreateR0<int32_t>(bi);
|
||||||
XlaOp branch_index;
|
const float modified_bi = static_cast<float>(
|
||||||
auto branch_index_arg =
|
|
||||||
CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index);
|
|
||||||
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
|
|
||||||
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
|
|
||||||
auto operands = Tuple(&builder, {operand1, operand2});
|
|
||||||
auto make_branch = [&builder, this](int i) {
|
|
||||||
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
|
|
||||||
auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
|
|
||||||
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
|
|
||||||
GetTupleElement(p, 0)),
|
|
||||||
GetTupleElement(p, 1));
|
|
||||||
return sb->BuildAndNoteError();
|
|
||||||
};
|
|
||||||
std::vector<XlaComputation> branches;
|
|
||||||
branches.reserve(num_branches);
|
|
||||||
std::vector<const XlaComputation*> branches_p(num_branches);
|
|
||||||
for (int i = 0; i < num_branches; ++i) {
|
|
||||||
branches.emplace_back(make_branch(i));
|
|
||||||
branches_p[i] = &branches[i];
|
|
||||||
}
|
|
||||||
Conditional(branch_index, branches_p,
|
|
||||||
std::vector<XlaOp>(num_branches, operands));
|
|
||||||
auto modified_bi = static_cast<float>(
|
|
||||||
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
|
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
|
||||||
ComputeAndCompareR1<float>(
|
const Literal expected = LiteralUtil::CreateR1<float>(
|
||||||
&builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
|
{24.0f * modified_bi + 10, 56.0f * modified_bi + 11});
|
||||||
{&branch_index_arg}, kErrorSpec);
|
ASSERT_OK_AND_ASSIGN(const Literal result,
|
||||||
|
test_runner().ExecuteWithExecutable(
|
||||||
|
executable.get(), {&branch_index_arg}));
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Near(expected, result, kErrorSpec));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -561,48 +592,53 @@ TEST_F(ConditionalOpTest, ReturnNestedTuple) {
|
|||||||
XlaBuilder true_builder(TestName() + ".true");
|
XlaBuilder true_builder(TestName() + ".true");
|
||||||
{
|
{
|
||||||
Parameter(&true_builder, 0, empty_tuple_, "tuple");
|
Parameter(&true_builder, 0, empty_tuple_, "tuple");
|
||||||
auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
|
const XlaOp true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
|
||||||
auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
|
const XlaOp true_constant2 =
|
||||||
auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
|
ConstantR1<float>(&true_builder, {12.8f, 14.6f});
|
||||||
auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
|
const XlaOp true_constant3 =
|
||||||
|
ConstantR1<float>(&true_builder, {25.4f, 29.8f});
|
||||||
|
const XlaOp true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
|
||||||
Tuple(&true_builder,
|
Tuple(&true_builder,
|
||||||
{Tuple(&true_builder, {true_constant1, true_constant2}),
|
{Tuple(&true_builder, {true_constant1, true_constant2}),
|
||||||
Tuple(&true_builder, {true_constant3, true_constant4})});
|
Tuple(&true_builder, {true_constant3, true_constant4})});
|
||||||
}
|
}
|
||||||
auto true_builder_result = true_builder.Build();
|
ASSERT_OK_AND_ASSIGN(XlaComputation true_comp, true_builder.Build());
|
||||||
EXPECT_IS_OK(true_builder_result.status());
|
|
||||||
|
|
||||||
XlaBuilder false_builder(TestName() + ".false");
|
XlaBuilder false_builder(TestName() + ".false");
|
||||||
{
|
{
|
||||||
Parameter(&false_builder, 0, empty_tuple_, "tuple");
|
Parameter(&false_builder, 0, empty_tuple_, "tuple");
|
||||||
auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
|
const XlaOp false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
|
||||||
auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
|
const XlaOp false_constant2 =
|
||||||
auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
|
ConstantR1<float>(&false_builder, {54.4f, 58.4f});
|
||||||
auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
|
const XlaOp false_constant3 =
|
||||||
|
ConstantR1<float>(&false_builder, {62.1f, 67.4f});
|
||||||
|
const XlaOp false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
|
||||||
Tuple(&false_builder,
|
Tuple(&false_builder,
|
||||||
{Tuple(&false_builder, {false_constant1, false_constant2}),
|
{Tuple(&false_builder, {false_constant1, false_constant2}),
|
||||||
Tuple(&false_builder, {false_constant3, false_constant4})});
|
Tuple(&false_builder, {false_constant3, false_constant4})});
|
||||||
}
|
}
|
||||||
auto false_builder_result = false_builder.Build();
|
ASSERT_OK_AND_ASSIGN(XlaComputation false_comp, false_builder.Build());
|
||||||
EXPECT_IS_OK(false_builder_result.status());
|
|
||||||
|
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
XlaOp pred;
|
XlaOp pred;
|
||||||
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
|
const Literal pred_arg =
|
||||||
auto operands = Tuple(&builder, {});
|
CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
|
||||||
Conditional(pred, operands, std::move(true_builder_result).value(), operands,
|
const XlaOp operands = Tuple(&builder, {});
|
||||||
std::move(false_builder_result).value());
|
const XlaOp result = Conditional(pred, operands, std::move(true_comp),
|
||||||
|
operands, std::move(false_comp));
|
||||||
|
// Flatten nested tuple for PjRt.
|
||||||
|
const XlaOp e0 = GetTupleElement(result, 0);
|
||||||
|
const XlaOp e1 = GetTupleElement(result, 1);
|
||||||
|
Tuple(&builder, {GetTupleElement(e0, 0), GetTupleElement(e0, 1),
|
||||||
|
GetTupleElement(e1, 0), GetTupleElement(e1, 1)});
|
||||||
|
|
||||||
ComputeAndCompareLiteral(
|
ComputeAndCompareLiteral(&builder,
|
||||||
&builder,
|
LiteralUtil::MakeTupleFromSlices(
|
||||||
LiteralUtil::MakeTupleFromSlices(
|
{LiteralUtil::CreateR0<float>(46.6f),
|
||||||
{LiteralUtil::MakeTupleFromSlices(
|
LiteralUtil::CreateR1<float>({54.4f, 58.4f}),
|
||||||
{LiteralUtil::CreateR0<float>(46.6f),
|
LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
|
||||||
LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
|
LiteralUtil::CreateR0<float>(9.3f)}),
|
||||||
LiteralUtil::MakeTupleFromSlices(
|
{&pred_arg}, kErrorSpec);
|
||||||
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
|
|
||||||
LiteralUtil::CreateR0<float>(9.3f)})}),
|
|
||||||
{&pred_arg}, kErrorSpec);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test conditional that takes in scalar operands in the form of external
|
// Test conditional that takes in scalar operands in the form of external
|
||||||
@@ -751,21 +787,31 @@ TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
|
|||||||
main = builder.Build().value();
|
main = builder.Build().value();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto test_swap = [&](float a, float b) {
|
XlaBuilder builder(TestName());
|
||||||
XlaBuilder builder(TestName());
|
const XlaOp x = Parameter(&builder, 0, r0f32_, "x");
|
||||||
XlaOp x, y;
|
const XlaOp y = Parameter(&builder, 1, r0f32_, "y");
|
||||||
auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
|
const XlaOp tuple_operand = Tuple(&builder, {x, y});
|
||||||
auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
|
Call(&builder, main, {tuple_operand});
|
||||||
auto tuple_operand = Tuple(&builder, {x, y});
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
Call(&builder, main, {tuple_operand});
|
HloModuleFromXlaBuilder(&builder, execution_options()));
|
||||||
ComputeAndCompareLiteral(
|
ASSERT_OK_AND_ASSIGN(
|
||||||
&builder,
|
std::unique_ptr<OpaqueExecutable> executable,
|
||||||
LiteralUtil::MakeTupleFromSlices(
|
CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
|
||||||
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
|
|
||||||
{&x_arg, &y_arg}, kErrorSpec);
|
const auto test_swap =
|
||||||
|
[&, this](float a,
|
||||||
|
float b) -> absl::StatusOr<::testing::AssertionResult> {
|
||||||
|
const Literal x_arg = LiteralUtil::CreateR0<float>(a);
|
||||||
|
const Literal y_arg = LiteralUtil::CreateR0<float>(b);
|
||||||
|
const Literal expected = LiteralUtil::MakeTupleFromSlices(
|
||||||
|
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)});
|
||||||
|
ASSIGN_OR_RETURN(const Literal result,
|
||||||
|
test_runner().ExecuteWithExecutable(executable.get(),
|
||||||
|
{&x_arg, &y_arg}));
|
||||||
|
return LiteralTestUtil::Near(expected, result, kErrorSpec);
|
||||||
};
|
};
|
||||||
test_swap(3.11f, 9.4f);
|
EXPECT_THAT(test_swap(3.11f, 9.4f), IsOkAndHolds(true));
|
||||||
test_swap(11.24f, 5.55f);
|
EXPECT_THAT(test_swap(11.24f, 5.55f), IsOkAndHolds(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test conditional that duplicates tuple elements in the then and else
|
// Test conditional that duplicates tuple elements in the then and else
|
||||||
@@ -792,35 +838,45 @@ TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
|
|||||||
else_comp = builder.Build().value();
|
else_comp = builder.Build().value();
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
XlaBuilder builder(TestName() + ".main");
|
||||||
// Pred is true case.
|
auto p0 = Parameter(&builder, 0, scalar, "p0.0");
|
||||||
std::vector<Literal> args;
|
auto p1 = Parameter(&builder, 1, scalar, "p0.1");
|
||||||
args.push_back(LiteralUtil::MakeTupleFromSlices(
|
auto p = Tuple(&builder, {p0, p1});
|
||||||
{LiteralUtil::CreateR0<int32_t>(123),
|
auto p_pred = Parameter(&builder, 2, ShapeUtil::MakeShape(PRED, {}), "p1");
|
||||||
LiteralUtil::CreateR0<int32_t>(-42)}));
|
Conditional(p_pred, p, then_comp, p, else_comp);
|
||||||
args.push_back(LiteralUtil::CreateR0<bool>(true));
|
ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
XlaBuilder builder(TestName() + ".main");
|
HloModuleFromXlaBuilder(&builder, execution_options()));
|
||||||
auto p = Parameter(&builder, 0, tuple2, "p0");
|
|
||||||
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
|
const std::array<Literal, 4> args = {
|
||||||
Conditional(p_pred, p, then_comp, p, else_comp);
|
LiteralUtil::CreateR0<int32_t>(123), LiteralUtil::CreateR0<int32_t>(-42),
|
||||||
ComputeAndCompare(&builder, {&args[0], &args[1]});
|
LiteralUtil::CreateR0<bool>(true), LiteralUtil::CreateR0<bool>(false)};
|
||||||
}
|
const std::array<const Literal*, 3> true_args = {&args[0], &args[1],
|
||||||
{
|
&args[2]};
|
||||||
// Pred is false case.
|
const std::array<const Literal*, 3> false_args = {&args[0], &args[1],
|
||||||
std::vector<Literal> args;
|
&args[3]};
|
||||||
args.push_back(LiteralUtil::MakeTupleFromSlices(
|
|
||||||
{LiteralUtil::CreateR0<int32_t>(123),
|
// Compute reference values. Because this test is not parameterized, we need
|
||||||
LiteralUtil::CreateR0<int32_t>(-42)}));
|
// to manually invoke the test runner and reference runner.
|
||||||
args.push_back(LiteralUtil::CreateR0<bool>(false));
|
ASSERT_OK_AND_ASSIGN(Literal true_reference,
|
||||||
XlaBuilder builder(TestName() + ".main");
|
reference_runner().Execute(module->Clone(), true_args,
|
||||||
auto p = Parameter(&builder, 0, tuple2, "p0");
|
/*run_hlo_passes=*/true));
|
||||||
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
|
ASSERT_OK_AND_ASSIGN(Literal false_reference,
|
||||||
Conditional(p_pred, p, then_comp, p, else_comp);
|
reference_runner().Execute(module->Clone(), false_args,
|
||||||
ComputeAndCompare(&builder, {&args[0], &args[1]});
|
/*run_hlo_passes=*/true));
|
||||||
}
|
|
||||||
|
ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<OpaqueExecutable> executable,
|
||||||
|
CreateExecutable(std::move(module), /*run_hlo_passes=*/true));
|
||||||
|
ASSERT_OK_AND_ASSIGN(Literal true_result, test_runner().ExecuteWithExecutable(
|
||||||
|
executable.get(), true_args));
|
||||||
|
ASSERT_OK_AND_ASSIGN(
|
||||||
|
Literal false_result,
|
||||||
|
test_runner().ExecuteWithExecutable(executable.get(), false_args));
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Equal(true_reference, true_result));
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Equal(false_reference, false_result));
|
||||||
}
|
}
|
||||||
|
|
||||||
using ConditionalOpHloTest = HloTestBase;
|
using ConditionalOpHloTest = HloPjRtTestBase;
|
||||||
|
|
||||||
TEST_F(ConditionalOpHloTest, ParallelExecution) {
|
TEST_F(ConditionalOpHloTest, ParallelExecution) {
|
||||||
// Test conditional works when an executable is executed in parallel.
|
// Test conditional works when an executable is executed in parallel.
|
||||||
|
|||||||
Reference in New Issue
Block a user