mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[xla:gpu] Only run XLA Triton passes on XLA fusions.
PiperOrigin-RevId: 713609640
This commit is contained in:
committed by
TensorFlower Gardener
parent
a4d4a095c8
commit
1f58545523
@@ -42,7 +42,8 @@ namespace gpu {
|
||||
// use, but that's not the case currently.
|
||||
absl::Status CreateTritonPipeline(
|
||||
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
|
||||
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info);
|
||||
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
|
||||
bool is_xla_fusion);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
@@ -41,14 +41,18 @@ namespace gpu {
|
||||
namespace mt = ::mlir::triton;
|
||||
namespace mt_xla = ::mlir::triton::xla;
|
||||
|
||||
absl::Status CreateTritonPipeline(
|
||||
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
|
||||
int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
|
||||
absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
|
||||
std::string arch_name, int num_warps,
|
||||
int num_ctas, int num_stages,
|
||||
mt::nvidia_gpu::ClusterInfo& out_cluster_info,
|
||||
bool is_xla_fusion) {
|
||||
auto cc = se::CudaComputeCapability(std::move(arch_name));
|
||||
const int ccAsInt = cc.major * 10 + cc.minor;
|
||||
const int threadsPerWarp = 32;
|
||||
|
||||
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
|
||||
if (is_xla_fusion) {
|
||||
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
|
||||
}
|
||||
|
||||
// Based on make_ttir() in
|
||||
// @triton//:third_party/nvidia/backend/compiler.py
|
||||
|
||||
@@ -55,9 +55,11 @@ using ::mlir::Type;
|
||||
using ::mlir::Value;
|
||||
using mlir::ValueRange;
|
||||
|
||||
absl::Status CreateTritonPipeline(
|
||||
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
|
||||
int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
|
||||
absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
|
||||
std::string arch_name, int num_warps,
|
||||
int num_ctas, int num_stages,
|
||||
mt::nvidia_gpu::ClusterInfo& out_cluster_info,
|
||||
bool is_xla_fusion) {
|
||||
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
|
||||
const int threadsPerWarp = 32;
|
||||
auto cc = se::RocmComputeCapability(std::move(arch_name));
|
||||
|
||||
@@ -24,7 +24,8 @@ namespace gpu {
|
||||
|
||||
absl::Status CreateTritonPipeline(
|
||||
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
|
||||
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) {
|
||||
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
|
||||
bool is_xla_fusion) {
|
||||
return absl::UnimplementedError("not supported for this build configuration");
|
||||
}
|
||||
|
||||
|
||||
@@ -1212,7 +1212,8 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
|
||||
const HloModule* hlo_module = fusion->GetModule();
|
||||
return CompileTritonToLLVM(hlo_module->config(), hlo_module->name(),
|
||||
device_info, block_level_parameters,
|
||||
triton_module.get(), llvm_module, mlir_context);
|
||||
triton_module.get(), llvm_module, mlir_context,
|
||||
/*is_xla_fusion=*/true);
|
||||
}
|
||||
|
||||
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
||||
@@ -1220,7 +1221,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
||||
const se::DeviceDescription& device_info,
|
||||
const BlockLevelParameters& block_level_parameters,
|
||||
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
|
||||
mlir::MLIRContext& mlir_context, bool emit_kernel) {
|
||||
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
|
||||
const auto& cc = device_info.gpu_compute_capability();
|
||||
const std::string arch_name =
|
||||
std::visit([](auto& cc) { return cc.ToString(); }, cc);
|
||||
@@ -1285,7 +1286,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
||||
mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
|
||||
if (!CreateTritonPipeline(&pm, arch_name, block_level_parameters.num_warps,
|
||||
block_level_parameters.num_ctas,
|
||||
block_level_parameters.num_stages, cluster_info)
|
||||
block_level_parameters.num_stages, cluster_info,
|
||||
is_xla_fusion)
|
||||
.ok()) {
|
||||
return Internal("Failed to create Triton pipeline.");
|
||||
}
|
||||
|
||||
@@ -87,7 +87,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
||||
const se::DeviceDescription& device_info,
|
||||
const BlockLevelParameters& block_level_parameters,
|
||||
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
|
||||
mlir::MLIRContext& mlir_context, bool emit_kernel = true);
|
||||
mlir::MLIRContext& mlir_context, bool is_xla_fusion,
|
||||
bool emit_kernel = true);
|
||||
|
||||
std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
|
||||
const se::DeviceDescription& device_info);
|
||||
|
||||
@@ -74,7 +74,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
|
||||
const se::DeviceDescription& device_info,
|
||||
const BlockLevelParameters& block_level_parameters,
|
||||
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
|
||||
mlir::MLIRContext& mlir_context, bool emit_kernel) {
|
||||
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
|
||||
return absl::UnimplementedError("not supported for this build configuration");
|
||||
}
|
||||
|
||||
|
||||
@@ -44,13 +44,16 @@ TEST(TritonStub, CallStubApi) {
|
||||
LoadMlirDialectsForTriton(context);
|
||||
EXPECT_FALSE(TritonWrapper({}, nullptr, {}, {}, {}, nullptr, context).ok());
|
||||
EXPECT_FALSE(CreateTritonModule({}, nullptr, {}, {}, context).ok());
|
||||
EXPECT_FALSE(
|
||||
CompileTritonToLLVM({}, {}, {}, {}, {}, nullptr, context, {}).ok());
|
||||
EXPECT_FALSE(CompileTritonToLLVM({}, {}, {}, {}, {}, nullptr, context,
|
||||
/*is_xla_fusion=*/true, {})
|
||||
.ok());
|
||||
|
||||
mlir::OpPassManager pm;
|
||||
::mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
|
||||
|
||||
EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info).ok());
|
||||
EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info,
|
||||
/*is_xla_fusion=*/true)
|
||||
.ok());
|
||||
EXPECT_EQ(GetLibdevicePath({}, {}), "");
|
||||
|
||||
EmitterLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context);
|
||||
|
||||
@@ -1421,7 +1421,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
|
||||
ir_emitter_context_->gpu_device_info(),
|
||||
block_level_parameters, triton_module.get(),
|
||||
ir_emitter_context_->llvm_module(), mlir_context,
|
||||
emit_kernels));
|
||||
/*is_xla_fusion=*/false, emit_kernels));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto kernel_arguments,
|
||||
|
||||
Reference in New Issue
Block a user