diff --git a/configure.py b/configure.py index 9524eada3cd..96cc70a494b 100644 --- a/configure.py +++ b/configure.py @@ -1442,6 +1442,11 @@ def main(): write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH')) + if ((environ_cp.get('TF_NEED_ROCM') == '1') and + (environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')): + write_to_bazelrc( + 'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1') + environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) if (environ_cp.get('TF_NEED_CUDA') == '1' and diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 3c88318064b..2a0e0abfceb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -1,5 +1,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") package( default_visibility = [":friends"], @@ -13,10 +14,10 @@ package_group( ) cc_library( - name = "cubin_creator", - srcs = ["cubin_creator.cc"], - hdrs = ["cubin_creator.h"], - copts = if_cuda(["-DGOOGLE_CUDA=1"]), + name = "gpu_binary_creator", + srcs = ["gpu_binary_creator.cc"], + hdrs = ["gpu_binary_creator.h"], + copts = if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -29,6 +30,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:TargetROCDLIR", "@llvm-project//mlir:Transforms", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/hlo", @@ -44,15 +46,19 @@ cc_library( "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", - ] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]), + ] + if_cuda([ + "//tensorflow/stream_executor/gpu:asm_compiler", + ]) + if_rocm([ + "//tensorflow/core/platform:rocm_rocdl_path", + ]), ) tf_cc_binary( - name = "tf_to_cubin", - srcs = ["tf_to_cubin.cc"], + name = "tf_to_gpu_binary", + srcs = ["tf_to_gpu_binary.cc"], visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], deps = [ - ":cubin_creator", + ":gpu_binary_creator", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.cc similarity index 90% rename from tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.cc index 3b6af7f699c..611c082cd25 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -//===- cubin_creator.cc -----------------------------------------*- C++ -*-===// +//===- gpu_binary_creator.cc ------------------------------------*- C++ -*-===// // -// This file implements the function to compile a TF kernel function to a cubin. +// This file implements the function to compile a TF kernel function +// to gpu binary (hsaco for AMD, cubin for NVIDIA) // //===----------------------------------------------------------------------===// -#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h" #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "mlir/Target/ROCDLIR.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" @@ -59,6 +61,8 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #if GOOGLE_CUDA #include "tensorflow/stream_executor/gpu/asm_compiler.h" +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm_rocdl_path.h" #endif namespace { @@ -249,7 +253,8 @@ Status PropagateTensorFlowABIKnowledgeToKernel( } // namespace -StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( +StatusOr> +tensorflow::kernel_gen::GenerateGpuBinaryForTfCode( llvm::StringRef tf_code, std::pair compute_capability, llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, llvm::ArrayRef unroll_factors) { @@ -266,13 +271,44 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( options.use_approximations = true; TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options)); } + +#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA) + return InternalError( + "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." + " Did you specify either --config=rocm or --config=cuda ?"); +#endif + +#if TENSORFLOW_USE_ROCM + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToROCDL(module.get())); +#elif GOOGLE_CUDA TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); +#endif + TF_RETURN_IF_ERROR( PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape)); mlir::OwningModuleRef kernel_module = xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); + llvm::LLVMContext llvmContext; + +#if TENSORFLOW_USE_ROCM + auto llvmModule = mlir::translateModuleToROCDLIR(*kernel_module, llvmContext); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to ROCDL IR"); + } + + llvmModule->setModuleIdentifier("acme"); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + int gpu_version = compute_capability.first; + std::string libdevice_dir = tensorflow::RocdlRoot(); + + return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), gpu_version, config, + libdevice_dir); +#elif GOOGLE_CUDA auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext); if (!llvmModule) { return InternalError("Could not translate MLIR module to NVVM"); @@ -295,12 +331,8 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( config, libdevice_dir, enable_fusion)); VLOG(1) << ptx; -#if GOOGLE_CUDA return tensorflow::se::CompileGpuAsm( std::get<0>(compute_capability), std::get<1>(compute_capability), ptx.c_str(), xla::gpu::PtxOptsFromConfig(config)); -#else - return InternalError( - "GOOGLE_CUDA not defined. Did you specify --config=cuda ?"); #endif } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h similarity index 76% rename from tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h rename to tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h index 47626ba9d0d..7b7b799aa3f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -//===- cubin_creator.h ------------------------------------------*- C++ -*-===// +//===- gpu_binary_creator.h -------------------------------------*- C++ -*-===// // -// This file declares the function to compile a TF kernel function to a cubin. +// This file declares the function to compile a TF kernel function +// to gpu binary (hsaco for AMD, cubin for NVIDIA) // //===----------------------------------------------------------------------===// -#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_GPU_BINARY_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_GPU_BINARY_CREATOR_H_ #include #include @@ -30,7 +31,7 @@ limitations under the License. namespace tensorflow { namespace kernel_gen { -xla::StatusOr> GenerateCubinForTfCode( +xla::StatusOr> GenerateGpuBinaryForTfCode( llvm::StringRef tf_code, std::pair compute_capability = {7, 5}, llvm::ArrayRef tile_sizes = {16, 64}, @@ -39,4 +40,4 @@ xla::StatusOr> GenerateCubinForTfCode( } // namespace kernel_gen } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_GPU_BINARY_CREATOR_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc similarity index 90% rename from tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc rename to tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc index 96831689600..84979c62a97 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===// +//===- tf_to_gpu_binary.cc --------------------------------------*- C++ -*-===// // -// This file implements the entry point to compile a tf op to a cubin file. +// This file implements the entry point to compile a tf op to a gpu binary // //===----------------------------------------------------------------------===// #include @@ -24,7 +24,7 @@ #include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/gpu_binary_creator.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -53,8 +53,12 @@ int main(int argc, char** argv) { tensorflow::InitMlir y(&argc, &argv); llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); +#if TENSORFLOW_USE_ROCM + std::pair compute_capability(architecture, 0); +#else std::pair compute_capability(architecture / 10, architecture % 10); +#endif std::string tf_code; auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(), @@ -64,7 +68,7 @@ int main(int argc, char** argv) { return 1; } - auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode( + auto cubin = tensorflow::kernel_gen::GenerateGpuBinaryForTfCode( tf_code, compute_capability, tile_sizes, same_shape, unroll_factors); if (!cubin.ok()) { diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index eaed707607d..8158d198799 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -51,12 +51,14 @@ string HloModuleConfig::compilation_cache_key() const { string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; - for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); + if (entry_computation_layout_.has_value()) { + for (const ShapeLayout& param_layout : + entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", + entry_computation_layout_->result_shape().SerializeAsString()); } - StrAppend(&key, absl::StrJoin(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index af670eb059f..dd4fdf5e1b9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -201,6 +201,7 @@ cc_library( "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", @@ -209,6 +210,7 @@ cc_library( "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFTransforms", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 1b2edec7d61..8014332e5c0 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project @@ -26,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project @@ -197,6 +199,85 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { return Status::OK(); } +namespace { + +/// A pass that does the final lowering to ROCDL. It collects all the patterns +/// that are currently required, currently mixing std, linalg and gpu. +class LowerToROCDLPass + : public ::mlir::PassWrapper< + LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + + public: + void runOnOperation() override { + ::mlir::gpu::GPUModuleOp m = getOperation(); + + ::mlir::OwningRewritePatternList patterns; + ::mlir::populateGpuRewritePatterns(m.getContext(), patterns); + ::mlir::applyPatternsAndFoldGreedily(m, patterns); + patterns.clear(); + + ::mlir::LLVMTypeConverter converter(m.getContext()); + ::mlir::populateStdToLLVMConversionPatterns(converter, patterns); + // TODO(b/145824979) Remove linalg once sliceop is in std. + ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns, + &getContext()); + ::mlir::populateGpuToROCDLConversionPatterns(converter, patterns); + ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext()); + + ::mlir::ConversionTarget target(getContext()); + target.addIllegalDialect<::mlir::gpu::GPUDialect>(); + target + .addIllegalOp(); + target.addIllegalOp(); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); + // TODO(csigg): Remove once we support replacing non-root ops. + target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp, + ::mlir::gpu::YieldOp>(); + if (failed(mlir::applyFullConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) { + // We cannot verify as the signature of the kernel is rewritten. + ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); + applyPassManagerCLOptions(pm); + + auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, + /*out=*/llvm::dbgs()); + + // Rewrite kernel functions to LLVM IR. + auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); + kernelPm.addPass(::mlir::createLowerToCFGPass()); + kernelPm.addPass(absl::make_unique()); + + // Some basic cleanup. + kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Remove all location information to prevent a debug build. + kernelPm.addPass(::mlir::createStripDebugInfoPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering to ROCDL IR failed."); + } + return Status::OK(); +} + StatusOr ExtractKernelModule(mlir::ModuleOp module) { auto kernelModule = ::mlir::ModuleOp::create(module.getLoc()); // TODO(b/137624192): This also needs to resolve naming conflicts. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index bd633bb06cb..290550142ec 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -36,6 +36,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, Status LowerKernelBodiesToNVVM(mlir::ModuleOp module); +Status LowerKernelBodiesToROCDL(mlir::ModuleOp module); + StatusOr ExtractKernelModule(mlir::ModuleOp module); } // namespace mlir_gpu diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 9f3efe9d972..9ebb39b4ec8 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/core/kernels/mlir_generated:build_defs.bzl", "gen_kernel_library", "if_mlir_generated_gpu_kernels_enabled") load( "//tensorflow:tensorflow.bzl", + "if_cuda_or_rocm", "tf_kernel_library", ) load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") @@ -10,7 +11,6 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") package( default_visibility = ["//tensorflow/core/kernels:__subpackages__"], @@ -33,7 +33,7 @@ tf_kernel_library( "cwise_op_gpu_tanh.cu.cc", ], tags = ["manual"], - deps = if_cuda([ + deps = if_cuda_or_rocm([ ":abs_kernels", ":tanh_kernels", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index 2bf6e8fa3bb..f59df3d7f3c 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -1,6 +1,12 @@ """Generates cubin headers for TF dialect ops.""" -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures", "if_cuda") +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_gpu_architectures") +load( + "@local_config_rocm//rocm:build_defs.bzl", + "rocm_gpu_architectures", + "rocm_is_configured", +) +load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm") def if_mlir_generated_gpu_kernels_enabled(if_true, if_false = []): return select({ @@ -15,9 +21,12 @@ def _lookup_file(filegroup, path): return file return None -CubinInfo = provider(fields = ["cubins"]) +GpuBinaryInfo = provider( + "GPU binaries in either cubin format or hsaco format", + fields = ["cubins", "hsacos"], +) -def _gen_kernel_cubin_impl(ctx): +def _gen_kernel_cubin_impl_cuda(ctx): name = ctx.attr.name tile_sizes = ctx.attr.tile_size.replace("x", ",") cmd_args = [] @@ -45,10 +54,37 @@ def _gen_kernel_cubin_impl(ctx): mnemonic = "compile", ) cubins.append(cubin) - return [CubinInfo(cubins = cubins)] + return [GpuBinaryInfo(cubins = cubins)] + +def _gen_kernel_cubin_impl_rocm(ctx): + name = ctx.attr.name + tile_sizes = ctx.attr.tile_size.replace("x", ",") + cmd_args = [] + if ctx.attr.same_shape: + cmd_args.append("--same_shape=%s" % ctx.attr.same_shape) + if ctx.attr.unroll_factors: + cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors) + + hsacos = [] + for arch in ctx.attr.gpu_archs: + filename = "%s.%s.hsaco" % (name, arch) + hsaco = ctx.actions.declare_file(filename) + ctx.actions.run( + inputs = [ctx.file.mlir_op, ctx.file._tfso], + outputs = [hsaco], + executable = ctx.executable._tool, + arguments = cmd_args + [ + "--tile_sizes=%s" % tile_sizes, + "--arch=%s" % arch[3:], # DDD in "gfxDDD" + "--input=%s" % ctx.file.mlir_op.path, + "--output=%s" % hsaco.path, + ], + mnemonic = "compile", + ) + hsacos.append(hsaco) + return [GpuBinaryInfo(hsacos = hsacos)] _gen_kernel_cubin_rule = rule( - implementation = _gen_kernel_cubin_impl, attrs = { "mlir_op": attr.label(mandatory = True, allow_single_file = True), "tile_size": attr.string(mandatory = True), @@ -62,16 +98,17 @@ _gen_kernel_cubin_rule = rule( ), "_tool": attr.label( executable = True, - default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"), + default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_gpu_binary"), cfg = "host", ), }, output_to_genfiles = True, + implementation = _gen_kernel_cubin_impl_rocm if rocm_is_configured() else _gen_kernel_cubin_impl_cuda, ) -def _gen_kernel_image_hdr_impl(ctx): +def _gen_kernel_image_hdr_impl_cuda(ctx): images = [] - for cubin in ctx.attr.input[CubinInfo].cubins: + for cubin in ctx.attr.input[GpuBinaryInfo].cubins: arch = cubin.path.split(".")[-2] images.append("--image=profile=%s,file=%s" % (arch, cubin.path)) @@ -79,8 +116,8 @@ def _gen_kernel_image_hdr_impl(ctx): fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name) ctx.actions.run( outputs = [fatbin], - inputs = ctx.attr.input[CubinInfo].cubins, - executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"), + inputs = ctx.attr.input[GpuBinaryInfo].cubins, + executable = _lookup_file(ctx.attr._gpu_root, "bin/fatbinary"), arguments = [ "--64", "--cmdline=--compile-only", @@ -91,7 +128,7 @@ def _gen_kernel_image_hdr_impl(ctx): mnemonic = "fatbinary", ) - bin2c = _lookup_file(ctx.attr._cuda_root, "bin/bin2c") + bin2c = _lookup_file(ctx.attr._gpu_root, "bin/bin2c") ctx.actions.run_shell( outputs = [ctx.outputs.out], inputs = [fatbin], @@ -101,29 +138,74 @@ def _gen_kernel_image_hdr_impl(ctx): mnemonic = "bin2c", ) +def _gen_kernel_image_hdr_impl_rocm(ctx): + hsaco_files = [] + hsaco_targets = [] + + # Add a dummy host target triple...clang-offload-bundler requires 1 and only 1 host target triple + hsaco_files.append("/dev/null") + hsaco_targets.append("host-x86_64-unknown-linux") + + hsacos = ctx.attr.input[GpuBinaryInfo].hsacos + for hsaco in hsacos: + gfx_arch = hsaco.path.split(".")[-2] + hsaco_files.append(hsaco.path) + hsaco_targets.append("hip-amdgcn-amd-amdhsa-%s" % gfx_arch) + + # Generate fatbin file from all hsacos. + fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name) + ctx.actions.run( + outputs = [fatbin], + inputs = hsacos, + executable = _lookup_file(ctx.attr._gpu_root, "bin/clang-offload-bundler"), + arguments = [ + "--inputs=%s" % ",".join(hsaco_files), + "--targets=%s" % ",".join(hsaco_targets), + "--type=o", + "--outputs=%s" % fatbin.path, + ], + mnemonic = "fatbinary", + ) + + ctx.actions.run_shell( + outputs = [ctx.outputs.out], + inputs = [fatbin], + command = ( + ("echo 'static const unsigned char %s[] = {' > %s && " + + "hexdump -v -e \'/1 \"0x%%02x, \"\' %s | cat >> %s && " + + "echo '};' >> %s") % ( + ctx.attr.symbol, + ctx.outputs.out.path, + fatbin.path, + ctx.outputs.out.path, + ctx.outputs.out.path, + ) + ), + ) + _gen_kernel_image_hdr_rule = rule( - implementation = _gen_kernel_image_hdr_impl, + implementation = _gen_kernel_image_hdr_impl_rocm if rocm_is_configured() else _gen_kernel_image_hdr_impl_cuda, output_to_genfiles = True, attrs = { - "input": attr.label(mandatory = True, providers = [CubinInfo]), + "input": attr.label(mandatory = True, providers = [GpuBinaryInfo]), "out": attr.output(mandatory = True), "symbol": attr.string(mandatory = True), - "_cuda_root": attr.label( - default = Label("@local_config_cuda//cuda:cuda_root"), + "_gpu_root": attr.label( + default = Label("@local_config_rocm//rocm:rocm_root") if rocm_is_configured() else Label("@local_config_cuda//cuda:cuda_root"), ), }, ) def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None): """Generates a C header with fatbin data from a Tensorflow op.""" - if cuda_gpu_architectures(): + if cuda_gpu_architectures() or rocm_gpu_architectures(): _gen_kernel_cubin_rule( name = name + "_cubin", mlir_op = mlir_op, tile_size = tile_size, same_shape = same_shape, unroll_factors = unroll_factors, - gpu_archs = cuda_gpu_architectures(), + gpu_archs = rocm_gpu_architectures() if rocm_is_configured() else cuda_gpu_architectures(), ) _gen_kernel_image_hdr_rule( name = name, @@ -173,7 +255,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr same_shape: The information about which shapes are the same, e.g. "0,1". """ - if cuda_gpu_architectures(): + if cuda_gpu_architectures() or rocm_gpu_architectures(): for type in types: _gen_mlir_op( name = name, @@ -189,6 +271,6 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr native.cc_library( name = name + "_kernels", - hdrs = if_cuda(if_true = [":{name}_{type}_kernel".format(name = name, type = type) for type in types]), + hdrs = if_cuda_or_rocm(if_true = [":{name}_{type}_kernel".format(name = name, type = type) for type in types]), tags = tags, ) diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index 3c233b4f5b0..d2533a08de1 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -143,4 +143,12 @@ cc_library( data = ["rocm/lib/%{hipsparse_lib}"], ) +filegroup( + name = "rocm_root", + srcs = [ + "rocm/bin/clang-offload-bundler", + "rocm/bin/bin2c.py", + ], +) + %{copy_rules} diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 08c59f95a07..ce4c1b04399 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -34,6 +34,10 @@ def rocm_is_configured(): """Returns true if ROCm was enabled during the configure process.""" return %{rocm_is_configured} +def rocm_gpu_architectures(): + """Returns a list of supported GPU architectures.""" + return %{rocm_gpu_architectures} + def if_rocm_is_configured(x): """Tests if the ROCm was enabled during the configure process. diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 752f48aa25b..1312574f0aa 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -9,8 +9,7 @@ * `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then use the system default. * `TF_MIOPEN_VERSION`: The version of the MIOpen library. - * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. Default is - `gfx803,gfx900`. + * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ load( @@ -44,7 +43,6 @@ _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" _DEFAULT_ROCM_VERSION = "" _DEFAULT_MIOPEN_VERSION = "" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" -_DEFAULT_ROCM_AMDGPU_TARGETS = ["gfx803", "gfx900"] def verify_build_defines(params): """Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted. @@ -228,11 +226,14 @@ def _rocm_toolkit_path(repository_ctx, bash_bin): auto_configure_fail("Cannot find rocm toolkit path.") return rocm_toolkit_path -def _amdgpu_targets(repository_ctx): +def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin): """Returns a list of strings representing AMDGPU targets.""" amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS) if not amdgpu_targets_str: - return _DEFAULT_ROCM_AMDGPU_TARGETS + cmd = "%s/bin/rocm_agent_enumerator" % rocm_toolkit_path + result = execute(repository_ctx, [bash_bin, "-c", cmd]) + targets = [target for target in result.stdout.strip().split("\n") if target != "gfx000"] + amdgpu_targets_str = ",".join(targets) amdgpu_targets = amdgpu_targets_str.split(",") for amdgpu_target in amdgpu_targets: if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): @@ -416,7 +417,7 @@ def _get_rocm_config(repository_ctx, bash_bin): rocm_toolkit_path = _rocm_toolkit_path(repository_ctx, bash_bin) return struct( rocm_toolkit_path = rocm_toolkit_path, - amdgpu_targets = _amdgpu_targets(repository_ctx), + amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin), ) def _tpl_path(repository_ctx, labelname): @@ -464,6 +465,7 @@ def _create_dummy_repository(repository_ctx): { "%{rocm_is_configured}": "False", "%{rocm_extra_copts}": "[]", + "%{rocm_gpu_architectures}": "[]", }, ) _tpl( @@ -532,12 +534,8 @@ def _genrule(src_dir, genrule_name, command, outs): ) def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): - if False: - amdgpu_target_flags = ["--amdgpu-target=" + - amdgpu_target for amdgpu_target in amdgpu_targets] - else: - # AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc" - amdgpu_target_flags = [] + amdgpu_target_flags = ["--amdgpu-target=" + + amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) def _create_local_rocm_repository(repository_ctx): @@ -611,6 +609,26 @@ def _create_local_rocm_repository(repository_ctx): outs = rocm_lib_outs, )) + clang_offload_bundler_path = rocm_toolkit_path + _if_hipcc_is_hipclang( + repository_ctx, + rocm_config, + bash_bin, + "/llvm/bin/", + "/hcc/bin/", + ) + "clang-offload-bundler" + + # copy files mentioned in third_party/gpus/rocm/BUILD + copy_rules.append(make_copy_files_rule( + repository_ctx, + name = "rocm-bin", + srcs = [ + clang_offload_bundler_path, + ], + outs = [ + "rocm/bin/" + "clang-offload-bundler", + ], + )) + # Set up BUILD file for rocm/ repository_ctx.template( "rocm/build_defs.bzl", @@ -621,6 +639,7 @@ def _create_local_rocm_repository(repository_ctx): repository_ctx, rocm_config.amdgpu_targets, ), + "%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets), }, ) repository_ctx.template( @@ -719,9 +738,6 @@ def _create_local_rocm_repository(repository_ctx): "%{hcc_runtime_library}": "mcwamp", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), - "%{rocm_amdgpu_targets}": ",".join( - ["\"%s\"" % c for c in rocm_config.amdgpu_targets], - ), }, )