mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Refactor: Pass GpuComputeCapability to GetBlasComputationType
This change modifies `gpu::GetBlasComputationType` to accept a `GpuComputeCapability`. This allows the function to make platform-specific decisions, such as enabling TF32, without relying on preprocessor macros like `GOOGLE_CUDA`. Call sites are updated to pass the appropriate compute capability. PiperOrigin-RevId: 846617600
This commit is contained in:
committed by
TensorFlower Gardener
parent
35dc0c4552
commit
458fe6aa5a
@@ -104,7 +104,8 @@ CublasBackend::GetSupportedConfigs(const HloInstruction& instr) {
|
||||
out_desc.compute_type,
|
||||
se::gpu::GetBlasComputationType(
|
||||
gemm_config.precision_algorithm, gemm_config.lhs_layout.dtype,
|
||||
gemm_config.output_layout.dtype, gemm_config.compute_precision));
|
||||
gemm_config.output_layout.dtype, gemm_config.compute_precision,
|
||||
target_config().device_description.gpu_compute_capability()));
|
||||
|
||||
se::blas::BlasSupport* blas = stream_executor()->AsBlas();
|
||||
if (blas == nullptr) {
|
||||
|
||||
4
third_party/xla/xla/service/gpu/BUILD
vendored
4
third_party/xla/xla/service/gpu/BUILD
vendored
@@ -1121,6 +1121,8 @@ cc_library(
|
||||
"//xla/stream_executor:stream",
|
||||
"//xla/stream_executor:stream_executor_h",
|
||||
"//xla/stream_executor/gpu:gpu_blas_lt",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
@@ -1128,8 +1130,6 @@ cc_library(
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@local_tsl//tsl/platform:errors",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
19
third_party/xla/xla/service/gpu/matmul_utils.cc
vendored
19
third_party/xla/xla/service/gpu/matmul_utils.cc
vendored
@@ -31,7 +31,6 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/autotuning.pb.h"
|
||||
#include "xla/hlo/ir/hlo_casting_utils.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
@@ -50,11 +49,11 @@ limitations under the License.
|
||||
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
|
||||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/stream_executor/stream_executor.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/types.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
#include "tsl/platform/errors.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@@ -479,7 +478,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm,
|
||||
|
||||
absl::StatusOr<GemmConfig::DescriptorsTuple> GemmConfig::GetMatrixDescriptors(
|
||||
se::DeviceAddressBase lhs_buf, se::DeviceAddressBase rhs_buf,
|
||||
se::DeviceAddressBase out_buf) const {
|
||||
se::DeviceAddressBase out_buf,
|
||||
const se::GpuComputeCapability& gpu_version) const {
|
||||
auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout,
|
||||
se::DeviceAddressBase data)
|
||||
-> absl::StatusOr<se::gpu::MatrixDescriptor> {
|
||||
@@ -512,7 +512,7 @@ absl::StatusOr<GemmConfig::DescriptorsTuple> GemmConfig::GetMatrixDescriptors(
|
||||
TF_ASSIGN_OR_RETURN(out_desc.compute_type,
|
||||
se::gpu::GetBlasComputationType(
|
||||
PrecisionConfig::ALG_UNSET, lhs.dtype, out.dtype,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
se::blas::kDefaultComputePrecision, gpu_version));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor lhs_desc,
|
||||
create_matrix_desc(lhs, lhs_buf));
|
||||
@@ -541,8 +541,9 @@ absl::Status DoGemmWithAlgorithm(const se::gpu::MatrixDescriptor& lhs,
|
||||
PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::blas::ComputationType computation_type,
|
||||
se::gpu::GetBlasComputationType(precision_algorithm, lhs_type,
|
||||
output_type, compute_precision));
|
||||
se::gpu::GetBlasComputationType(
|
||||
precision_algorithm, lhs_type, output_type, compute_precision,
|
||||
stream->parent()->GetDeviceDescription().gpu_compute_capability()));
|
||||
se::DeviceAddress<Output> output_data(output.data);
|
||||
|
||||
// Set a workspace for all Blas operations launched below.
|
||||
@@ -626,7 +627,9 @@ absl::Status RunGemm(const GemmConfig& config, se::DeviceAddressBase lhs_buffer,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
GemmConfig::DescriptorsTuple desc,
|
||||
config.GetMatrixDescriptors(lhs_buffer, rhs_buffer, output_buffer));
|
||||
config.GetMatrixDescriptors(
|
||||
lhs_buffer, rhs_buffer, output_buffer,
|
||||
stream->parent()->GetDeviceDescription().gpu_compute_capability()));
|
||||
|
||||
se::EngineOptions engine_options{
|
||||
deterministic_ops,
|
||||
|
||||
@@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "xla/stream_executor/device_address.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
|
||||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@@ -146,7 +147,8 @@ struct GemmConfig : public se::gpu::GemmConfig {
|
||||
};
|
||||
absl::StatusOr<DescriptorsTuple> GetMatrixDescriptors(
|
||||
se::DeviceAddressBase lhs_buf, se::DeviceAddressBase rhs_buf,
|
||||
se::DeviceAddressBase out_buf) const;
|
||||
se::DeviceAddressBase out_buf,
|
||||
const se::GpuComputeCapability& gpu_version) const;
|
||||
};
|
||||
|
||||
// Run the given GEMM instruction `gemm` subject to the configuration
|
||||
|
||||
@@ -2095,7 +2095,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
const se::blas::ComputationType compute_type,
|
||||
se::gpu::GetBlasComputationType(
|
||||
instr.precision_config().algorithm(), a_dtype, output_type,
|
||||
stream_executor::blas::kDefaultComputePrecision));
|
||||
stream_executor::blas::kDefaultComputePrecision, gpu_version_));
|
||||
se::blas::DataType scale_type =
|
||||
se::gpu::GetScaleType(output_dtype, compute_type);
|
||||
|
||||
@@ -2193,10 +2193,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const se::blas::ComputationType compute_type,
|
||||
se::gpu::GetBlasComputationType(
|
||||
algorithm, a_dtype, instr.shape().element_type(), max_precision));
|
||||
TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
|
||||
se::gpu::GetBlasComputationType(
|
||||
algorithm, a_dtype, instr.shape().element_type(),
|
||||
max_precision, gpu_version_));
|
||||
se::blas::DataType scale_type =
|
||||
se::gpu::GetScaleType(output_dtype, compute_type);
|
||||
|
||||
|
||||
@@ -318,10 +318,12 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg,
|
||||
|
||||
auto compute_type = cfg.compute_type;
|
||||
if (!compute_type) { // obtain compute_type unless provided by the user
|
||||
TF_ASSIGN_OR_RETURN(compute_type,
|
||||
gpu::GetBlasComputationType(
|
||||
cfg.precision_algorithm, lhs_layout.dtype,
|
||||
output_layout.dtype, cfg.compute_precision));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
compute_type,
|
||||
gpu::GetBlasComputationType(
|
||||
cfg.precision_algorithm, lhs_layout.dtype, output_layout.dtype,
|
||||
cfg.compute_precision,
|
||||
parent_->GetDeviceDescription().gpu_compute_capability()));
|
||||
}
|
||||
|
||||
// FP8 matmuls have a fast accumulation mode that is less precise than the
|
||||
|
||||
@@ -409,7 +409,6 @@ cc_library(
|
||||
name = "gpu_blas_lt",
|
||||
srcs = ["gpu_blas_lt.cc"],
|
||||
hdrs = ["gpu_blas_lt.h"],
|
||||
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
|
||||
deps = [
|
||||
":gpu_blas_lt_proto_cc",
|
||||
"//xla:shape_util",
|
||||
@@ -420,8 +419,11 @@ cc_library(
|
||||
"//xla/service:algorithm_util",
|
||||
"//xla/stream_executor:blas",
|
||||
"//xla/stream_executor:device_address",
|
||||
"//xla/stream_executor:device_description",
|
||||
"//xla/stream_executor:platform",
|
||||
"//xla/stream_executor:stream",
|
||||
"//xla/stream_executor:stream_executor_h",
|
||||
"//xla/stream_executor/cuda:cuda_platform_id",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/protobuf:dnn_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@@ -435,9 +437,8 @@ cc_library(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
] + if_cuda_is_configured([
|
||||
"@local_tsl//tsl/platform:tensor_float_32_hdr_lib",
|
||||
]) + if_static([
|
||||
] + if_static([
|
||||
"@local_tsl//tsl/platform:tensor_float_32_utils",
|
||||
]),
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/service/algorithm_util.h"
|
||||
#include "xla/stream_executor/blas.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/gpu/gpu_blas_lt.pb.h"
|
||||
#include "xla/stream_executor/stream.h"
|
||||
#include "xla/stream_executor/stream_executor.h"
|
||||
@@ -37,9 +38,7 @@ limitations under the License.
|
||||
#include "xla/tsl/protobuf/dnn.pb.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tsl/platform/tensor_float_32_utils.h"
|
||||
#endif
|
||||
|
||||
namespace stream_executor {
|
||||
|
||||
@@ -205,7 +204,8 @@ xla::GemmConfigProto::MatrixLayout MatrixLayout::ToProto() const {
|
||||
|
||||
absl::StatusOr<ComputationType> GetBlasComputationType(
|
||||
xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype,
|
||||
xla::PrimitiveType output_dtype, int64_t compute_precision) {
|
||||
xla::PrimitiveType output_dtype, int64_t compute_precision,
|
||||
const GpuComputeCapability& cc) {
|
||||
if (algorithm == xla::PrecisionConfig::ALG_UNSET) {
|
||||
switch (output_dtype) {
|
||||
case PrimitiveType::F8E5M2: // fall-through
|
||||
@@ -222,14 +222,12 @@ absl::StatusOr<ComputationType> GetBlasComputationType(
|
||||
return ComputationType::kF32;
|
||||
case PrimitiveType::F32: // fall-through
|
||||
case PrimitiveType::C64:
|
||||
#if GOOGLE_CUDA
|
||||
if (tsl::tensor_float_32_execution_enabled() &&
|
||||
if (cc.IsCuda() && tsl::tensor_float_32_execution_enabled() &&
|
||||
compute_precision <= 1 && lhs_dtype == output_dtype) {
|
||||
// CublasLt requires compute type to be F32 for F8 matmul.
|
||||
// TF32 should only be chosen for FP32 or C64 gemm
|
||||
return ComputationType::kTF32AsF32;
|
||||
}
|
||||
#endif
|
||||
return ComputationType::kF32;
|
||||
case PrimitiveType::F64: // fall-through
|
||||
case PrimitiveType::C128:
|
||||
|
||||
@@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "xla/stream_executor/blas.h"
|
||||
#include "xla/stream_executor/device_address.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/gpu/gpu_blas_lt.pb.h"
|
||||
#include "xla/types.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
@@ -44,7 +45,8 @@ absl::StatusOr<xla::PrimitiveType> AsXlaPrimitiveType(blas::DataType dtype);
|
||||
|
||||
absl::StatusOr<blas::ComputationType> GetBlasComputationType(
|
||||
xla::PrecisionConfig::Algorithm algorithm, xla::PrimitiveType lhs_dtype,
|
||||
xla::PrimitiveType output_dtype, int64_t compute_precision);
|
||||
xla::PrimitiveType output_dtype, int64_t compute_precision,
|
||||
const GpuComputeCapability& cc);
|
||||
|
||||
// Returns the type for the alpha and beta scalars.
|
||||
blas::DataType GetScaleType(blas::DataType c_type,
|
||||
|
||||
@@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "Eigen/Core"
|
||||
#include "rocm/include/hip/library_types.h"
|
||||
#include "rocm/include/hipblas/hipblas.h"
|
||||
#include "rocm/include/hipblaslt/hipblaslt.h"
|
||||
@@ -55,7 +54,6 @@ limitations under the License.
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/types.h"
|
||||
#include "xla/util.h"
|
||||
#include "tsl/platform/ml_dtypes.h"
|
||||
|
||||
#define SET_ATTR(setter, handle, attr, value) \
|
||||
ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter)
|
||||
@@ -326,10 +324,12 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg, Epilogue epilogue) const
|
||||
|
||||
auto compute_type = cfg.compute_type;
|
||||
if (!compute_type) { // obtain compute_type unless provided by the user
|
||||
TF_ASSIGN_OR_RETURN(compute_type,
|
||||
gpu::GetBlasComputationType(
|
||||
cfg.precision_algorithm, lhs_layout.dtype,
|
||||
output_layout.dtype, cfg.compute_precision));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
compute_type,
|
||||
gpu::GetBlasComputationType(
|
||||
cfg.precision_algorithm, lhs_layout.dtype, output_layout.dtype,
|
||||
cfg.compute_precision,
|
||||
parent_->GetDeviceDescription().gpu_compute_capability()));
|
||||
}
|
||||
|
||||
if (lhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) {
|
||||
|
||||
Reference in New Issue
Block a user