mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Merge pull request #41515 from ROCmSoftwarePlatform:google-upstream-rocm-mlir-integration-prototype
PiperOrigin-RevId: 329868219 Change-Id: I09132dcdaca4653924d7003b878eaca9f1f9971a
This commit is contained in:
@@ -1442,6 +1442,11 @@ def main():
|
||||
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
|
||||
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
|
||||
|
||||
if ((environ_cp.get('TF_NEED_ROCM') == '1') and
|
||||
(environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')):
|
||||
write_to_bazelrc(
|
||||
'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1')
|
||||
|
||||
environ_cp['TF_NEED_CUDA'] = str(
|
||||
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
@@ -13,10 +14,10 @@ package_group(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cubin_creator",
|
||||
srcs = ["cubin_creator.cc"],
|
||||
hdrs = ["cubin_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
|
||||
name = "gpu_binary_creator",
|
||||
srcs = ["gpu_binary_creator.cc"],
|
||||
hdrs = ["gpu_binary_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]),
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@@ -29,6 +30,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:TargetROCDLIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@@ -44,15 +46,19 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
|
||||
] + if_cuda([
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
]) + if_rocm([
|
||||
"//tensorflow/core/platform:rocm_rocdl_path",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_to_cubin",
|
||||
srcs = ["tf_to_cubin.cc"],
|
||||
name = "tf_to_gpu_binary",
|
||||
srcs = ["tf_to_gpu_binary.cc"],
|
||||
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
||||
deps = [
|
||||
":cubin_creator",
|
||||
":gpu_binary_creator",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
||||
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- cubin_creator.cc -----------------------------------------*- C++ -*-===//
|
||||
//===- gpu_binary_creator.cc ------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the function to compile a TF kernel function to a cubin.
|
||||
// This file implements the function to compile a TF kernel function
|
||||
// to gpu binary (hsaco for AMD, cubin for NVIDIA)
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@@ -42,6 +43,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
@@ -59,6 +61,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/rocm_rocdl_path.h"
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
@@ -249,7 +253,8 @@ Status PropagateTensorFlowABIKnowledgeToKernel(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
StatusOr<std::vector<uint8_t>>
|
||||
tensorflow::kernel_gen::GenerateGpuBinaryForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
@@ -266,13 +271,44 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
options.use_approximations = true;
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options));
|
||||
}
|
||||
|
||||
#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
" Did you specify either --config=rocm or --config=cuda ?");
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get()));
|
||||
#elif GOOGLE_CUDA
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
#endif
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape));
|
||||
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
auto llvmModule = mlir::translateModuleToROCDLIR(*kernel_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to ROCDL IR");
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier("acme");
|
||||
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
int gpu_version = compute_capability.first;
|
||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||
|
||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), gpu_version, config,
|
||||
libdevice_dir);
|
||||
#elif GOOGLE_CUDA
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to NVVM");
|
||||
@@ -295,12 +331,8 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
config, libdevice_dir, enable_fusion));
|
||||
VLOG(1) << ptx;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
return tensorflow::se::CompileGpuAsm(
|
||||
std::get<0>(compute_capability), std::get<1>(compute_capability),
|
||||
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
|
||||
#else
|
||||
return InternalError(
|
||||
"GOOGLE_CUDA not defined. Did you specify --config=cuda ?");
|
||||
#endif
|
||||
}
|
||||
@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- cubin_creator.h ------------------------------------------*- C++ -*-===//
|
||||
//===- gpu_binary_creator.h -------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file declares the function to compile a TF kernel function to a cubin.
|
||||
// This file declares the function to compile a TF kernel function
|
||||
// to gpu binary (hsaco for AMD, cubin for NVIDIA)
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_GPU_BINARY_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_GPU_BINARY_CREATOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -30,7 +31,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
|
||||
xla::StatusOr<std::vector<uint8_t>> GenerateGpuBinaryForTfCode(
|
||||
llvm::StringRef tf_code,
|
||||
std::pair<int32_t, int32_t> compute_capability = {7, 5},
|
||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
@@ -39,4 +40,4 @@ xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_GPU_BINARY_CREATOR_H_
|
||||
@@ -12,9 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===//
|
||||
//===- tf_to_gpu_binary.cc --------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the entry point to compile a tf op to a cubin file.
|
||||
// This file implements the entry point to compile a tf op to a gpu binary
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <string>
|
||||
@@ -24,7 +24,7 @@
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@@ -53,8 +53,12 @@ int main(int argc, char** argv) {
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture, 0);
|
||||
#else
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
#endif
|
||||
|
||||
std::string tf_code;
|
||||
auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
@@ -64,7 +68,7 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
auto cubin = tensorflow::kernel_gen::GenerateGpuBinaryForTfCode(
|
||||
tf_code, compute_capability, tile_sizes, same_shape, unroll_factors);
|
||||
|
||||
if (!cubin.ok()) {
|
||||
@@ -51,12 +51,14 @@ string HloModuleConfig::compilation_cache_key() const {
|
||||
string key = absl::StrCat("profiling=", hlo_profiling_enabled());
|
||||
StrAppend(&key, "::(");
|
||||
std::vector<string> params;
|
||||
for (const ShapeLayout& param_layout :
|
||||
entry_computation_layout_->parameter_layouts()) {
|
||||
params.push_back(param_layout.shape().DebugString());
|
||||
if (entry_computation_layout_.has_value()) {
|
||||
for (const ShapeLayout& param_layout :
|
||||
entry_computation_layout_->parameter_layouts()) {
|
||||
params.push_back(param_layout.shape().DebugString());
|
||||
}
|
||||
StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
|
||||
entry_computation_layout_->result_shape().SerializeAsString());
|
||||
}
|
||||
StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
|
||||
entry_computation_layout_->result_shape().SerializeAsString());
|
||||
if (seed() != 0) {
|
||||
// TODO(b/32083678): force recompilation to reset global state.
|
||||
static std::atomic<int> counter{0};
|
||||
|
||||
@@ -201,6 +201,7 @@ cc_library(
|
||||
"@llvm-project//mlir:CFGTransforms",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:GPUToNVVMTransforms",
|
||||
"@llvm-project//mlir:GPUToROCDLTransforms",
|
||||
"@llvm-project//mlir:GPUTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
@@ -209,6 +210,7 @@ cc_library(
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:NVVMDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:ROCDLDialect",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:SCFToGPUPass",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
|
||||
@@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
@@ -26,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
|
||||
@@ -197,6 +199,85 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A pass that does the final lowering to ROCDL. It collects all the patterns
|
||||
/// that are currently required, currently mixing std, linalg and gpu.
|
||||
class LowerToROCDLPass
|
||||
: public ::mlir::PassWrapper<
|
||||
LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const override {
|
||||
registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
::mlir::gpu::GPUModuleOp m = getOperation();
|
||||
|
||||
::mlir::OwningRewritePatternList patterns;
|
||||
::mlir::populateGpuRewritePatterns(m.getContext(), patterns);
|
||||
::mlir::applyPatternsAndFoldGreedily(m, patterns);
|
||||
patterns.clear();
|
||||
|
||||
::mlir::LLVMTypeConverter converter(m.getContext());
|
||||
::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
// TODO(b/145824979) Remove linalg once sliceop is in std.
|
||||
::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
|
||||
&getContext());
|
||||
::mlir::populateGpuToROCDLConversionPatterns(converter, patterns);
|
||||
::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
|
||||
|
||||
::mlir::ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<::mlir::gpu::GPUDialect>();
|
||||
target
|
||||
.addIllegalOp<mlir::LLVM::CosOp, mlir::LLVM::ExpOp, mlir::LLVM::FAbsOp,
|
||||
mlir::LLVM::FCeilOp, mlir::LLVM::LogOp,
|
||||
mlir::LLVM::Log10Op, mlir::LLVM::Log2Op>();
|
||||
target.addIllegalOp<mlir::FuncOp>();
|
||||
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
||||
target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
|
||||
// TODO(csigg): Remove once we support replacing non-root ops.
|
||||
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
|
||||
::mlir::gpu::YieldOp>();
|
||||
if (failed(mlir::applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) {
|
||||
// We cannot verify as the signature of the kernel is rewritten.
|
||||
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false,
|
||||
/*out=*/llvm::dbgs());
|
||||
|
||||
// Rewrite kernel functions to LLVM IR.
|
||||
auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernelPm.addPass(::mlir::createLowerToCFGPass());
|
||||
kernelPm.addPass(absl::make_unique<LowerToROCDLPass>());
|
||||
|
||||
// Some basic cleanup.
|
||||
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Remove all location information to prevent a debug build.
|
||||
kernelPm.addPass(::mlir::createStripDebugInfoPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering to ROCDL IR failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module) {
|
||||
auto kernelModule = ::mlir::ModuleOp::create(module.getLoc());
|
||||
// TODO(b/137624192): This also needs to resolve naming conflicts.
|
||||
|
||||
@@ -36,6 +36,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module,
|
||||
|
||||
Status LowerKernelBodiesToNVVM(mlir::ModuleOp module);
|
||||
|
||||
Status LowerKernelBodiesToROCDL(mlir::ModuleOp module);
|
||||
|
||||
StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module);
|
||||
|
||||
} // namespace mlir_gpu
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
load("//tensorflow/core/kernels/mlir_generated:build_defs.bzl", "gen_kernel_library", "if_mlir_generated_gpu_kernels_enabled")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_cuda_or_rocm",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
@@ -10,7 +11,6 @@ load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/core/kernels:__subpackages__"],
|
||||
@@ -33,7 +33,7 @@ tf_kernel_library(
|
||||
"cwise_op_gpu_tanh.cu.cc",
|
||||
],
|
||||
tags = ["manual"],
|
||||
deps = if_cuda([
|
||||
deps = if_cuda_or_rocm([
|
||||
":abs_kernels",
|
||||
":tanh_kernels",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"""Generates cubin headers for TF dialect ops."""
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures", "if_cuda")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures")
|
||||
load(
|
||||
"@local_config_rocm//rocm:build_defs.bzl",
|
||||
"rocm_gpu_architectures",
|
||||
"rocm_is_configured",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm")
|
||||
|
||||
def if_mlir_generated_gpu_kernels_enabled(if_true, if_false = []):
|
||||
return select({
|
||||
@@ -15,9 +21,12 @@ def _lookup_file(filegroup, path):
|
||||
return file
|
||||
return None
|
||||
|
||||
CubinInfo = provider(fields = ["cubins"])
|
||||
GpuBinaryInfo = provider(
|
||||
"GPU binaries in either cubin format or hsaco format",
|
||||
fields = ["cubins", "hsacos"],
|
||||
)
|
||||
|
||||
def _gen_kernel_cubin_impl(ctx):
|
||||
def _gen_kernel_cubin_impl_cuda(ctx):
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
cmd_args = []
|
||||
@@ -45,10 +54,37 @@ def _gen_kernel_cubin_impl(ctx):
|
||||
mnemonic = "compile",
|
||||
)
|
||||
cubins.append(cubin)
|
||||
return [CubinInfo(cubins = cubins)]
|
||||
return [GpuBinaryInfo(cubins = cubins)]
|
||||
|
||||
def _gen_kernel_cubin_impl_rocm(ctx):
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
cmd_args = []
|
||||
if ctx.attr.same_shape:
|
||||
cmd_args.append("--same_shape=%s" % ctx.attr.same_shape)
|
||||
if ctx.attr.unroll_factors:
|
||||
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
|
||||
|
||||
hsacos = []
|
||||
for arch in ctx.attr.gpu_archs:
|
||||
filename = "%s.%s.hsaco" % (name, arch)
|
||||
hsaco = ctx.actions.declare_file(filename)
|
||||
ctx.actions.run(
|
||||
inputs = [ctx.file.mlir_op, ctx.file._tfso],
|
||||
outputs = [hsaco],
|
||||
executable = ctx.executable._tool,
|
||||
arguments = cmd_args + [
|
||||
"--tile_sizes=%s" % tile_sizes,
|
||||
"--arch=%s" % arch[3:], # DDD in "gfxDDD"
|
||||
"--input=%s" % ctx.file.mlir_op.path,
|
||||
"--output=%s" % hsaco.path,
|
||||
],
|
||||
mnemonic = "compile",
|
||||
)
|
||||
hsacos.append(hsaco)
|
||||
return [GpuBinaryInfo(hsacos = hsacos)]
|
||||
|
||||
_gen_kernel_cubin_rule = rule(
|
||||
implementation = _gen_kernel_cubin_impl,
|
||||
attrs = {
|
||||
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
|
||||
"tile_size": attr.string(mandatory = True),
|
||||
@@ -62,16 +98,17 @@ _gen_kernel_cubin_rule = rule(
|
||||
),
|
||||
"_tool": attr.label(
|
||||
executable = True,
|
||||
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
|
||||
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary"),
|
||||
cfg = "host",
|
||||
),
|
||||
},
|
||||
output_to_genfiles = True,
|
||||
implementation = _gen_kernel_cubin_impl_rocm if rocm_is_configured() else _gen_kernel_cubin_impl_cuda,
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr_impl(ctx):
|
||||
def _gen_kernel_image_hdr_impl_cuda(ctx):
|
||||
images = []
|
||||
for cubin in ctx.attr.input[CubinInfo].cubins:
|
||||
for cubin in ctx.attr.input[GpuBinaryInfo].cubins:
|
||||
arch = cubin.path.split(".")[-2]
|
||||
images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
|
||||
|
||||
@@ -79,8 +116,8 @@ def _gen_kernel_image_hdr_impl(ctx):
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
|
||||
ctx.actions.run(
|
||||
outputs = [fatbin],
|
||||
inputs = ctx.attr.input[CubinInfo].cubins,
|
||||
executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
|
||||
inputs = ctx.attr.input[GpuBinaryInfo].cubins,
|
||||
executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"),
|
||||
arguments = [
|
||||
"--64",
|
||||
"--cmdline=--compile-only",
|
||||
@@ -91,7 +128,7 @@ def _gen_kernel_image_hdr_impl(ctx):
|
||||
mnemonic = "fatbinary",
|
||||
)
|
||||
|
||||
bin2c = _lookup_file(ctx.attr._cuda_root, "bin/bin2c")
|
||||
bin2c = _lookup_file(ctx.attr._gpu_root, "bin/bin2c")
|
||||
ctx.actions.run_shell(
|
||||
outputs = [ctx.outputs.out],
|
||||
inputs = [fatbin],
|
||||
@@ -101,29 +138,74 @@ def _gen_kernel_image_hdr_impl(ctx):
|
||||
mnemonic = "bin2c",
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr_impl_rocm(ctx):
|
||||
hsaco_files = []
|
||||
hsaco_targets = []
|
||||
|
||||
# Add a dummy host target triple...clang-offload-bundler requires 1 and only 1 host target triple
|
||||
hsaco_files.append("/dev/null")
|
||||
hsaco_targets.append("host-x86_64-unknown-linux")
|
||||
|
||||
hsacos = ctx.attr.input[GpuBinaryInfo].hsacos
|
||||
for hsaco in hsacos:
|
||||
gfx_arch = hsaco.path.split(".")[-2]
|
||||
hsaco_files.append(hsaco.path)
|
||||
hsaco_targets.append("hip-amdgcn-amd-amdhsa-%s" % gfx_arch)
|
||||
|
||||
# Generate fatbin file from all hsacos.
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
|
||||
ctx.actions.run(
|
||||
outputs = [fatbin],
|
||||
inputs = hsacos,
|
||||
executable = _lookup_file(ctx.attr._gpu_root, "bin/clang-offload-bundler"),
|
||||
arguments = [
|
||||
"--inputs=%s" % ",".join(hsaco_files),
|
||||
"--targets=%s" % ",".join(hsaco_targets),
|
||||
"--type=o",
|
||||
"--outputs=%s" % fatbin.path,
|
||||
],
|
||||
mnemonic = "fatbinary",
|
||||
)
|
||||
|
||||
ctx.actions.run_shell(
|
||||
outputs = [ctx.outputs.out],
|
||||
inputs = [fatbin],
|
||||
command = (
|
||||
("echo 'static const unsigned char %s[] = {' > %s && " +
|
||||
"hexdump -v -e \'/1 \"0x%%02x, \"\' %s | cat >> %s && " +
|
||||
"echo '};' >> %s") % (
|
||||
ctx.attr.symbol,
|
||||
ctx.outputs.out.path,
|
||||
fatbin.path,
|
||||
ctx.outputs.out.path,
|
||||
ctx.outputs.out.path,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
_gen_kernel_image_hdr_rule = rule(
|
||||
implementation = _gen_kernel_image_hdr_impl,
|
||||
implementation = _gen_kernel_image_hdr_impl_rocm if rocm_is_configured() else _gen_kernel_image_hdr_impl_cuda,
|
||||
output_to_genfiles = True,
|
||||
attrs = {
|
||||
"input": attr.label(mandatory = True, providers = [CubinInfo]),
|
||||
"input": attr.label(mandatory = True, providers = [GpuBinaryInfo]),
|
||||
"out": attr.output(mandatory = True),
|
||||
"symbol": attr.string(mandatory = True),
|
||||
"_cuda_root": attr.label(
|
||||
default = Label("@local_config_cuda//cuda:cuda_root"),
|
||||
"_gpu_root": attr.label(
|
||||
default = Label("@local_config_rocm//rocm:rocm_root") if rocm_is_configured() else Label("@local_config_cuda//cuda:cuda_root"),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None):
|
||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||
if cuda_gpu_architectures():
|
||||
if cuda_gpu_architectures() or rocm_gpu_architectures():
|
||||
_gen_kernel_cubin_rule(
|
||||
name = name + "_cubin",
|
||||
mlir_op = mlir_op,
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
unroll_factors = unroll_factors,
|
||||
gpu_archs = cuda_gpu_architectures(),
|
||||
gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(),
|
||||
)
|
||||
_gen_kernel_image_hdr_rule(
|
||||
name = name,
|
||||
@@ -173,7 +255,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
|
||||
same_shape: The information about which shapes are the same, e.g. "0,1".
|
||||
"""
|
||||
|
||||
if cuda_gpu_architectures():
|
||||
if cuda_gpu_architectures() or rocm_gpu_architectures():
|
||||
for type in types:
|
||||
_gen_mlir_op(
|
||||
name = name,
|
||||
@@ -189,6 +271,6 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
|
||||
|
||||
native.cc_library(
|
||||
name = name + "_kernels",
|
||||
hdrs = if_cuda(if_true = [":{name}_{type}_kernel".format(name = name, type = type) for type in types]),
|
||||
hdrs = if_cuda_or_rocm(if_true = [":{name}_{type}_kernel".format(name = name, type = type) for type in types]),
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
8
third_party/gpus/rocm/BUILD.tpl
vendored
8
third_party/gpus/rocm/BUILD.tpl
vendored
@@ -143,4 +143,12 @@ cc_library(
|
||||
data = ["rocm/lib/%{hipsparse_lib}"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "rocm_root",
|
||||
srcs = [
|
||||
"rocm/bin/clang-offload-bundler",
|
||||
"rocm/bin/bin2c.py",
|
||||
],
|
||||
)
|
||||
|
||||
%{copy_rules}
|
||||
|
||||
4
third_party/gpus/rocm/build_defs.bzl.tpl
vendored
4
third_party/gpus/rocm/build_defs.bzl.tpl
vendored
@@ -34,6 +34,10 @@ def rocm_is_configured():
|
||||
"""Returns true if ROCm was enabled during the configure process."""
|
||||
return %{rocm_is_configured}
|
||||
|
||||
def rocm_gpu_architectures():
|
||||
"""Returns a list of supported GPU architectures."""
|
||||
return %{rocm_gpu_architectures}
|
||||
|
||||
def if_rocm_is_configured(x):
|
||||
"""Tests if the ROCm was enabled during the configure process.
|
||||
|
||||
|
||||
46
third_party/gpus/rocm_configure.bzl
vendored
46
third_party/gpus/rocm_configure.bzl
vendored
@@ -9,8 +9,7 @@
|
||||
* `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then
|
||||
use the system default.
|
||||
* `TF_MIOPEN_VERSION`: The version of the MIOpen library.
|
||||
* `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. Default is
|
||||
`gfx803,gfx900`.
|
||||
* `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets.
|
||||
"""
|
||||
|
||||
load(
|
||||
@@ -44,7 +43,6 @@ _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
|
||||
_DEFAULT_ROCM_VERSION = ""
|
||||
_DEFAULT_MIOPEN_VERSION = ""
|
||||
_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
|
||||
_DEFAULT_ROCM_AMDGPU_TARGETS = ["gfx803", "gfx900"]
|
||||
|
||||
def verify_build_defines(params):
|
||||
"""Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted.
|
||||
@@ -228,11 +226,14 @@ def _rocm_toolkit_path(repository_ctx, bash_bin):
|
||||
auto_configure_fail("Cannot find rocm toolkit path.")
|
||||
return rocm_toolkit_path
|
||||
|
||||
def _amdgpu_targets(repository_ctx):
|
||||
def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
|
||||
"""Returns a list of strings representing AMDGPU targets."""
|
||||
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
|
||||
if not amdgpu_targets_str:
|
||||
return _DEFAULT_ROCM_AMDGPU_TARGETS
|
||||
cmd = "%s/bin/rocm_agent_enumerator" % rocm_toolkit_path
|
||||
result = execute(repository_ctx, [bash_bin, "-c", cmd])
|
||||
targets = [target for target in result.stdout.strip().split("\n") if target != "gfx000"]
|
||||
amdgpu_targets_str = ",".join(targets)
|
||||
amdgpu_targets = amdgpu_targets_str.split(",")
|
||||
for amdgpu_target in amdgpu_targets:
|
||||
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
|
||||
@@ -416,7 +417,7 @@ def _get_rocm_config(repository_ctx, bash_bin):
|
||||
rocm_toolkit_path = _rocm_toolkit_path(repository_ctx, bash_bin)
|
||||
return struct(
|
||||
rocm_toolkit_path = rocm_toolkit_path,
|
||||
amdgpu_targets = _amdgpu_targets(repository_ctx),
|
||||
amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin),
|
||||
)
|
||||
|
||||
def _tpl_path(repository_ctx, labelname):
|
||||
@@ -464,6 +465,7 @@ def _create_dummy_repository(repository_ctx):
|
||||
{
|
||||
"%{rocm_is_configured}": "False",
|
||||
"%{rocm_extra_copts}": "[]",
|
||||
"%{rocm_gpu_architectures}": "[]",
|
||||
},
|
||||
)
|
||||
_tpl(
|
||||
@@ -532,12 +534,8 @@ def _genrule(src_dir, genrule_name, command, outs):
|
||||
)
|
||||
|
||||
def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
|
||||
if False:
|
||||
amdgpu_target_flags = ["--amdgpu-target=" +
|
||||
amdgpu_target for amdgpu_target in amdgpu_targets]
|
||||
else:
|
||||
# AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc"
|
||||
amdgpu_target_flags = []
|
||||
amdgpu_target_flags = ["--amdgpu-target=" +
|
||||
amdgpu_target for amdgpu_target in amdgpu_targets]
|
||||
return str(amdgpu_target_flags)
|
||||
|
||||
def _create_local_rocm_repository(repository_ctx):
|
||||
@@ -611,6 +609,26 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
outs = rocm_lib_outs,
|
||||
))
|
||||
|
||||
clang_offload_bundler_path = rocm_toolkit_path + _if_hipcc_is_hipclang(
|
||||
repository_ctx,
|
||||
rocm_config,
|
||||
bash_bin,
|
||||
"/llvm/bin/",
|
||||
"/hcc/bin/",
|
||||
) + "clang-offload-bundler"
|
||||
|
||||
# copy files mentioned in third_party/gpus/rocm/BUILD
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "rocm-bin",
|
||||
srcs = [
|
||||
clang_offload_bundler_path,
|
||||
],
|
||||
outs = [
|
||||
"rocm/bin/" + "clang-offload-bundler",
|
||||
],
|
||||
))
|
||||
|
||||
# Set up BUILD file for rocm/
|
||||
repository_ctx.template(
|
||||
"rocm/build_defs.bzl",
|
||||
@@ -621,6 +639,7 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
repository_ctx,
|
||||
rocm_config.amdgpu_targets,
|
||||
),
|
||||
"%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets),
|
||||
},
|
||||
)
|
||||
repository_ctx.template(
|
||||
@@ -719,9 +738,6 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"%{hcc_runtime_library}": "mcwamp",
|
||||
"%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
|
||||
"%{gcc_host_compiler_path}": str(cc),
|
||||
"%{rocm_amdgpu_targets}": ",".join(
|
||||
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user