mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[XLA:GPU] Constant-fold broadcasts if they're the input to a gemm.
In general we don't want to fold broadcasts of scalars; it just makes the program "less simple" and makes it require more memory. But for the special case of A@B+C fusion, we do want to fold them. Otherwise we end up with situations where we run a fusion just to calculate the constant addend C. PiperOrigin-RevId: 452813130
This commit is contained in:
committed by
TensorFlower Gardener
parent
bde18b7101
commit
bbd0b4f23d
@@ -1011,6 +1011,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/core:lib",
|
||||
|
||||
@@ -15,12 +15,15 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
@@ -49,6 +52,57 @@ Status SetName(HloModule *module, HloInstruction *gemm) {
|
||||
return ::tensorflow::OkStatus();
|
||||
}
|
||||
|
||||
// If the bias is a sequence of ops that depend only on broadcasts of
|
||||
// constants, materialize the bias if it's small.
|
||||
//
|
||||
// Normally the constant-folding pass would materialize the bias if it is
|
||||
// calculated entirely from constants. But if the bias is a broadcast of a
|
||||
// constant, constant-folding won't expand the broadcast, on the theory that
|
||||
// folding broadcasts of constants causes us to consume more memory and can
|
||||
// actually make things slower (because any op which reads the constant has
|
||||
// to read more memory).
|
||||
//
|
||||
// OTOH in our case, we don't want to run an op that just broadcasts a
|
||||
// constant so we can fuse it into this gemm. That would defeat the whole
|
||||
// purpose of this fusion, which is to launch fewer kernels. So if we can,
|
||||
// we expand out this constant ourselves.
|
||||
//
|
||||
// TODO(b/192499646): Even better would be to use cublasLT to fuse the
|
||||
// broadcasted bias, if it supports that fusion efficiently.
|
||||
HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
|
||||
// This limit was not chosen carefully.
|
||||
constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
|
||||
|
||||
// Don't fold broadcasts of scalars -- algsimp will just collapse it again.
|
||||
auto is_nonscalar = [](const HloInstruction *instr) {
|
||||
return !ShapeUtil::IsEffectiveScalar(instr->shape());
|
||||
};
|
||||
|
||||
// For now, only fold broadcast(constant) or
|
||||
// reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
|
||||
// complexity in the constant-folding pass about what is and isn't legal to
|
||||
// fold.
|
||||
auto broadcast_of_nonscalar =
|
||||
m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
|
||||
|
||||
if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
|
||||
(Match(bias, broadcast_of_nonscalar) ||
|
||||
Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
|
||||
Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
|
||||
Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
|
||||
HloEvaluator evaluator(/*max_loop_iterations=*/0);
|
||||
Literal result;
|
||||
if (evaluator.TryEvaluate(
|
||||
bias, &result,
|
||||
/*recursively_evaluate_nonconstant_operands=*/true)) {
|
||||
return bias->parent()->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(result)));
|
||||
}
|
||||
}
|
||||
|
||||
return bias;
|
||||
}
|
||||
|
||||
// The rewriting proceeds in a bottom-up way:
|
||||
//
|
||||
// (kDot A B) is rewritten into a (kCustomCall:gemm A B)
|
||||
@@ -182,20 +236,24 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
}
|
||||
auto config =
|
||||
existing_gemm->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||
if (config.beta() == 0 && bias->user_count() == 1 &&
|
||||
existing_gemm->user_count() == 1 &&
|
||||
bias->shape() == existing_gemm->shape()) {
|
||||
config.set_beta(1.0);
|
||||
CHECK_EQ(existing_gemm->operand_count(), 2);
|
||||
std::unique_ptr<HloInstruction> gemm_call =
|
||||
existing_gemm->CloneWithNewOperands(
|
||||
instr->shape(), {existing_gemm->mutable_operand(0),
|
||||
existing_gemm->mutable_operand(1), bias});
|
||||
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
|
||||
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReplaceWithNewInstruction(instr, std::move(gemm_call)));
|
||||
if (config.beta() != 0 || bias->user_count() != 1 ||
|
||||
existing_gemm->user_count() != 1 ||
|
||||
bias->shape() != existing_gemm->shape()) {
|
||||
return ::tensorflow::OkStatus();
|
||||
}
|
||||
|
||||
config.set_beta(1.0);
|
||||
CHECK_EQ(existing_gemm->operand_count(), 2);
|
||||
std::unique_ptr<HloInstruction> gemm_call =
|
||||
existing_gemm->CloneWithNewOperands(
|
||||
instr->shape(), {
|
||||
existing_gemm->mutable_operand(0),
|
||||
existing_gemm->mutable_operand(1),
|
||||
MaybeConstantFoldBias(bias),
|
||||
});
|
||||
TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config));
|
||||
TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
|
||||
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(gemm_call)));
|
||||
return ::tensorflow::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -803,6 +803,48 @@ ENTRY test {
|
||||
.WithShape(F32, {4})));
|
||||
}
|
||||
|
||||
TEST_F(GemmRewriteHloTest, FoldConstantBias) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule test
|
||||
ENTRY test {
|
||||
x = f32[2,2] parameter(0)
|
||||
y = f32[2,2] parameter(1)
|
||||
bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
|
||||
|
||||
dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
bias1 = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0}
|
||||
sum1 = add(dot1, bias1)
|
||||
|
||||
dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
sum2 = add(dot2, f32[2,2] reshape(bias))
|
||||
|
||||
dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
bias3 = f32[2,2] transpose(bias), dimensions={1,0}
|
||||
sum3 = add(dot3, bias3)
|
||||
|
||||
dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
sum4 = add(dot4, f32[2,2] bitcast(bias))
|
||||
|
||||
ROOT root = tuple(sum1, sum2, sum3, sum4)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
GemmRewriter pass;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
|
||||
SCOPED_TRACE(module->ToString());
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Tuple(
|
||||
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
|
||||
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
|
||||
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()),
|
||||
m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
Reference in New Issue
Block a user