diff --git a/configure.py b/configure.py index 135001ed103..67d3016cb05 100644 --- a/configure.py +++ b/configure.py @@ -718,7 +718,8 @@ def create_android_sdk_rule(environ_cp): def get_ndk_api_level(environ_cp, android_ndk_home_path): - """Gets the appropriate NDK API level to use for the provided Android NDK path.""" + """Gets the appropriate NDK API level to use for the provided Android NDK path. + """ # First check to see if we're using a blessed version of the NDK. properties_path = '%s/source.properties' % android_ndk_home_path @@ -1215,8 +1216,6 @@ def main(): if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')): write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) - write_action_env_to_bazelrc('ROCBLAS_TENSILE_LIBPATH', - environ_cp.get('ROCM_PATH') + '/lib/library') if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('HIP_PLATFORM')): write_action_env_to_bazelrc('HIP_PLATFORM', environ_cp.get('HIP_PLATFORM')) diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc b/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc index 6553add829f..e6a66c916df 100644 --- a/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc @@ -1013,7 +1013,12 @@ tsl::Status RocmActivityCallbackImpl::operator()(const char* begin, } RETURN_IF_ROCTRACER_ERROR(static_cast( - roctracer_next_record(record, &record))); +#if TF_ROCM_VERSION >= 50300 + se::wrap::roctracer_next_record(record, &record) +#else + roctracer_next_record(record, &record) +#endif + )); } return tsl::OkStatus(); @@ -1486,7 +1491,11 @@ void RocmTracer::ActivityCallbackHandler(const char* begin, const char* end) { while (record < end_record) { DumpActivityRecord(record, "activity_tracing_enabled_ is false. Dropped!"); +#if TF_ROCM_VERSION >= 50300 + se::wrap::roctracer_next_record(record, &record); +#else roctracer_next_record(record, &record); +#endif } VLOG(3) << "Dropped Activity Records End"; } diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h index ce564c78409..5e1215f198a 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h @@ -30,8 +30,12 @@ limitations under the License. // Common place for all collective thunks to include nccl/rccl headers. #if TENSORFLOW_USE_ROCM +#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/rccl/rccl.h" #else +#include "rocm/include/rccl.h" +#endif +#else #include "third_party/nccl/nccl.h" #endif diff --git a/tensorflow/compiler/xla/stream_executor/rocm/hipsparse_wrapper.h b/tensorflow/compiler/xla/stream_executor/rocm/hipsparse_wrapper.h index 93c890b00b8..4286cc1d56b 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/hipsparse_wrapper.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/hipsparse_wrapper.h @@ -20,7 +20,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ +#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/hipsparse/hipsparse.h" +#else +#include "rocm/include/hipsparse.h" +#endif #include "tensorflow/compiler/xla/stream_executor/lib/env.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h index 8c59f45cfdc..6bbc77d74da 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h @@ -22,7 +22,11 @@ limitations under the License. #if TENSORFLOW_USE_ROCM +#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/hipfft/hipfft.h" +#else +#include "rocm/include/hipfft.h" +#endif #include "rocm/rocm_config.h" #endif diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocsolver_wrapper.h b/tensorflow/compiler/xla/stream_executor/rocm/rocsolver_wrapper.h index 5d3574de9cb..4471f04bfe5 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocsolver_wrapper.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocsolver_wrapper.h @@ -20,7 +20,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_ #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_ +#include "rocm/rocm_config.h" +#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/rocsolver/rocsolver.h" +#else +#include "rocm/include/rocsolver.h" +#endif + #include "tensorflow/compiler/xla/stream_executor/lib/env.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" diff --git a/tensorflow/compiler/xla/stream_executor/rocm/roctracer_wrapper.h b/tensorflow/compiler/xla/stream_executor/rocm/roctracer_wrapper.h index 6e61064a5ab..032af383acc 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/roctracer_wrapper.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/roctracer_wrapper.h @@ -23,6 +23,7 @@ limitations under the License. #include "rocm/include/roctracer/roctracer.h" #include "rocm/include/roctracer/roctracer_hcc.h" #include "rocm/include/roctracer/roctracer_hip.h" +#include "rocm/rocm_config.h" #include "tensorflow/compiler/xla/stream_executor/lib/env.h" #include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" @@ -60,6 +61,25 @@ namespace wrap { #endif // PLATFORM_GOOGLE +#if TF_ROCM_VERSION >= 50300 +#define FOREACH_ROCTRACER_API(DO_FUNC) \ + DO_FUNC(roctracer_default_pool_expl) \ + DO_FUNC(roctracer_disable_domain_activity) \ + DO_FUNC(roctracer_disable_domain_callback) \ + DO_FUNC(roctracer_disable_op_activity) \ + DO_FUNC(roctracer_disable_op_callback) \ + DO_FUNC(roctracer_enable_domain_activity_expl) \ + DO_FUNC(roctracer_enable_domain_callback) \ + DO_FUNC(roctracer_enable_op_activity_expl) \ + DO_FUNC(roctracer_enable_op_callback) \ + DO_FUNC(roctracer_error_string) \ + DO_FUNC(roctracer_flush_activity_expl) \ + DO_FUNC(roctracer_get_timestamp) \ + DO_FUNC(roctracer_op_string) \ + DO_FUNC(roctracer_open_pool_expl) \ + DO_FUNC(roctracer_set_properties) \ + DO_FUNC(roctracer_next_record) +#else #define FOREACH_ROCTRACER_API(DO_FUNC) \ DO_FUNC(roctracer_default_pool_expl) \ DO_FUNC(roctracer_disable_domain_activity) \ @@ -76,7 +96,7 @@ namespace wrap { DO_FUNC(roctracer_op_string) \ DO_FUNC(roctracer_open_pool_expl) \ DO_FUNC(roctracer_set_properties) - +#endif FOREACH_ROCTRACER_API(ROCTRACER_API_WRAPPER) #undef FOREACH_ROCTRACER_API diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index 9449147aab5..bef22b50ada 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -80,11 +80,29 @@ struct NumericTraits } // namespace cub #elif TENSORFLOW_USE_ROCM #include "rocm/include/hipcub/hipcub.hpp" +#include "rocm/rocm_config.h" namespace gpuprim = ::hipcub; // Required for sorting Eigen::half and bfloat16. namespace rocprim { namespace detail { +#if (TF_ROCM_VERSION >= 50200) +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7C00; + static constexpr uint16_t mantissa = 0x03FF; + using bit_type = uint16_t; +}; + +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7F80; + static constexpr uint16_t mantissa = 0x007F; + using bit_type = uint16_t; +}; +#endif template <> struct radix_key_codec_base : radix_key_codec_floating {}; diff --git a/tensorflow/core/kernels/nccl_ops.cc b/tensorflow/core/kernels/nccl_ops.cc index e6af9115141..b8ade086a8f 100644 --- a/tensorflow/core/kernels/nccl_ops.cc +++ b/tensorflow/core/kernels/nccl_ops.cc @@ -20,7 +20,11 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM +#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif #endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/nccl/nccl_manager.h" diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index fec0d7514b6..572b7ef7d1b 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -735,14 +735,12 @@ namespace { template struct GPUDataType; -// GPUDataType templates are currently not instantiated in the ROCm flow -// So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now -// hipblas library is not (yet) being pulled in via rocm_configure.bzl -// so cannot reference tyeps from hipblas headers here template <> struct GPUDataType { #if GOOGLE_CUDA static constexpr cudaDataType_t type = CUDA_R_16F; +#else + static constexpr hipDataType type = HIP_R_16F; #endif }; @@ -750,6 +748,8 @@ template <> struct GPUDataType { #if GOOGLE_CUDA static constexpr cudaDataType_t type = CUDA_R_32F; +#else + static constexpr hipDataType type = HIP_R_32F; #endif }; @@ -757,6 +757,8 @@ template <> struct GPUDataType> { #if GOOGLE_CUDA static constexpr cudaDataType_t type = CUDA_C_32F; +#else + static constexpr hipDataType type = HIP_C_32F; #endif }; @@ -764,6 +766,8 @@ template <> struct GPUDataType { #if GOOGLE_CUDA static constexpr cudaDataType_t type = CUDA_R_64F; +#else + static constexpr hipDataType type = HIP_R_64F; #endif }; @@ -771,6 +775,8 @@ template <> struct GPUDataType> { #if GOOGLE_CUDA static constexpr cudaDataType_t type = CUDA_C_64F; +#else + static constexpr hipDataType type = HIP_C_64F; #endif }; @@ -883,16 +889,16 @@ class CSRSparseMatrixMatMul { TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateCsr( &matA, m, k, nnz, const_cast(a.row_ptr.data()), const_cast(a.col_ind.data()), const_cast(a.values.data()), - CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, + HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, GPUDataType::type)); TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateDnMat( &matB, n, k, ldb, const_cast(b.data()), GPUDataType::type, - HIPSPARSE_ORDER_COL)); + HIPSPARSE_ORDER_COLUMN)); TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateDnMat( &matC, m, n, ldc, c.data(), GPUDataType::type, - HIPSPARSE_ORDER_COL)); + HIPSPARSE_ORDER_COLUMN)); size_t bufferSize = 0; TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( @@ -905,7 +911,7 @@ class CSRSparseMatrixMatMul { DCHECK(buffer.flat().data() != nullptr); TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, - &beta, matC, CUSPARSE_MM_ALG_DEFAULT, + &beta, matC, HIPSPARSE_MM_ALG_DEFAULT, buffer.flat().data())); TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseDestroyDnMat(matB)); diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index 1f30e2d4acd..1de57c9a852 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -30,7 +30,11 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM +#if (TF_ROCM_VERSION >= 50200) #include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif #include "tensorflow/core/common_runtime/gpu_device_context.h" #endif #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index ccfb931a7e5..a893fb5cc43 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -733,6 +733,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/stream_executor/rocm:rocblas_wrapper", "//tensorflow/compiler/xla/stream_executor/rocm:rocm_gpu_executor", "//tensorflow/compiler/xla/stream_executor/rocm:rocsolver_wrapper", + "//tensorflow/compiler/xla/stream_executor/rocm:hipsolver_wrapper", ] + if_rocm([ "@local_config_rocm//rocm:rocprim", ]), @@ -755,6 +756,7 @@ tf_kernel_library( "@local_config_cuda//cuda:cub_headers", ]) + if_rocm([ "//tensorflow/compiler/xla/stream_executor/rocm:rocsolver_wrapper", + "//tensorflow/compiler/xla/stream_executor/rocm:hipsolver_wrapper", "//tensorflow/compiler/xla/stream_executor/rocm:hipsparse_wrapper", ]), ) diff --git a/tensorflow/core/util/cuda_solvers.cc b/tensorflow/core/util/cuda_solvers.cc index bda7689f970..f5383a3f94b 100644 --- a/tensorflow/core/util/cuda_solvers.cc +++ b/tensorflow/core/util/cuda_solvers.cc @@ -490,6 +490,17 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context, return OkStatus(); } +#if TENSORFLOW_USE_ROCM +#define GETRS_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status GpuSolver::Getrs( \ + cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \ + int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \ + return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \ + cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \ + ldb, dev_lapack_info); \ + } +#else #define GETRS_INSTANCE(Scalar, type_prefix) \ template <> \ Status GpuSolver::Getrs( \ @@ -499,6 +510,7 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context, cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \ ldb, dev_lapack_info); \ } +#endif TF_CALL_LAPACK_TYPES(GETRS_INSTANCE); diff --git a/tensorflow/core/util/cuda_sparse.h b/tensorflow/core/util/cuda_sparse.h index 5d170364287..375bbc0f747 100644 --- a/tensorflow/core/util/cuda_sparse.h +++ b/tensorflow/core/util/cuda_sparse.h @@ -46,6 +46,7 @@ using gpusparseSpMMAlg_t = cusparseSpMMAlg_t; #elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" #include "tensorflow/compiler/xla/stream_executor/rocm/hipsparse_wrapper.h" using gpusparseStatus_t = hipsparseStatus_t; diff --git a/tensorflow/core/util/gpu_solvers.h b/tensorflow/core/util/gpu_solvers.h index e8310888d8d..e06c7b1a7b1 100644 --- a/tensorflow/core/util/gpu_solvers.h +++ b/tensorflow/core/util/gpu_solvers.h @@ -33,7 +33,11 @@ limitations under the License. #else #include "rocm/include/hip/hip_complex.h" #include "rocm/include/rocblas.h" +#include "rocm/rocm_config.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" +#if TF_ROCM_VERSION >= 40500 +#include "tensorflow/compiler/xla/stream_executor/rocm/hipsolver_wrapper.h" +#endif #include "tensorflow/compiler/xla/stream_executor/rocm/rocsolver_wrapper.h" #endif #include "tensorflow/core/framework/op_kernel.h" @@ -283,7 +287,7 @@ class GpuSolver { // Uses LU factorization to solve A * X = B. template Status Getrs(const gpuSolverOp_t trans, int n, int nrhs, Scalar* A, int lda, - const int* dev_pivots, Scalar* B, int ldb, int* dev_lapack_info); + int* dev_pivots, Scalar* B, int ldb, int* dev_lapack_info); template Status GetrfBatched(int n, Scalar** dev_A, int lda, int* dev_pivots, @@ -364,8 +368,9 @@ class GpuSolver { #if TF_ROCM_VERSION >= 40500 // Hermitian (Symmetric) Eigen decomposition. template - Status Heevd(gpuSolverOp_t jobz, gpuSolverFill_t uplo, int n, Scalar* dev_A, - int lda, typename Eigen::NumTraits::Real* dev_W, + Status Heevd(hipsolverEigMode_t jobz, gpuSolverFill_t uplo, int n, + Scalar* dev_A, int lda, + typename Eigen::NumTraits::Real* dev_W, int* dev_lapack_info); #endif @@ -550,6 +555,9 @@ class GpuSolver { #else // TENSORFLOW_USE_ROCM hipStream_t hip_stream_; rocblas_handle rocm_blas_handle_; +#if TF_ROCM_VERSION >= 40500 + hipsolverHandle_t hipsolver_handle_; +#endif #endif std::vector scratch_tensor_refs_; diff --git a/tensorflow/core/util/rocm_solvers.cc b/tensorflow/core/util/rocm_solvers.cc index 4b48664912f..c3395d3d8ea 100644 --- a/tensorflow/core/util/rocm_solvers.cc +++ b/tensorflow/core/util/rocm_solvers.cc @@ -582,7 +582,7 @@ TF_CALL_LAPACK_TYPES(POTRF_INSTANCE); #define GETRS_INSTANCE(Scalar, type_prefix) \ template <> \ Status GpuSolver::Getrs(rocblas_operation trans, int n, int nrhs, \ - Scalar* A, int lda, const int* dev_pivots, \ + Scalar* A, int lda, int* dev_pivots, \ Scalar* B, int ldb, int* dev_lapack_info) { \ mutex_lock lock(handle_map_mutex); \ using ROCmScalar = typename ROCmComplexT::type; \ @@ -645,7 +645,7 @@ TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE); pivots.bytes())) { \ return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \ } \ - TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getri_batched, type_prefix)( \ + TF_RETURN_IF_ROCBLAS_ERROR(ROCSOLVER_FN(getri_batched, type_prefix)( \ rocm_blas_handle_, n, \ reinterpret_cast(dev_a.mutable_data()), lda, \ reinterpret_cast(pivots.mutable_data()), stride, \ @@ -668,7 +668,7 @@ TF_CALL_ROCSOLV_TYPES(GETRI_BATCHED_INSTANCE); if (!CopyHostToDevice(context_, dev_a.mutable_data(), A, dev_a.bytes())) { \ return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \ } \ - TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf_batched, type_prefix)( \ + TF_RETURN_IF_ROCBLAS_ERROR(ROCSOLVER_FN(getrf_batched, type_prefix)( \ rocm_blas_handle_, n, n, \ reinterpret_cast(dev_a.mutable_data()), lda, dev_pivots, \ stride, dev_info->mutable_data(), batch_size)); \ @@ -696,7 +696,7 @@ TF_CALL_ROCSOLV_TYPES(GETRF_BATCHED_INSTANCE); if (!CopyHostToDevice(context_, dev_b.mutable_data(), B, dev_b.bytes())) { \ return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \ } \ - TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs_batched, type_prefix)( \ + TF_RETURN_IF_ROCBLAS_ERROR(ROCSOLVER_FN(getrs_batched, type_prefix)( \ rocm_blas_handle_, trans, n, nrhs, \ reinterpret_cast(dev_a.mutable_data()), lda, dev_pivots, \ stride, reinterpret_cast(dev_b.mutable_data()), ldb, \ diff --git a/tensorflow/core/util/rocm_sparse.cc b/tensorflow/core/util/rocm_sparse.cc index 49774583de0..f2c1247cb75 100644 --- a/tensorflow/core/util/rocm_sparse.cc +++ b/tensorflow/core/util/rocm_sparse.cc @@ -152,9 +152,9 @@ Status GpuSparse::Initialize() { return OkStatus(); } -#define TF_CALL_HIPSPARSE_DTYPES(m) \ - m(float, ROCM_R_32F) m(double, ROCM_R_64F) \ - m(std::complex, ROCM_C_32F) m(std::complex, ROCM_C_64F) +#define TF_CALL_HIPSPARSE_DTYPES(m) \ + m(float, HIP_R_32F) m(double, HIP_R_64F) m(std::complex, HIP_C_32F) \ + m(std::complex, HIP_C_64F) // Macro that specializes a sparse method for all 4 standard // numeric types. @@ -357,14 +357,13 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op, const hipsparseMatDescr_t descrA, Scalar* csrVal, const int* csrRowPtr, int* csrColInd) { - GpuSparseCsrSortingConversionInfo info; - TF_RETURN_IF_ERROR(info.Initialize()); + csru2csrInfo_t info; size_t pBufferSizeInBytes = 0; TF_RETURN_IF_GPUSPARSE_ERROR( buffer_size_op(hipsparse_handle, m, n, nnz, AsHipComplex(csrVal), - csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes)); + csrRowPtr, csrColInd, info, &pBufferSizeInBytes)); Tensor pBuffer_t; TF_RETURN_IF_ERROR(context->allocate_temp( @@ -375,7 +374,7 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op, TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, m, n, nnz, descrA, AsHipComplex(csrVal), csrRowPtr, csrColInd, - info.info(), pBuffer.data())); + info, pBuffer.data())); return OkStatus(); } diff --git a/tensorflow/tools/ci_build/Dockerfile.rocm b/tensorflow/tools/ci_build/Dockerfile.rocm index 47997e7e9da..b950b9bc4a8 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rocm +++ b/tensorflow/tools/ci_build/Dockerfile.rocm @@ -3,10 +3,10 @@ FROM ubuntu:focal MAINTAINER Jeff Poznanovic -ARG ROCM_DEB_REPO=https://repo.radeon.com/rocm/apt/5.1/ +ARG ROCM_DEB_REPO=https://repo.radeon.com/rocm/apt/5.3/ ARG ROCM_BUILD_NAME=ubuntu ARG ROCM_BUILD_NUM=main -ARG ROCM_PATH=/opt/rocm-5.1.0 +ARG ROCM_PATH=/opt/rocm-5.3.0 ENV DEBIAN_FRONTEND noninteractive ENV TF_NEED_ROCM 1 @@ -54,7 +54,7 @@ RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteracti libnuma-dev \ pciutils \ virtualenv \ - python-pip \ + python3-pip \ libxml2 \ libxml2-dev \ wget && \ diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index 7d60d2bd9a9..2a37a59b15c 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="5.1.1" +BAZEL_VERSION="5.3.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh b/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh index f12992e9dd5..d88b31de152 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_gpu_multi.sh @@ -25,7 +25,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} echo "" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-5.1.0 +ROCM_INSTALL_DIR=/opt/rocm-5.3.0 if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh b/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh index 5ea5e2e8f0b..8f7164fc50e 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_gpu_single.sh @@ -27,7 +27,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} echo "" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-5.1.0 +ROCM_INSTALL_DIR=/opt/rocm-5.3.0 if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/third_party/gpus/find_rocm_config.py b/third_party/gpus/find_rocm_config.py index 91674c4da6d..cd64efe6495 100644 --- a/third_party/gpus/find_rocm_config.py +++ b/third_party/gpus/find_rocm_config.py @@ -69,36 +69,27 @@ def _get_header_version(path, name): def _find_rocm_config(rocm_install_path): - def rocm_version_numbers_pre_rocm50(path, prior_err): - version_file = os.path.join(path, ".info/version-dev") - if not os.path.exists(version_file): - raise ConfigError("{} ROCm version file ".format(prior_err) + - '"{}" not found either.'.format(version_file)) - version_numbers = [] - with open(version_file) as f: - version_string = f.read().strip() - version_numbers = version_string.split(".") - major = int(version_numbers[0]) - minor = int(version_numbers[1]) - patch = int(version_numbers[2].split("-")[0]) - return major, minor, patch + def rocm_version_numbers(path): + possible_version_files = [ + "include/rocm-core/rocm_version.h", # ROCm 5.2 + "include/rocm_version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: + raise ConfigError( + "ROCm version file not found in {}".format(possible_version_files)) - def rocm_version_numbers_post_rocm50(path): - version_file = os.path.join(path, "include/rocm_version.h") - if not os.path.exists(version_file): - return False, 'ROCm version file "{}" not found.'.format(version_file) +\ - " Trying an alternate approach to determine the ROCm version.", 0,0,0 major = _get_header_version(version_file, "ROCM_VERSION_MAJOR") minor = _get_header_version(version_file, "ROCM_VERSION_MINOR") patch = _get_header_version(version_file, "ROCM_VERSION_PATCH") - return True, "", major, minor, patch + return major, minor, patch - status, error_msg, major, minor, patch = \ - rocm_version_numbers_post_rocm50(rocm_install_path) - - if not status: - major, minor, patch = \ - rocm_version_numbers_pre_rocm50(rocm_install_path, error_msg) + major, minor, patch = rocm_version_numbers(rocm_install_path) rocm_config = { "rocm_version_number": _get_composite_version_number(major, minor, patch) @@ -110,10 +101,20 @@ def _find_rocm_config(rocm_install_path): def _find_hipruntime_config(rocm_install_path): def hipruntime_version_number(path): - version_file = os.path.join(path, "hip/include/hip/hip_version.h") - if not os.path.exists(version_file): - raise ConfigError( - 'HIP Runtime version file "{}" not found'.format(version_file)) + possible_version_files = [ + "include/hip/hip_version.h", # ROCm 5.2 + "hip/include/hip/hip_version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: + raise ConfigError("HIP Runtime version file not found in {}".format( + possible_version_files)) + # This header file has an explicit #define for HIP_VERSION, whose value # is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR) # Retreive the major + minor and re-calculate here, since we do not @@ -132,8 +133,17 @@ def _find_hipruntime_config(rocm_install_path): def _find_miopen_config(rocm_install_path): def miopen_version_numbers(path): - version_file = os.path.join(path, "miopen/include/miopen/version.h") - if not os.path.exists(version_file): + possible_version_files = [ + "include/miopen/version.h", # ROCm 5.2 and prior + "miopen/include/miopen/version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: raise ConfigError( 'MIOpen version file "{}" not found'.format(version_file)) major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR") @@ -155,8 +165,8 @@ def _find_rocblas_config(rocm_install_path): def rocblas_version_numbers(path): possible_version_files = [ - "rocblas/include/rocblas-version.h", # ROCm 3.7 and prior - "rocblas/include/internal/rocblas-version.h", # ROCm 3.8 + "include/rocblas/internal/rocblas-version.h", # ROCm 5.2 + "rocblas/include/internal/rocblas-version.h", # ROCm 5.1 and prior ] version_file = None for f in possible_version_files: @@ -187,8 +197,8 @@ def _find_rocrand_config(rocm_install_path): def rocrand_version_number(path): possible_version_files = [ - "rocrand/include/rocrand_version.h", # ROCm 5.0 and prior "include/rocrand/rocrand_version.h", # ROCm 5.1 + "rocrand/include/rocrand_version.h", # ROCm 5.0 and prior ] version_file = None for f in possible_version_files: @@ -212,10 +222,19 @@ def _find_rocrand_config(rocm_install_path): def _find_rocfft_config(rocm_install_path): def rocfft_version_numbers(path): - version_file = os.path.join(path, "rocfft/include/rocfft-version.h") - if not os.path.exists(version_file): + possible_version_files = [ + "include/rocfft/rocfft-version.h", # ROCm 5.2 + "rocfft/include/rocfft-version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: raise ConfigError( - 'rocfft version file "{}" not found'.format(version_file)) + "rocfft version file not found in {}".format(possible_version_files)) major = _get_header_version(version_file, "rocfft_version_major") minor = _get_header_version(version_file, "rocfft_version_minor") patch = _get_header_version(version_file, "rocfft_version_patch") @@ -234,10 +253,19 @@ def _find_rocfft_config(rocm_install_path): def _find_hipfft_config(rocm_install_path): def hipfft_version_numbers(path): - version_file = os.path.join(path, "hipfft/include/hipfft-version.h") - if not os.path.exists(version_file): + possible_version_files = [ + "include/hipfft/hipfft-version.h", # ROCm 5.2 + "hipfft/include/hipfft-version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: raise ConfigError( - 'hipfft version file "{}" not found'.format(version_file)) + "hipfft version file not found in {}".format(possible_version_files)) major = _get_header_version(version_file, "hipfftVersionMajor") minor = _get_header_version(version_file, "hipfftVersionMinor") patch = _get_header_version(version_file, "hipfftVersionPatch") @@ -256,10 +284,19 @@ def _find_hipfft_config(rocm_install_path): def _find_roctracer_config(rocm_install_path): def roctracer_version_numbers(path): - version_file = os.path.join(path, "roctracer/include/roctracer.h") - if not os.path.exists(version_file): - raise ConfigError( - 'roctracer version file "{}" not found'.format(version_file)) + possible_version_files = [ + "include/roctracer/roctracer.h", # ROCm 5.2 + "roctracer/include/roctracer.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: + raise ConfigError("roctracer version file not found in {}".format( + possible_version_files)) major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR") minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR") # roctracer header does not have a patch version number @@ -279,10 +316,19 @@ def _find_roctracer_config(rocm_install_path): def _find_hipsparse_config(rocm_install_path): def hipsparse_version_numbers(path): - version_file = os.path.join(path, "hipsparse/include/hipsparse-version.h") - if not os.path.exists(version_file): - raise ConfigError( - 'hipsparse version file "{}" not found'.format(version_file)) + possible_version_files = [ + "include/hipsparse/hipsparse-version.h", # ROCm 5.2 + "hipsparse/include/hipsparse-version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: + raise ConfigError("hipsparse version file not found in {}".format( + possible_version_files)) major = _get_header_version(version_file, "hipsparseVersionMajor") minor = _get_header_version(version_file, "hipsparseVersionMinor") patch = _get_header_version(version_file, "hipsparseVersionPatch") @@ -301,8 +347,9 @@ def _find_hipsolver_config(rocm_install_path): def hipsolver_version_numbers(path): possible_version_files = [ - "hipsolver/include/hipsolver-version.h", # ROCm 5.0 and prior + "include/hipsolver/internal/hipsolver-version.h", # ROCm 5.2 "hipsolver/include/internal/hipsolver-version.h", # ROCm 5.1 + "hipsolver/include/hipsolver-version.h", # ROCm 5.0 and prior ] version_file = None for f in possible_version_files: @@ -331,10 +378,19 @@ def _find_hipsolver_config(rocm_install_path): def _find_rocsolver_config(rocm_install_path): def rocsolver_version_numbers(path): - version_file = os.path.join(path, "rocsolver/include/rocsolver-version.h") - if not os.path.exists(version_file): - raise ConfigError( - 'rocsolver version file "{}" not found'.format(version_file)) + possible_version_files = [ + "include/rocsolver/rocsolver-version.h", # ROCm 5.2 + "rocsolver/include/rocsolver-version.h", # ROCm 5.1 and prior + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: + raise ConfigError("rocsolver version file not found in {}".format( + possible_version_files)) major = _get_header_version(version_file, "ROCSOLVER_VERSION_MAJOR") minor = _get_header_version(version_file, "ROCSOLVER_VERSION_MINOR") patch = _get_header_version(version_file, "ROCSOLVER_VERSION_PATCH") diff --git a/third_party/gpus/find_rocm_config.py.gz.base64 b/third_party/gpus/find_rocm_config.py.gz.base64 index 4d838d5eb57..fee67ec553d 100644 --- a/third_party/gpus/find_rocm_config.py.gz.base64 +++ b/third_party/gpus/find_rocm_config.py.gz.base64 @@ -1 +1 @@ -eJztW21v2zgS/q5fQSgoIl8cxeldcYsccoA3TVHftUlhZ7tYNIEh27StrSz6SCqpUfS/3wxJyaQs+SV2+mmzWNSWhg+HM888HCnMEbli8wWPJ1NJXrdet8jdlJI7mgrG3yXsibQzOWVchKSdJKSLZoJ0qaD8kY5C78g7Ih/iIZjTEcnSEeVEwvj2PBrCP+ZOk3ymXMQsJa/DFgnQwDe3/Ma/AGHBMjKLFiRlkmSCAkQsyDhOKKHfhnQuSZySIZvNkzhKh5Q8xXKqpjEg4Ab5w0CwgYzAOgL7OXwb23Ykksph/JlKOb84O3t6egoj5WzI+OQs0Ybi7EPn6vqmd30KDqshv6UJFYJw+r8s5rDUwYJEc/BnGA3AyyR6IoyTaMIp3JMM/X3isYzTSZMINpZPEaeAMoqF5PEgk06wcu9gzbYBhCtKid/ukU7PJ7+2e51eEzB+79y9v/3tjvze7nbbN3ed6x657ZKr25u3nbvO7Q18e0faN3+Q/3Zu3jYJhVDBNPTbnKP/4GSMYVSpIz1KHQfGTDsk5nQYj+MhrCudZNGEkgl7pDyF5ZA55bNYYDIFuDcClCSexTKS6srKonCay4P+eL7vf+JxijS8vZrB9AMe8QU6Q6Y0wvlHkKKhZDymykfyqNkHlGLgIAZWrXIhJJ2FnoeEF0MeA88EjThwQahQ1MEjMYWL0oSMY9Sk8ODiDCkwohJDlaoQxzx3QgHNtf84fsjScTzJuAogjhNyxDIZKq/mERKd5eDIEJMbpNmUs2wyRZLQ9DHmLJ3RVJLHiMeKlAH4/7H/qX33vhF6nTEUF9xL4lFpytiEpamXo+OQO6jcoZyrVHMqM67STuASBGjIRtSNn4y+Ur2uPAcLy2MoGrxV+FXpd2jjJYx91cnQsdf5zHOiE6GqfRrx0Sn6M4IcSqh7T2QDmwdjzmZkEAkTVCMMS98Kf0MCsVq6COEBVfIKQxUmKMszNpdnnA1nPppkKH8R+CIh7+MoS3A9SUY9ZKvnQc0xDulj+Scm8k+gC+YTMMnzvGESQZ1eqRRdY5SDayWBkKrGhUfAe4FmMAvpT6jsm+n66EoflxYoM50r2017kDIGUskoSaxB4OvbnLU60nnKTdpmxAzS1MGRIS4Q5isjksta/8A8HhO/iLGPKWQiNGxAT6oBlzZfrNEPYH9UZ69c5DRK1NQrRo1lqFbu2QHDzYeJWNK+KeN+ms0GlAez6E/GmwQihv/AsOHUjv95C37I34gyIyf4Hb+hNXxT5vY0muT5HAF60SRpNKN5eromH6Dmc3AYJB2AYLhKGETxytYokxqsngR1CO7HLGRzmiP73IcdIoWyAWW/9DM5Pv3Fb+j4z9A3iCGnofoYcP9IT0ReCXIS3I9OGj55pbxrKvyGGgeZVfYahegyABx1MZyAZM2D84a5aaIEWhMou4aHoYtiqFK7Ao7zmf3vP1S56V3tPj0OYXGAHKgYkRMD6/74RPUVuAGizkAUVGfx/Qfsm/epn0MoOtRDxCCHei1aEVAFwW86oRhdKJoo8Rt5LsHZkWa81toK4l3gUtFY3XJZJfrgrBr/pmVSBYLMeB9k16QnH6CWsmT6nywukhvG6ZidGcPTEX30iwyh//kQ+g16DhHYgI08eRWpgLApIci3M+VAkYelmzWRxJ9jlUb0Qe/HOpthAeK40nCWa+IDK/7yoG6oVlBR2hmFWjwuKGjuYGsFHcwlGaMijIJGiFfmQaNkt5zEHRkK6Jxk4IcmkLqoLzV93bFfWg/GRpV6tc25sZmbUquyef2Qz3rqNwpUUzcV4rOeVUxIm1bbcylOh0k2omc2bDh9BqG04++iBJ8KjleZ5FKjhhPk5L4g1zG54wtMK9RjlEAdppGk2JtzBl39SiPmkDc8bpJWE/5z0lmlxPb0TbN1fb7u9qDh7n9s/+e267vp3hmjc1Ng5HTYFQM2w6v3vkOQO56hGah8DVVAkGQmmrqf68/EpNISnLlfbsrrWFW5wXoFRfR0F8tw18xTN9NSFVcmstagd5Gl+gLsdwXrV6D65OIZWzzA/fBKvYOezNkCpvGcZ6mMZ3SLjcAyLnmwW60Czller/gZ/t+rZFf2AEvYj993PpGu9npdIa/T9iNyh4/7VuMC3bx+yviGT9jQU+f7P7YyMGNO+SZ5mjJwTrfaGguQAstEVyf0XNh5nRDnDtZc7gI0VpzGj1oj8m5NV7Pug0+HUTLMEhQX2KygqkSsXkXAkw5TTwga5ymCFgPohGLOFNggE9idCXzkmEewetQqjtsWPpSLXaVnZW27K89KEFzROHc6VoWLFF3hclFWRHGumrv+xRpeV6qFXVcrczrVNYtx49+isoxhSU52LCsNUlSW+fpihfWxcwv4z62pHQgFE326vtmXU2WU5+1nJZSKHa1mH6veTGryXkE7hLDZZHO7EsW/sJK19/7hzF1+iBgkkdjuOUJZrmE5+CfiQbL0UL9AgWa6WIxvYM6sjg+/ny6J3lR6iW3U38N/5q+zQCNqIfAxCdqyZAPWLwrhoaoeb1iq9R03gDE+wlUvpdzx48XyywCrpsd55w9Vu6Zi+1YIK5xbMS0MB/CY8dWWBdt0KxnIY+nqwPLRCSIBmpALQemJqzpGuysE5OfXD+3eARpeF+bZPa8Ds69I1NVNtUq49WjLRDXOYXXCnb0sFBxqcTuhUJb1TeaWOoEotk7YqE5tvwlbVTpRGroB4vwveVjtEkzI1srDcaU8rBUHlxpb12W3ffM2r0u3Il1AU0gWX0uFVMFPaCRriLuxi3SnKlfNeCy3Kxo03K+F1CB2ycBXazc8cAup8X9CC1kKjhq5+/5QRsGRu28PJRQ1cr/doSrvtZvDkk0lSq+iHHxrWM5dfgmxHceN4X4c1yD2C4gX5bjG/wkc1xN91hc/Po/hLsbz+O1gfNqT3TUZr2a3w6PSw/9Ls9uZu6zgkkdDCNhWIm5s99ZxjWNLub7yIiquoX8CyWEPv+u2r667B+j1y0BWt3+0zERxZINR/Wu1afRISWQYmq/YtA12qbT2kfQaFtSqusOwkrBXYh1c2x0PyvIu8L3ilq+Yje3eIq9xbJ3XV15U6vUUP0ft9Vx7C74L82zNd2AOIPs1NKhVfodiJfGvxDq4/jselAuAJY9b7QBL2z1fkxVATgWoK5Wvt2oegVdhipdlm/D+eh6u0QgVtEM/Ee+oHcqH/bXDgXm+dtgwh9CO6gqq1w67OsvaUYV1eO2wPSi3j1uLx9J27/axVPTFlRd9GVBVGS/VRvZuP3w+SBtZBnr2S+MS0AFeG+9UB2WilbrIn1AHZQ9MHaycVSufNYzIKB7ieU88csvGev9RjqQUDwmb9eBps02HQSvOnXp1HK8qxA00P+4Vx3mXB3gVw4vHCwVe0Hx1Dp0rKvAEL2Tox/LrF316RDKWfI2lsvYfNBFK50bzEWE2H0WSBtscB2zUjNrmBEnd2E2/H68bt/F3jmsGrv8dxJqBa19RGZY4iSiVygP59yX5R+u81TI0qQ7mxmlq3Nv0jmFN/jY9nW2/uDcbFrdpJ1uzvs1DLSFREEY+ZhFsbKo0JV9cFF3nV7po5udlUyIYl3QUrEpNCFI2E0Gj2MTUHx4E/itxQV4JPF8cLJGU/+YPgqz6x7Oepj0UCxHqP1gI8c9vaODfp9fd7m33Akr5PrVO+wrJAwBsFMNAGCQeTPY8yEW/jweK+31yeUn8fh/X2O8rNdbL9f4PWsTtHA== \ No newline at end of file +eJztW21v2zgS/q5fQSgoam9dJSlugUMOOcCbZlHftUlhZ7tYtIVB27TNrSz6SCppUPS/7wxJyRIt2YqjvX6JgTa2NPNwOC8Ph7Z4RC7E+l7yxVKTVyevTsjNkpEblighf43FHemneimkikg/jskQxRQZMsXkLZtFwVFwRN7yKYizGUmTGZNEg35/Tafwx93pkQ9MKi4S8io6IR0UCN2tsPsvQLgXKVnRe5IITVLFAIIrMucxI+zrlK014QmZitU65jSZMnLH9dIM40DADPKHgxATTUGagvwaPs2LcoRqYzC+llqvz46P7+7uImqMjYRcHMdWUB2/HVxcXo0uX4LBRuW3JGZKEcn+l3IJU53cE7oGe6Z0AlbG9I4ISehCMrinBdp7J7nmyaJHlJjrOyoZoMy40pJPUl1yVmYdzLkoAO6iCQn7IzIYheSX/mgw6gHG74ObN9e/3ZDf+8Nh/+pmcDki10NycX31enAzuL6CT7+S/tUf5L+Dq9c9wsBVMAz7upZoPxjJ0Y0mdGTEWMmAubAGqTWb8jmfwrySRUoXjCzELZMJTIesmVxxhcFUYN4MUGK+4ppqc2VrUjjMeauvIAzD95InmIbXFysYfiKpvEdjyJJRHH8GIZpqITkzNpJbm32QUgIMRMeaWd4rzVZREGDCq6nkkGeKUQm5oIwr6uAxMVUZpQcRR69pFcDFFabAjGl0VWJczGVmhAFaW/tRfyqSOV+k0jgQ9ZSeiVRHxqo1xUQXGThmiIsNptlSinSxxCRhyS2XIlmxRJNbKrlJyg7Y/278vn/zphsFgzkUF9yL+cwbkju39Ox0rB8yA405TEoTasl0Kk3YCVwCB03FjJX9p+kXZueVxeC+YDEUDd7K7aq0OyrixUJ8scGwvrfxzGJiA2GqfUnl7CXaM4MYaqj7QKWTYh7MpViRCVXOqY4YNrbl9kYEfLUxEdwDrBTkgsZNUJbHYq2PpZiuQhRJkf4o2KIh7nOaxjifOGUBZmsQQM0JCeET2TuhsnfAC+4dZFIQBNOYQp1emBBdopc7l4YCIVTds4CA9QrFYBQyXjA9dsON0ZQxTq1jxGysimYWlYwwJJWmcVxQAltfZ1lrPZ2F3IVtRZySTR3UjHCCMJ6PSM5r7QNxPidh7uMQQyhU5LIBLakG3Mh8LGh/BvmjOnljomQ0NkNvCXU3rtq6V3QYLj5Ccc3GrozHSbqaMNlZ0T+F7BHwGP4Btemy6P/TE3iRn4gRIy/wM35CafhkxIvD2CTPxuigFT2S0BXLwjN08QA2X4PBQOkABOomYODFiyJHudBg9cTIQ3Cfi0isWYYcyhBWiATKBpj9PEz1/OU/w671/wptAx9KFpm3HRke2YHIM0VedD7NXnRD8sxY1zP4XaMHkTXyFoXYMgAcczFaAGWtO6ddd9N5CbimY+S6AbqOcqjSYgU8z0YOv3035WZXtU/J8wgmB8gd4yPywsGWXyExfQUugMgz4AXTWXz7DuvmpyTMIEw61ENwoEM7F8sIyIJgN1sw9C4UDY3DbhZLMHZmM95ybUXineFUUdjcKmeV6jgRHBvyTnHgxVzGst45+ZibGvJkGqczZgodWFDad5lGtIQwQ4mYAv45elWtVyN9mi1YQhq9z+b/oi1gypVImLmOuTZHF1dbnSdF4aJfrH8KnuXnPMsTSKrsNvsKLZLqbCF0z/JpecZtieaCEyCGL1nWYkyLohnedjoWUiQ0XsqWdjMg4tg+A9wA+Zrll5dY1Q7qmgogji/OK1mhqNBzNPrhcjiC5m/8rv+f62Fo3WZZ5gCMwVWOsXYs8FAMIOaLNw7DFXkFU+JcKy4j61QVRQV7ZyuPLTNQ/ObcHFYghGcHkDngfQ+8VcKOVir2JV/LNNF8xRqUfEHYs+DAugdA/Le/4FGwmdJT3e+s+zeD92RoY/g3lP8RbMJhmSms5dDg2sb7K246oc3MlkT0PBiTVV6P3C0F2G27T4sFSJ2CiCUJaEOwGXlBSnew9LtODXoNyfit3SJmDYwlFdsavpzSeJpCLwjWMQkcoLjZnUPzL0zTbHHuKKy6UHe4XAoDNkkVNiwKu/A1hfnDvhK6f0gH3KeqhzLg1tweToBbTihz12mpiTO4GKitoi8SUG2RAw3VE0Blh1okoK0xSzS04tjeNaAgJ/i4vqPEQRbxuMQm7rVhIo9UchCnvAvriZkaMdPzd4Nr8F2ZlEzfnDNT3jUXwbvdh5YdDPT+8uqxleejHNZ8eCiPbT9qqqO6ASnVXJEBKlHCs0KwHt2OlMb2dx+TmKpmGxAj2d4eBOGOcXMkExpnF17ubU42mhapIcITH+zsVJzz2u1SHkgVEKlf3vZHLWxTyjAH71RKMC1sVioLqHa/UihMb8tSgdMuYZRH9xlDQiU1Ywwj2c7mxaEde6h+lZd4wih4ADWKJ0/0sIceZOE3kt30sJMPytnQuBSH/avXWSmWi7AM6GqnkKJe7VSkJLTaNbm6t88uD+UXynyum9UJCra3sAKa+9NoNUXxsvbTEnpgjYDvfugK6qWS0Xz4AuqjoObD108PxWg+bvmsqpLa1XNTex4BbKO0vnZuxva//WvGCE6wJUawaO7PfkZw4mXtJ0Y4hBGs734oI1gTPtiL7w7jgzLGYWxQwnj/SC6oqY9qLihVnffV29/NBaWx/e5ASzoFhzVqEJxsez2CBdy829khOOEt9Sc2eGB/YN32ozfZN8P+xeWwhW22D1TYaB9tsjZ/3kow+5v4kt4yQl01Z75w7XuRVk4e0yzUVExtv1CqRq9lqMRqvWsoWeA3Dgp/AWn4q6GTba99sICbd42aCKe0BfPUShzYSlj3/ehuwlrx6IaiDHNwT1GCaaGtqCmd2s6iVJZec1GJ1Xp/UbLAJw0R3zbqMDayLZKGAdx8PZ9fasYembb3Jf8+lNNdKPuUn74E3OKgjRcfTDxtko6x4PGkU4I5nHSKMG2QTnXp1ZNOsax90qnCap90ihb4+5rGrLORbW9f4yo+f9foG1CPJvYpP7Uq+/Y5BzJGm6wBsRpdv/3Qyj7HBzr4B0UPqIWfFB/EHX5xetuc/wN3+BY47th6Etp/kp2SGZ/iaQI80CHmthKNIQnDIyhuPjyZi31HDSpONQR5tnslVUVeewrg+Sg/LLI5HmKerMn3vwY8f7xmewwbK6bwfAhE6Pvm40f7wKoWIv7CtZEOP2cPwpZOJWQaUbqeUc064wYPm3drtJo8tVqnu+9Rszq9vQ+m7FDc/fv0DsWd3867LCkFwiuVz+Tf5+QfJ6cnJy5Nqp25d5ga8/Z9Ybgjfvu+Pmg+uZ/3TG7f6r9jfvtVC0RiIBx9rCisj6Y0tbw/y1fdL+y+l53GSIgSUrNZZ5tqIqCylep08+XSHGvrhM/UGXmm8PRKZ4Nk7HfHTQv1j6e63PKo7lVkj8NFeLiTdcJPyeVweD08g1L+lBTOkigtOwDYzdWAGDQeewkCiMV4jMdVxmNyfk7C8RjnOB4bNrbTDf4CEAHDYg== \ No newline at end of file diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 001cb1805bd..072329d00eb 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -178,7 +178,13 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") # Add HIP headers (needs to match $HIP_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") + if int(rocm_config.rocm_version_number) >= 50200: + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocprim") + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocsolver") + else: + inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") # Add HIP-Clang headers (realpath relative to compiler binary) rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) @@ -189,6 +195,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") @@ -314,7 +321,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): return libs -def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin): +def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin): """Returns the ROCm libraries on the system. Args: @@ -329,19 +336,19 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin): (name, _rocm_lib_paths(repository_ctx, name, path)) for name, path in [ ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"), - ("rocblas", rocm_config.rocm_toolkit_path + "/rocblas"), - (hipfft_or_rocfft, rocm_config.rocm_toolkit_path + "/" + hipfft_or_rocfft), + ("rocblas", rocm_config.rocm_toolkit_path), + (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), - ("MIOpen", rocm_config.rocm_toolkit_path + "/miopen"), - ("rccl", rocm_config.rocm_toolkit_path + "/rccl"), - ("hipsparse", rocm_config.rocm_toolkit_path + "/hipsparse"), + ("MIOpen", miopen_path), + ("rccl", rccl_path), + ("hipsparse", rocm_config.rocm_toolkit_path), ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"), - ("rocsolver", rocm_config.rocm_toolkit_path + "/rocsolver"), + ("rocsolver", rocm_config.rocm_toolkit_path), ] ] if int(rocm_config.rocm_version_number) >= 40500: - libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path + "/hipsolver"))) - libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path + "/hipblas"))) + libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path))) + libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path))) return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) def _exec_find_rocm_config(repository_ctx, script_path): @@ -463,7 +470,7 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": "hipfft", + "%{hipfft_or_rocfft}": _lib_name("hipfft"), "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), @@ -550,6 +557,10 @@ def _create_local_rocm_repository(repository_ctx): rocm_version_number = int(rocm_config.rocm_version_number) hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft" + # For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path + miopen_path = rocm_config.rocm_toolkit_path + "/miopen" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path + rccl_path = rocm_config.rocm_toolkit_path + "/rccl" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path + # Copy header and library files to execroot. # rocm_toolkit_path rocm_toolkit_path = rocm_config.rocm_toolkit_path @@ -561,69 +572,14 @@ def _create_local_rocm_repository(repository_ctx): out_dir = "rocm/include", exceptions = ["gtest", "gmock"], ), - make_copy_dir_rule( - repository_ctx, - name = hipfft_or_rocfft + "-include", - src_dir = rocm_toolkit_path + "/" + hipfft_or_rocfft + "/include", - out_dir = "rocm/include/" + hipfft_or_rocfft, - ), - make_copy_dir_rule( - repository_ctx, - name = "rocblas-include", - src_dir = rocm_toolkit_path + "/rocblas/include", - out_dir = "rocm/include/rocblas", - ), make_copy_dir_rule( repository_ctx, name = "rocblas-hsaco", src_dir = rocm_toolkit_path + "/rocblas/lib/library", out_dir = "rocm/lib/rocblas/lib/library", ), - make_copy_dir_rule( - repository_ctx, - name = "miopen-include", - src_dir = rocm_toolkit_path + "/miopen/include", - out_dir = "rocm/include/miopen", - ), - make_copy_dir_rule( - repository_ctx, - name = "rccl-include", - src_dir = rocm_toolkit_path + "/rccl/include", - out_dir = "rocm/include/rccl", - ), - make_copy_dir_rule( - repository_ctx, - name = "hipsparse-include", - src_dir = rocm_toolkit_path + "/hipsparse/include", - out_dir = "rocm/include/hipsparse", - ), - make_copy_dir_rule( - repository_ctx, - name = "rocsolver-include", - src_dir = rocm_toolkit_path + "/rocsolver/include", - out_dir = "rocm/include/rocsolver", - ), ] - # Add Hipsolver on ROCm4.5+ - if rocm_version_number >= 40500: - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "hipsolver-include", - src_dir = rocm_toolkit_path + "/hipsolver/include", - out_dir = "rocm/include/hipsolver", - ), - ) - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "hipblas-include", - src_dir = rocm_toolkit_path + "/hipblas/include", - out_dir = "rocm/include/hipblas", - ), - ) - # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include # dir has been removed. This removal will happen in a near-future ROCm release. @@ -655,7 +611,7 @@ def _create_local_rocm_repository(repository_ctx): ), ) - rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin) + rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin) rocm_lib_srcs = [] rocm_lib_outs = [] for lib in rocm_libs.values(): @@ -710,20 +666,12 @@ def _create_local_rocm_repository(repository_ctx): "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name, "%{copy_rules}": "\n".join(copy_rules), "%{rocm_headers}": ('":rocm-include",\n' + - '":' + hipfft_or_rocfft + '-include",\n' + - '":rocblas-include",\n' + - '":miopen-include",\n' + - '":rccl-include",\n' + hiprand_include + - rocrand_include + - '":hipsparse-include",\n' + - '":rocsolver-include"'), + rocrand_include), } if rocm_version_number >= 40500: repository_dict["%{hipsolver_lib}"] = rocm_libs["hipsolver"].file_name - repository_dict["%{rocm_headers}"] += ',\n":hipsolver-include"' repository_dict["%{hipblas_lib}"] = rocm_libs["hipblas"].file_name - repository_dict["%{rocm_headers}"] += ',\n":hipblas-include"' repository_ctx.template( "rocm/BUILD",