diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc index 4673a90e0a9..38479d688d7 100644 --- a/tensorflow/compiler/plugin/executor/executable.cc +++ b/tensorflow/compiler/plugin/executor/executable.cc @@ -30,8 +30,8 @@ ExecutorExecutable::ExecutorExecutable(std::unique_ptr hlo_module) ExecutorExecutable::~ExecutorExecutable() {} -static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor, - const Literal& literal) { +static se::DeviceMemoryBase AllocateSingleOutput( + sep::ExecutorExecutor* executor, const Literal& literal) { int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); void* buf = executor->Allocate(size); const void* src = literal.InternalData(); @@ -39,8 +39,8 @@ static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor return se::DeviceMemoryBase(buf, size); } -static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor, - const Literal& literal) { +static se::DeviceMemoryBase AllocateOutputBuffer( + sep::ExecutorExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); if (shape.element_type() != xla::TUPLE) { return AllocateSingleOutput(executor, literal); @@ -96,7 +96,7 @@ StatusOr ExecutorExecutable::ExecuteOnStream( // Execute the graph using the evaluator HloEvaluator evaluator; TF_ASSIGN_OR_RETURN(std::unique_ptr output, - evaluator.Evaluate(computation, arg_literals_ptrs)); + evaluator.Evaluate(*computation, arg_literals_ptrs)); // Copy the result into the return buffer perftools::gputools::StreamExecutor* executor(stream->parent()); @@ -139,6 +139,5 @@ StatusOr ExecutorExecutable::ExecuteAsyncOnStream( return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - } // namespace executorplugin } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 1a2eed5f602..a51b1021798 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -51,9 +51,12 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant and Parameter operation. + // Skip Constant, Parameter, Reduce operation. + // TODO(b/35975797): Enable Reduce operation once arbitary computation are + // supported by the evaluator. if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant) { + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kReduce) { continue; } // Skip instructions with non-constant operands. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index a0c5cbe9160..e2a807595b4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -654,12 +654,262 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) override { + auto result_shape = dynamic_slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), + dynamic_slice->dynamic_slice_sizes())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + default: + LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + }; + + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override { + auto result_shape = dynamic_update_slice->shape(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + default: + LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + }; + + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override { + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + /*arg=*/arg->shape(), + /*init_value=*/init_value->shape(), + /*dimensions_to_reduce=*/dimensions, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = Literal::CreateFromShape(reduce->shape()); + + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + std::vector arg_dim_steps(arg_dimensions.size()); + std::vector arg_dim_counts(arg_dimensions.size()); + for (const int64 dim : dimensions) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + + // Create mapping from result index to arg index. + const int64 result_rank = ShapeUtil::Rank(result->shape()); + int64 result_dim = 0; + std::vector result_to_arg_index(result_rank); + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index[result_dim] = i; + ++result_dim; + } + } + + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = init_scalar; + + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + auto func = [&](const std::vector& input_index) { + auto curr_val = arg_literal.Get(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = Literal::CreateR0(curr_val); + auto result_val_literal = Literal::CreateR0(result_val); + std::vector args = {curr_val_literal.get(), + result_val_literal.get()}; + + // We need a new visitor for each evaluation, so that the same + // computation can be visited more than once (with different + // inputs). + HloEvaluator embedded_evaluator; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + // Assign computed result to result_val. + result_val = computed_result->Get({}); + + return true; + }; + + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); + return Status::OK(); + }; + Status Preprocess(HloInstruction* hlo) override { VLOG(2) << hlo->ToString(); return Status::OK(); }; private: + template + StatusOr> DynamicSlice( + const Literal& operand_literal, const Literal& start_indices_literal, + const Shape& result_shape) { + const auto& start_indices_typed = + start_indices_literal.GetArraySlice(); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + + std::vector operand_indices(start.size(), 0); + + auto result = Literal::CreateFromShape(result_shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::transform(multi_index.begin(), multi_index.end(), start.begin(), + operand_indices.begin(), std::plus()); + + return operand_literal.Get(operand_indices); + })); + + return std::move(result); + } + + template + StatusOr> DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + const Literal& start_indices_literal) { + const auto& start_indices_typed = + start_indices_literal.GetArraySlice(); + const std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + + auto result = MakeUnique(operand_literal); + std::vector result_index(ShapeUtil::Rank(result->shape()), 0); + + auto func = [&](const std::vector& update_index) { + std::transform(update_index.begin(), update_index.end(), start.begin(), + result_index.begin(), std::plus()); + + result->Set(result_index, + update_literal.Get(update_index)); + return true; + }; + + std::vector base(update_literal.shape().dimensions_size(), 0); + std::vector step(update_literal.shape().dimensions_size(), 1); + ShapeUtil::ForEachIndex(update_literal.shape(), base, + AsInt64Slice(update_literal.shape().dimensions()), + step, func); + + return std::move(result); + } + StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { @@ -771,14 +1021,28 @@ HloEvaluator::HloEvaluator() { } StatusOr> HloEvaluator::Evaluate( - HloComputation* computation, - tensorflow::gtl::ArraySlice args) { - arg_literals_ = args; + const HloModule& module, + tensorflow::gtl::ArraySlice arg_literals) { + XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); + + arg_literals_ = arg_literals; evaluated_.clear(); - TF_RETURN_IF_ERROR(computation->Accept(this)); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); + return MakeUnique( - GetEvaluatedLiteralFor(computation->root_instruction())); + GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())); +} + +StatusOr> HloEvaluator::Evaluate( + const HloComputation& computation, + tensorflow::gtl::ArraySlice arg_literals) { + arg_literals_ = arg_literals; + evaluated_.clear(); + + TF_RETURN_IF_ERROR(computation.Accept(this)); + return MakeUnique( + GetEvaluatedLiteralFor(computation.root_instruction())); } StatusOr> HloEvaluator::Evaluate( @@ -930,7 +1194,8 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, break; } default: - LOG(FATAL) << "unknown/unhandled primitive type."; + LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); } return Status::OK(); @@ -1009,7 +1274,8 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; default: - LOG(FATAL) << "unknown primitive type."; + LOG(FATAL) << "HandleCompare: unknown primitive type: " + << PrimitiveType_Name(lhs->shape().element_type()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 976a2325ea9..fbb385c40fa 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,9 +37,17 @@ namespace xla { class HloEvaluator : public DfsHloVisitorWithDefault { public: HloEvaluator(); - // Evaluates a HLO computation and an array of pointers to literals. - // Return the evaluated result as literal if successful. - // Precondition: argument literals are corresponds to the input computation's + // Evaluates an HLO module and an array of pointers to literals. + // Returns the evaluated result as a literal if successful. + // Precondition: argument literals correspond to each input computation's + // parameters in their post-ordering. See comment below for example. + StatusOr> Evaluate( + const HloModule& module, + tensorflow::gtl::ArraySlice arg_literals); + + // Evaluates an HLO computation and an array of pointers to literals. + // Returns the evaluated result as a literal if successful. + // Precondition: argument literals correspond to the input computation's // parameters in their post-ordering. For e.g., consider the following graph: // // * @@ -51,7 +60,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // The input literals array will have its first literal map to Parameter0 and // the second map to Parameter1. StatusOr> Evaluate( - HloComputation* computation, + const HloComputation& computation, tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 7269fbeffc5..d2770f6a612 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -188,7 +188,7 @@ TEST_F(HloEvaluatorTest, DoesAbs) { // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { - HloComputation::Builder builder(TestName()); + HloComputation::Builder b(TestName()); auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); @@ -205,9 +205,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { auto root_instruction = HloInstruction::CreateBinary( shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get()); - builder.AddInstruction(std::move(root_instruction)); + b.AddInstruction(std::move(root_instruction)); std::unique_ptr result = - evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), args).ConsumeValueOrDie(); auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); @@ -216,22 +216,22 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { // Verifies Reshape operation is correctly evaluated. TEST_F(HloEvaluatorTest, DoesReshape) { - HloComputation::Builder builder(TestName()); + HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralTestUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); - HloInstruction* literal_instruction = builder.AddInstruction( - HloInstruction::CreateConstant(std::move(literal))); + HloInstruction* literal_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); const int64 permutation[] = {1, 2, 0, 4, 3}; - builder.AddInstruction( + b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); std::unique_ptr result = - evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result->EachCell( @@ -243,24 +243,24 @@ TEST_F(HloEvaluatorTest, DoesReshape) { // Verifies Broadcast operation is correctly evaluated. TEST_F(HloEvaluatorTest, DoesBroadcast) { - HloComputation::Builder builder(TestName()); + HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = Literal::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); - HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - builder.AddInstruction(HloInstruction::CreateBroadcast( + b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, {1, 2})); std::unique_ptr result = - evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectEqual(*result, *output_literal); } TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { - HloComputation::Builder builder(TestName()); + HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = @@ -268,19 +268,18 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); - HloInstruction* constant = builder.AddInstruction( + HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - builder.AddInstruction( - HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); std::unique_ptr result = - evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectEqual(*result, *expected); } TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { - HloComputation::Builder builder(TestName()); + HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); @@ -289,13 +288,12 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); - HloInstruction* constant = builder.AddInstruction( + HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - builder.AddInstruction( - HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); std::unique_ptr result = - evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -355,7 +353,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); auto expected_array = MakeUnique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -398,7 +396,7 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { r2_padding_on_dim0_dim1)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = MakeUnique>(1, 5); @@ -442,7 +440,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { r2_padding_on_dim0_dim1)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); @@ -477,7 +475,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); // clang-format off auto expected_array = Array2D({ @@ -519,7 +517,7 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); auto expected = Literal::CreateR1({22.f, 28.f}); @@ -559,10 +557,13 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); auto expected_array = Array2D({ - {22.f, 28.f}, {58.f, 76.f}, {94.f, 124.f}, {130.f, 172.f}, + {22.f, 28.f}, + {58.f, 76.f}, + {94.f, 124.f}, + {130.f, 172.f}, }); auto expected = Literal::CreateR2FromArray2D(expected_array); @@ -606,7 +607,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { shape, lhs_instruction, rhs_instruction, window, dnums)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); @@ -660,7 +661,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { shape, lhs_instruction, rhs_instruction, window, dnums)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -736,7 +737,7 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { shape, lhs_instruction, rhs_instruction, window, dnums)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -793,7 +794,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { shape, lhs_instruction, rhs_instruction, window, dnums)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -856,7 +857,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { shape, lhs_instruction, rhs_instruction, window, dnums)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -927,7 +928,7 @@ TEST_F(HloEvaluatorTest, shape, lhs_instruction, rhs_instruction, window, dnums)); std::unique_ptr result = - evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -946,5 +947,115 @@ TEST_F(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, ReduceAdd) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[2,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // } + auto arg_array = MakeUnique>(2, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = add_computation.Build(); + + Shape shape = ShapeUtil::MakeShape(F32, {2}); + b.AddInstruction(HloInstruction::CreateReduce( + shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, + add_func.get())); + + std::unique_ptr result = + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR1({6, 18}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DynamicSlice) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[2,4] { + // { 1, 2, 3, 4 }, + // { 5, 6, 7, 8 }, + // } + auto operand_array = MakeUnique>(2, 4); + operand_array->FillUnique(1.0f); + auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + + HloInstruction* operand = b.AddInstruction( + HloInstruction::CreateConstant(std::move(operand_literal))); + + auto start_indices = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, + start_indices, {2, 3})); + + std::unique_ptr result = + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({ + {2, 3, 4}, + {6, 7, 8}, + }); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[2,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // } + auto operand_array = MakeUnique>(2, 3); + operand_array->FillUnique(1.0); + auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + + HloInstruction* operand = b.AddInstruction( + HloInstruction::CreateConstant(std::move(operand_literal))); + + auto start_indices = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + + auto update = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); + + Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); + b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + shape, operand, update, start_indices)); + + std::unique_ptr result = + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({ + {1, -2, -3}, + {5, -6, -7}, + }); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + } // namespace } // namespace xla