[XLA:GPU] Constant-fold broadcasts if they're the input to a gemm.

In general we don't want to fold broadcasts of scalars; it just makes the
program "less simple" and makes it require more memory.

But for the special case of A@B+C fusion, we do want to fold them.  Otherwise
we end up with situations where we run a fusion just to calculate the constant
addend C.

PiperOrigin-RevId: 452813130
This commit is contained in:
Justin Lebar
2022-06-03 11:56:57 -07:00
committed by TensorFlower Gardener
parent bde18b7101
commit bbd0b4f23d
3 changed files with 114 additions and 13 deletions

View File

@@ -1011,6 +1011,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/core:lib",

View File

@@ -15,12 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
#include <utility>
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -49,6 +52,57 @@ Status SetName(HloModule *module, HloInstruction *gemm) {
return ::tensorflow::OkStatus();
}
// If the bias is a sequence of ops that depend only on broadcasts of
// constants, materialize the bias if it's small.
//
// Normally the constant-folding pass would materialize the bias if it is
// calculated entirely from constants. But if the bias is a broadcast of a
// constant, constant-folding won't expand the broadcast, on the theory that
// folding broadcasts of constants causes us to consume more memory and can
// actually make things slower (because any op which reads the constant has
// to read more memory).
//
// OTOH in our case, we don't want to run an op that just broadcasts a
// constant so we can fuse it into this gemm. That would defeat the whole
// purpose of this fusion, which is to launch fewer kernels. So if we can,
// we expand out this constant ourselves.
//
// TODO(b/192499646): Even better would be to use cublasLT to fuse the
// broadcasted bias, if it supports that fusion efficiently.
HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
// This limit was not chosen carefully.
constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
// Don't fold broadcasts of scalars -- algsimp will just collapse it again.
auto is_nonscalar = [](const HloInstruction *instr) {
return !ShapeUtil::IsEffectiveScalar(instr->shape());
};
// For now, only fold broadcast(constant) or
// reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
// complexity in the constant-folding pass about what is and isn't legal to
// fold.
auto broadcast_of_nonscalar =
m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
(Match(bias, broadcast_of_nonscalar) ||
Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
HloEvaluator evaluator(/*max_loop_iterations=*/0);
Literal result;
if (evaluator.TryEvaluate(
bias, &result,
/*recursively_evaluate_nonconstant_operands=*/true)) {
return bias->parent()->AddInstruction(
HloInstruction::CreateConstant(std::move(result)));
}
}
return bias;
}
// The rewriting proceeds in a bottom-up way:
//
// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
@@ -182,20 +236,24 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
auto config =
existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
if (config.beta() == 0 && bias->user_count() == 1 &&
existing_gemm->user_count() == 1 &&
bias->shape() == existing_gemm->shape()) {
config.set_beta(1.0);
CHECK_EQ(existing_gemm->operand_count(), 2);
std::unique_ptr<HloInstruction> gemm_call =
existing_gemm->CloneWithNewOperands(
instr->shape(), {existing_gemm->mutable_operand(0),
existing_gemm->mutable_operand(1), bias});
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
TF_RETURN_IF_ERROR(
ReplaceWithNewInstruction(instr, std::move(gemm_call)));
if (config.beta() != 0 || bias->user_count() != 1 ||
existing_gemm->user_count() != 1 ||
bias->shape() != existing_gemm->shape()) {
return ::tensorflow::OkStatus();
}
config.set_beta(1.0);
CHECK_EQ(existing_gemm->operand_count(), 2);
std::unique_ptr<HloInstruction> gemm_call =
existing_gemm->CloneWithNewOperands(
instr->shape(), {
existing_gemm->mutable_operand(0),
existing_gemm->mutable_operand(1),
MaybeConstantFoldBias(bias),
});
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(gemm_call)));
return ::tensorflow::OkStatus();
}
};

View File

@@ -803,6 +803,48 @@ ENTRY test {
.WithShape(F32, {4})));
}
TEST_F(GemmRewriteHloTest, FoldConstantBias) {
const char* hlo_text = R"(
HloModule test
ENTRY test {
x = f32[2,2] parameter(0)
y = f32[2,2] parameter(1)
bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bias1 = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
sum1 = add(dot1, bias1)
dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
sum2 = add(dot2, f32[2,2] reshape(bias))
dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
bias3 = f32[2,2] transpose(bias), dimensions={1,0}
sum3 = add(dot3, bias3)
dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
sum4 = add(dot4, f32[2,2] bitcast(bias))
ROOT root = tuple(sum1, sum2, sum3, sum4)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass;
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
SCOPED_TRACE(module->ToString());
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
GmockMatch(m::Tuple(
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()))));
}
} // namespace
} // namespace gpu
} // namespace xla