Refactor: Pass GpuComputeCapability to GetBlasComputationType

This change modifies `gpu::GetBlasComputationType` to accept a `GpuComputeCapability`. This allows the function to make platform-specific decisions, such as enabling TF32, without relying on preprocessor macros like `GOOGLE_CUDA`. Call sites are updated to pass the appropriate compute capability.

PiperOrigin-RevId: 846617600
This commit is contained in:
Henning Becker
2025-12-19 01:11:22 -08:00
committed by TensorFlower Gardener
parent 35dc0c4552
commit 458fe6aa5a
10 changed files with 46 additions and 37 deletions

View File

@@ -104,7 +104,8 @@ CublasBackend::GetSupportedConfigs(const HloInstruction& instr) {
out_desc.compute_type,
se::gpu::GetBlasComputationType(
gemm_config.precision_algorithm, gemm_config.lhs_layout.dtype,
gemm_config.output_layout.dtype, gemm_config.compute_precision));
gemm_config.output_layout.dtype, gemm_config.compute_precision,
target_config().device_description.gpu_compute_capability()));
se::blas::BlasSupport* blas = stream_executor()->AsBlas();
if (blas == nullptr) {

View File

@@ -1121,6 +1121,8 @@ cc_library(
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_blas_lt",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
@@ -1128,8 +1130,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)

View File

@@ -31,7 +31,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
@@ -50,11 +49,11 @@ limitations under the License.
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/types.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
namespace xla {
namespace gpu {
@@ -479,7 +478,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm,
absl::StatusOr<GemmConfig::DescriptorsTuple> GemmConfig::GetMatrixDescriptors(
se::DeviceAddressBase lhs_buf, se::DeviceAddressBase rhs_buf,
se::DeviceAddressBase out_buf) const {
se::DeviceAddressBase out_buf,
const se::GpuComputeCapability& gpu_version) const {
auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout,
se::DeviceAddressBase data)
-> absl::StatusOr<se::gpu::MatrixDescriptor> {
@@ -512,7 +512,7 @@ absl::StatusOr<GemmConfig::DescriptorsTuple> GemmConfig::GetMatrixDescriptors(
TF_ASSIGN_OR_RETURN(out_desc.compute_type,
se::gpu::GetBlasComputationType(
PrecisionConfig::ALG_UNSET, lhs.dtype, out.dtype,
se::blas::kDefaultComputePrecision));
se::blas::kDefaultComputePrecision, gpu_version));
TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor lhs_desc,
create_matrix_desc(lhs, lhs_buf));
@@ -541,8 +541,9 @@ absl::Status DoGemmWithAlgorithm(const se::gpu::MatrixDescriptor& lhs,
PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
TF_ASSIGN_OR_RETURN(
se::blas::ComputationType computation_type,
se::gpu::GetBlasComputationType(precision_algorithm, lhs_type,
output_type, compute_precision));
se::gpu::GetBlasComputationType(
precision_algorithm, lhs_type, output_type, compute_precision,
stream->parent()->GetDeviceDescription().gpu_compute_capability()));
se::DeviceAddress<Output> output_data(output.data);
// Set a workspace for all Blas operations launched below.
@@ -626,7 +627,9 @@ absl::Status RunGemm(const GemmConfig& config, se::DeviceAddressBase lhs_buffer,
TF_ASSIGN_OR_RETURN(
GemmConfig::DescriptorsTuple desc,
config.GetMatrixDescriptors(lhs_buffer, rhs_buffer, output_buffer));
config.GetMatrixDescriptors(
lhs_buffer, rhs_buffer, output_buffer,
stream->parent()->GetDeviceDescription().gpu_compute_capability()));
se::EngineOptions engine_options{
deterministic_ops,

View File

@@ -35,6 +35,7 @@ limitations under the License.
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
#include "xla/stream_executor/stream.h"
#include "xla/xla_data.pb.h"
namespace xla {
@@ -146,7 +147,8 @@ struct GemmConfig : public se::gpu::GemmConfig {
};
absl::StatusOr<DescriptorsTuple> GetMatrixDescriptors(
se::DeviceAddressBase lhs_buf, se::DeviceAddressBase rhs_buf,
se::DeviceAddressBase out_buf) const;
se::DeviceAddressBase out_buf,
const se::GpuComputeCapability& gpu_version) const;
};
// Run the given GEMM instruction `gemm` subject to the configuration

View File

@@ -2095,7 +2095,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const se::blas::ComputationType compute_type,
se::gpu::GetBlasComputationType(
instr.precision_config().algorithm(), a_dtype, output_type,
stream_executor::blas::kDefaultComputePrecision));
stream_executor::blas::kDefaultComputePrecision, gpu_version_));
se::blas::DataType scale_type =
se::gpu::GetScaleType(output_dtype, compute_type);
@@ -2193,10 +2193,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return false;
}
TF_ASSIGN_OR_RETURN(
const se::blas::ComputationType compute_type,
se::gpu::GetBlasComputationType(
algorithm, a_dtype, instr.shape().element_type(), max_precision));
TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
se::gpu::GetBlasComputationType(
algorithm, a_dtype, instr.shape().element_type(),
max_precision, gpu_version_));
se::blas::DataType scale_type =
se::gpu::GetScaleType(output_dtype, compute_type);

View File

@@ -318,10 +318,12 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg,
auto compute_type = cfg.compute_type;
if (!compute_type) { // obtain compute_type unless provided by the user
TF_ASSIGN_OR_RETURN(compute_type,
gpu::GetBlasComputationType(
cfg.precision_algorithm, lhs_layout.dtype,
output_layout.dtype, cfg.compute_precision));
TF_ASSIGN_OR_RETURN(
compute_type,
gpu::GetBlasComputationType(
cfg.precision_algorithm, lhs_layout.dtype, output_layout.dtype,
cfg.compute_precision,
parent_->GetDeviceDescription().gpu_compute_capability()));
}
// FP8 matmuls have a fast accumulation mode that is less precise than the

View File

@@ -409,7 +409,6 @@ cc_library(
name = "gpu_blas_lt",
srcs = ["gpu_blas_lt.cc"],
hdrs = ["gpu_blas_lt.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":gpu_blas_lt_proto_cc",
"//xla:shape_util",
@@ -420,8 +419,11 @@ cc_library(
"//xla/service:algorithm_util",
"//xla/stream_executor:blas",
"//xla/stream_executor:device_address",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/tsl/platform:statusor",
"//xla/tsl/protobuf:dnn_proto_cc",
"@com_google_absl//absl/algorithm:container",
@@ -435,9 +437,8 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:statusor",
] + if_cuda_is_configured([
"@local_tsl//tsl/platform:tensor_float_32_hdr_lib",
]) + if_static([
] + if_static([
"@local_tsl//tsl/platform:tensor_float_32_utils",
]),
)

View File

@@ -30,6 +30,7 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_blas_lt.pb.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
@@ -37,9 +38,7 @@ limitations under the License.
#include "xla/tsl/protobuf/dnn.pb.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#if GOOGLE_CUDA
#include "tsl/platform/tensor_float_32_utils.h"
#endif
namespace stream_executor {
@@ -205,7 +204,8 @@ xla::GemmConfigProto::MatrixLayout MatrixLayout::ToProto() const {
absl::StatusOr<ComputationType> GetBlasComputationType(
xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype,
xla::PrimitiveType output_dtype, int64_t compute_precision) {
xla::PrimitiveType output_dtype, int64_t compute_precision,
const GpuComputeCapability& cc) {
if (algorithm == xla::PrecisionConfig::ALG_UNSET) {
switch (output_dtype) {
case PrimitiveType::F8E5M2: // fall-through
@@ -222,14 +222,12 @@ absl::StatusOr<ComputationType> GetBlasComputationType(
return ComputationType::kF32;
case PrimitiveType::F32: // fall-through
case PrimitiveType::C64:
#if GOOGLE_CUDA
if (tsl::tensor_float_32_execution_enabled() &&
if (cc.IsCuda() && tsl::tensor_float_32_execution_enabled() &&
compute_precision <= 1 && lhs_dtype == output_dtype) {
// CublasLt requires compute type to be F32 for F8 matmul.
// TF32 should only be chosen for FP32 or C64 gemm
return ComputationType::kTF32AsF32;
}
#endif
return ComputationType::kF32;
case PrimitiveType::F64: // fall-through
case PrimitiveType::C128:

View File

@@ -32,6 +32,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_blas_lt.pb.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"
@@ -44,7 +45,8 @@ absl::StatusOr<xla::PrimitiveType> AsXlaPrimitiveType(blas::DataType dtype);
absl::StatusOr<blas::ComputationType> GetBlasComputationType(
xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype,
xla::PrimitiveType output_dtype, int64_t compute_precision);
xla::PrimitiveType output_dtype, int64_t compute_precision,
const GpuComputeCapability& cc);
// Returns the type for the alpha and beta scalars.
blas::DataType GetScaleType(blas::DataType c_type,

View File

@@ -33,7 +33,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "Eigen/Core"
#include "rocm/include/hip/library_types.h"
#include "rocm/include/hipblas/hipblas.h"
#include "rocm/include/hipblaslt/hipblaslt.h"
@@ -55,7 +54,6 @@ limitations under the License.
#include "xla/tsl/platform/statusor.h"
#include "xla/types.h"
#include "xla/util.h"
#include "tsl/platform/ml_dtypes.h"
#define SET_ATTR(setter, handle, attr, value) \
ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter)
@@ -326,10 +324,12 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const
auto compute_type = cfg.compute_type;
if (!compute_type) { // obtain compute_type unless provided by the user
TF_ASSIGN_OR_RETURN(compute_type,
gpu::GetBlasComputationType(
cfg.precision_algorithm, lhs_layout.dtype,
output_layout.dtype, cfg.compute_precision));
TF_ASSIGN_OR_RETURN(
compute_type,
gpu::GetBlasComputationType(
cfg.precision_algorithm, lhs_layout.dtype, output_layout.dtype,
cfg.compute_precision,
parent_->GetDeviceDescription().gpu_compute_capability()));
}
if (lhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) {