diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 81130d672c2..195854da45f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1011,6 +1011,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 76af3760bdf..bb949ee4b6b 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -15,12 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" +#include + #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -49,6 +52,57 @@ Status SetName(HloModule *module, HloInstruction *gemm) { return ::tensorflow::OkStatus(); } +// If the bias is a sequence of ops that depend only on broadcasts of +// constants, materialize the bias if it's small. +// +// Normally the constant-folding pass would materialize the bias if it is +// calculated entirely from constants. But if the bias is a broadcast of a +// constant, constant-folding won't expand the broadcast, on the theory that +// folding broadcasts of constants causes us to consume more memory and can +// actually make things slower (because any op which reads the constant has +// to read more memory). +// +// OTOH in our case, we don't want to run an op that just broadcasts a +// constant so we can fuse it into this gemm. That would defeat the whole +// purpose of this fusion, which is to launch fewer kernels. So if we can, +// we expand out this constant ourselves. +// +// TODO(b/192499646): Even better would be to use cublasLT to fuse the +// broadcasted bias, if it supports that fusion efficiently. +HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) { + // This limit was not chosen carefully. + constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024; + + // Don't fold broadcasts of scalars -- algsimp will just collapse it again. + auto is_nonscalar = [](const HloInstruction *instr) { + return !ShapeUtil::IsEffectiveScalar(instr->shape()); + }; + + // For now, only fold broadcast(constant) or + // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the + // complexity in the constant-folding pass about what is and isn't legal to + // fold. + auto broadcast_of_nonscalar = + m::Broadcast(m::Constant().WithPredicate(is_nonscalar)); + + if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes && + (Match(bias, broadcast_of_nonscalar) || + Match(bias, m::Reshape(broadcast_of_nonscalar)) || + Match(bias, m::Transpose(broadcast_of_nonscalar)) || + Match(bias, m::Bitcast(broadcast_of_nonscalar)))) { + HloEvaluator evaluator(/*max_loop_iterations=*/0); + Literal result; + if (evaluator.TryEvaluate( + bias, &result, + /*recursively_evaluate_nonconstant_operands=*/true)) { + return bias->parent()->AddInstruction( + HloInstruction::CreateConstant(std::move(result))); + } + } + + return bias; +} + // The rewriting proceeds in a bottom-up way: // // (kDot A B) is rewritten into a (kCustomCall:gemm A B) @@ -182,20 +236,24 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } auto config = existing_gemm->backend_config().ValueOrDie(); - if (config.beta() == 0 && bias->user_count() == 1 && - existing_gemm->user_count() == 1 && - bias->shape() == existing_gemm->shape()) { - config.set_beta(1.0); - CHECK_EQ(existing_gemm->operand_count(), 2); - std::unique_ptr gemm_call = - existing_gemm->CloneWithNewOperands( - instr->shape(), {existing_gemm->mutable_operand(0), - existing_gemm->mutable_operand(1), bias}); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config)); - TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get())); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(instr, std::move(gemm_call))); + if (config.beta() != 0 || bias->user_count() != 1 || + existing_gemm->user_count() != 1 || + bias->shape() != existing_gemm->shape()) { + return ::tensorflow::OkStatus(); } + + config.set_beta(1.0); + CHECK_EQ(existing_gemm->operand_count(), 2); + std::unique_ptr gemm_call = + existing_gemm->CloneWithNewOperands( + instr->shape(), { + existing_gemm->mutable_operand(0), + existing_gemm->mutable_operand(1), + MaybeConstantFoldBias(bias), + }); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config)); + TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get())); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(gemm_call))); return ::tensorflow::OkStatus(); } }; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 55ab775af2c..fa02d93ff56 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -803,6 +803,48 @@ ENTRY test { .WithShape(F32, {4}))); } +TEST_F(GemmRewriteHloTest, FoldConstantBias) { + const char* hlo_text = R"( +HloModule test +ENTRY test { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0} + + dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bias1 = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0} + sum1 = add(dot1, bias1) + + dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + sum2 = add(dot2, f32[2,2] reshape(bias)) + + dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bias3 = f32[2,2] transpose(bias), dimensions={1,0} + sum3 = add(dot3, bias3) + + dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + sum4 = add(dot4, f32[2,2] bitcast(bias)) + + ROOT root = tuple(sum1, sum2, sum3, sum4) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + SCOPED_TRACE(module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant())))); +} + } // namespace } // namespace gpu } // namespace xla