[xla:gpu] Only run XLA Triton passes on XLA fusions.

PiperOrigin-RevId: 713609640
This commit is contained in:
Chris Jones
2025-01-09 03:25:21 -08:00
committed by TensorFlower Gardener
parent a4d4a095c8
commit 1f58545523
9 changed files with 32 additions and 18 deletions

View File

@@ -42,7 +42,8 @@ namespace gpu {
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info);
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
bool is_xla_fusion);
} // namespace gpu
} // namespace xla

View File

@@ -41,14 +41,18 @@ namespace gpu {
namespace mt = ::mlir::triton;
namespace mt_xla = ::mlir::triton::xla;
absl::Status CreateTritonPipeline(
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
std::string arch_name, int num_warps,
int num_ctas, int num_stages,
mt::nvidia_gpu::ClusterInfo& out_cluster_info,
bool is_xla_fusion) {
auto cc = se::CudaComputeCapability(std::move(arch_name));
const int ccAsInt = cc.major * 10 + cc.minor;
const int threadsPerWarp = 32;
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
if (is_xla_fusion) {
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
}
// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py

View File

@@ -55,9 +55,11 @@ using ::mlir::Type;
using ::mlir::Value;
using mlir::ValueRange;
absl::Status CreateTritonPipeline(
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
std::string arch_name, int num_warps,
int num_ctas, int num_stages,
mt::nvidia_gpu::ClusterInfo& out_cluster_info,
bool is_xla_fusion) {
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
const int threadsPerWarp = 32;
auto cc = se::RocmComputeCapability(std::move(arch_name));

View File

@@ -24,7 +24,8 @@ namespace gpu {
absl::Status CreateTritonPipeline(
mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) {
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
bool is_xla_fusion) {
return absl::UnimplementedError("not supported for this build configuration");
}

View File

@@ -1212,7 +1212,8 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
const HloModule* hlo_module = fusion->GetModule();
return CompileTritonToLLVM(hlo_module->config(), hlo_module->name(),
device_info, block_level_parameters,
triton_module.get(), llvm_module, mlir_context);
triton_module.get(), llvm_module, mlir_context,
/*is_xla_fusion=*/true);
}
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
@@ -1220,7 +1221,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const se::DeviceDescription& device_info,
const BlockLevelParameters& block_level_parameters,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context, bool emit_kernel) {
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
const auto& cc = device_info.gpu_compute_capability();
const std::string arch_name =
std::visit([](auto& cc) { return cc.ToString(); }, cc);
@@ -1285,7 +1286,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
if (!CreateTritonPipeline(&pm, arch_name, block_level_parameters.num_warps,
block_level_parameters.num_ctas,
block_level_parameters.num_stages, cluster_info)
block_level_parameters.num_stages, cluster_info,
is_xla_fusion)
.ok()) {
return Internal("Failed to create Triton pipeline.");
}

View File

@@ -87,7 +87,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const se::DeviceDescription& device_info,
const BlockLevelParameters& block_level_parameters,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context, bool emit_kernel = true);
mlir::MLIRContext& mlir_context, bool is_xla_fusion,
bool emit_kernel = true);
std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info);

View File

@@ -74,7 +74,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const se::DeviceDescription& device_info,
const BlockLevelParameters& block_level_parameters,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context, bool emit_kernel) {
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
return absl::UnimplementedError("not supported for this build configuration");
}

View File

@@ -44,13 +44,16 @@ TEST(TritonStub, CallStubApi) {
LoadMlirDialectsForTriton(context);
EXPECT_FALSE(TritonWrapper({}, nullptr, {}, {}, {}, nullptr, context).ok());
EXPECT_FALSE(CreateTritonModule({}, nullptr, {}, {}, context).ok());
EXPECT_FALSE(
CompileTritonToLLVM({}, {}, {}, {}, {}, nullptr, context, {}).ok());
EXPECT_FALSE(CompileTritonToLLVM({}, {}, {}, {}, {}, nullptr, context,
/*is_xla_fusion=*/true, {})
.ok());
mlir::OpPassManager pm;
::mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info).ok());
EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info,
/*is_xla_fusion=*/true)
.ok());
EXPECT_EQ(GetLibdevicePath({}, {}), "");
EmitterLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context);

View File

@@ -1421,7 +1421,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
ir_emitter_context_->gpu_device_info(),
block_level_parameters, triton_module.get(),
ir_emitter_context_->llvm_module(), mlir_context,
emit_kernels));
/*is_xla_fusion=*/false, emit_kernels));
TF_ASSIGN_OR_RETURN(
auto kernel_arguments,