mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Merge pull request #56762 from ROCmSoftwarePlatform:fix-upstream-rocm
PiperOrigin-RevId: 489405842
This commit is contained in:
@@ -718,7 +718,8 @@ def create_android_sdk_rule(environ_cp):
|
||||
|
||||
|
||||
def get_ndk_api_level(environ_cp, android_ndk_home_path):
|
||||
"""Gets the appropriate NDK API level to use for the provided Android NDK path."""
|
||||
"""Gets the appropriate NDK API level to use for the provided Android NDK path.
|
||||
"""
|
||||
|
||||
# First check to see if we're using a blessed version of the NDK.
|
||||
properties_path = '%s/source.properties' % android_ndk_home_path
|
||||
@@ -1215,8 +1216,6 @@ def main():
|
||||
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')):
|
||||
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
|
||||
write_action_env_to_bazelrc('ROCBLAS_TENSILE_LIBPATH',
|
||||
environ_cp.get('ROCM_PATH') + '/lib/library')
|
||||
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('HIP_PLATFORM')):
|
||||
write_action_env_to_bazelrc('HIP_PLATFORM', environ_cp.get('HIP_PLATFORM'))
|
||||
|
||||
@@ -1013,7 +1013,12 @@ tsl::Status RocmActivityCallbackImpl::operator()(const char* begin,
|
||||
}
|
||||
|
||||
RETURN_IF_ROCTRACER_ERROR(static_cast<roctracer_status_t>(
|
||||
roctracer_next_record(record, &record)));
|
||||
#if TF_ROCM_VERSION >= 50300
|
||||
se::wrap::roctracer_next_record(record, &record)
|
||||
#else
|
||||
roctracer_next_record(record, &record)
|
||||
#endif
|
||||
));
|
||||
}
|
||||
|
||||
return tsl::OkStatus();
|
||||
@@ -1486,7 +1491,11 @@ void RocmTracer::ActivityCallbackHandler(const char* begin, const char* end) {
|
||||
while (record < end_record) {
|
||||
DumpActivityRecord(record,
|
||||
"activity_tracing_enabled_ is false. Dropped!");
|
||||
#if TF_ROCM_VERSION >= 50300
|
||||
se::wrap::roctracer_next_record(record, &record);
|
||||
#else
|
||||
roctracer_next_record(record, &record);
|
||||
#endif
|
||||
}
|
||||
VLOG(3) << "Dropped Activity Records End";
|
||||
}
|
||||
|
||||
@@ -30,8 +30,12 @@ limitations under the License.
|
||||
|
||||
// Common place for all collective thunks to include nccl/rccl headers.
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#else
|
||||
#include "rocm/include/rccl.h"
|
||||
#endif
|
||||
#else
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#endif
|
||||
|
||||
|
||||
@@ -20,7 +20,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
||||
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
#else
|
||||
#include "rocm/include/hipsparse.h"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/env.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
||||
@@ -22,7 +22,11 @@ limitations under the License.
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/hipfft/hipfft.h"
|
||||
#else
|
||||
#include "rocm/include/hipfft.h"
|
||||
#endif
|
||||
#include "rocm/rocm_config.h"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -20,7 +20,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_
|
||||
|
||||
#include "rocm/rocm_config.h"
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rocsolver/rocsolver.h"
|
||||
#else
|
||||
#include "rocm/include/rocsolver.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/env.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
||||
@@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "rocm/include/roctracer/roctracer.h"
|
||||
#include "rocm/include/roctracer/roctracer_hcc.h"
|
||||
#include "rocm/include/roctracer/roctracer_hip.h"
|
||||
#include "rocm/rocm_config.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/env.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
@@ -60,6 +61,25 @@ namespace wrap {
|
||||
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
#if TF_ROCM_VERSION >= 50300
|
||||
#define FOREACH_ROCTRACER_API(DO_FUNC) \
|
||||
DO_FUNC(roctracer_default_pool_expl) \
|
||||
DO_FUNC(roctracer_disable_domain_activity) \
|
||||
DO_FUNC(roctracer_disable_domain_callback) \
|
||||
DO_FUNC(roctracer_disable_op_activity) \
|
||||
DO_FUNC(roctracer_disable_op_callback) \
|
||||
DO_FUNC(roctracer_enable_domain_activity_expl) \
|
||||
DO_FUNC(roctracer_enable_domain_callback) \
|
||||
DO_FUNC(roctracer_enable_op_activity_expl) \
|
||||
DO_FUNC(roctracer_enable_op_callback) \
|
||||
DO_FUNC(roctracer_error_string) \
|
||||
DO_FUNC(roctracer_flush_activity_expl) \
|
||||
DO_FUNC(roctracer_get_timestamp) \
|
||||
DO_FUNC(roctracer_op_string) \
|
||||
DO_FUNC(roctracer_open_pool_expl) \
|
||||
DO_FUNC(roctracer_set_properties) \
|
||||
DO_FUNC(roctracer_next_record)
|
||||
#else
|
||||
#define FOREACH_ROCTRACER_API(DO_FUNC) \
|
||||
DO_FUNC(roctracer_default_pool_expl) \
|
||||
DO_FUNC(roctracer_disable_domain_activity) \
|
||||
@@ -76,7 +96,7 @@ namespace wrap {
|
||||
DO_FUNC(roctracer_op_string) \
|
||||
DO_FUNC(roctracer_open_pool_expl) \
|
||||
DO_FUNC(roctracer_set_properties)
|
||||
|
||||
#endif
|
||||
FOREACH_ROCTRACER_API(ROCTRACER_API_WRAPPER)
|
||||
|
||||
#undef FOREACH_ROCTRACER_API
|
||||
|
||||
@@ -80,11 +80,29 @@ struct NumericTraits<tensorflow::bfloat16>
|
||||
} // namespace cub
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#include "rocm/rocm_config.h"
|
||||
namespace gpuprim = ::hipcub;
|
||||
|
||||
// Required for sorting Eigen::half and bfloat16.
|
||||
namespace rocprim {
|
||||
namespace detail {
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
template <>
|
||||
struct float_bit_mask<Eigen::half> {
|
||||
static constexpr uint16_t sign_bit = 0x8000;
|
||||
static constexpr uint16_t exponent = 0x7C00;
|
||||
static constexpr uint16_t mantissa = 0x03FF;
|
||||
using bit_type = uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct float_bit_mask<Eigen::bfloat16> {
|
||||
static constexpr uint16_t sign_bit = 0x8000;
|
||||
static constexpr uint16_t exponent = 0x7F80;
|
||||
static constexpr uint16_t mantissa = 0x007F;
|
||||
using bit_type = uint16_t;
|
||||
};
|
||||
#endif
|
||||
template <>
|
||||
struct radix_key_codec_base<Eigen::half>
|
||||
: radix_key_codec_floating<Eigen::half, uint16_t> {};
|
||||
|
||||
@@ -20,7 +20,11 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#else
|
||||
#include "rocm/include/rccl.h"
|
||||
#endif
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/nccl/nccl_manager.h"
|
||||
|
||||
@@ -735,14 +735,12 @@ namespace {
|
||||
template <typename T>
|
||||
struct GPUDataType;
|
||||
|
||||
// GPUDataType templates are currently not instantiated in the ROCm flow
|
||||
// So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
|
||||
// hipblas library is not (yet) being pulled in via rocm_configure.bzl
|
||||
// so cannot reference tyeps from hipblas headers here
|
||||
template <>
|
||||
struct GPUDataType<Eigen::half> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_16F;
|
||||
#else
|
||||
static constexpr hipDataType type = HIP_R_16F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -750,6 +748,8 @@ template <>
|
||||
struct GPUDataType<float> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_32F;
|
||||
#else
|
||||
static constexpr hipDataType type = HIP_R_32F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -757,6 +757,8 @@ template <>
|
||||
struct GPUDataType<std::complex<float>> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_C_32F;
|
||||
#else
|
||||
static constexpr hipDataType type = HIP_C_32F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -764,6 +766,8 @@ template <>
|
||||
struct GPUDataType<double> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_64F;
|
||||
#else
|
||||
static constexpr hipDataType type = HIP_R_64F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -771,6 +775,8 @@ template <>
|
||||
struct GPUDataType<std::complex<double>> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||
#else
|
||||
static constexpr hipDataType type = HIP_C_64F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -883,16 +889,16 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateCsr(
|
||||
&matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
|
||||
const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO,
|
||||
HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO,
|
||||
GPUDataType<T>::type));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateDnMat(
|
||||
&matB, n, k, ldb, const_cast<T*>(b.data()), GPUDataType<T>::type,
|
||||
HIPSPARSE_ORDER_COL));
|
||||
HIPSPARSE_ORDER_COLUMN));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateDnMat(
|
||||
&matC, m, n, ldc, c.data(), GPUDataType<T>::type,
|
||||
HIPSPARSE_ORDER_COL));
|
||||
HIPSPARSE_ORDER_COLUMN));
|
||||
|
||||
size_t bufferSize = 0;
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
|
||||
@@ -905,7 +911,7 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
DCHECK(buffer.flat<int8>().data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
|
||||
&beta, matC, CUSPARSE_MM_ALG_DEFAULT,
|
||||
&beta, matC, HIPSPARSE_MM_ALG_DEFAULT,
|
||||
buffer.flat<int8>().data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseDestroyDnMat(matB));
|
||||
|
||||
@@ -30,7 +30,11 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#else
|
||||
#include "rocm/include/rccl.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
|
||||
@@ -733,6 +733,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:rocblas_wrapper",
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:rocm_gpu_executor",
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:rocsolver_wrapper",
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:hipsolver_wrapper",
|
||||
] + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
@@ -755,6 +756,7 @@ tf_kernel_library(
|
||||
"@local_config_cuda//cuda:cub_headers",
|
||||
]) + if_rocm([
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:rocsolver_wrapper",
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:hipsolver_wrapper",
|
||||
"//tensorflow/compiler/xla/stream_executor/rocm:hipsparse_wrapper",
|
||||
]),
|
||||
)
|
||||
|
||||
@@ -490,6 +490,17 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#define GETRS_INSTANCE(Scalar, type_prefix) \
|
||||
template <> \
|
||||
Status GpuSolver::Getrs<Scalar>( \
|
||||
cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \
|
||||
int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \
|
||||
return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \
|
||||
cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
|
||||
ldb, dev_lapack_info); \
|
||||
}
|
||||
#else
|
||||
#define GETRS_INSTANCE(Scalar, type_prefix) \
|
||||
template <> \
|
||||
Status GpuSolver::Getrs<Scalar>( \
|
||||
@@ -499,6 +510,7 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
|
||||
cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
|
||||
ldb, dev_lapack_info); \
|
||||
}
|
||||
#endif
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/rocm_config.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/rocm/hipsparse_wrapper.h"
|
||||
|
||||
using gpusparseStatus_t = hipsparseStatus_t;
|
||||
|
||||
@@ -33,7 +33,11 @@ limitations under the License.
|
||||
#else
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
#include "rocm/include/rocblas.h"
|
||||
#include "rocm/rocm_config.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/blas.h"
|
||||
#if TF_ROCM_VERSION >= 40500
|
||||
#include "tensorflow/compiler/xla/stream_executor/rocm/hipsolver_wrapper.h"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/stream_executor/rocm/rocsolver_wrapper.h"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@@ -283,7 +287,7 @@ class GpuSolver {
|
||||
// Uses LU factorization to solve A * X = B.
|
||||
template <typename Scalar>
|
||||
Status Getrs(const gpuSolverOp_t trans, int n, int nrhs, Scalar* A, int lda,
|
||||
const int* dev_pivots, Scalar* B, int ldb, int* dev_lapack_info);
|
||||
int* dev_pivots, Scalar* B, int ldb, int* dev_lapack_info);
|
||||
|
||||
template <typename Scalar>
|
||||
Status GetrfBatched(int n, Scalar** dev_A, int lda, int* dev_pivots,
|
||||
@@ -364,8 +368,9 @@ class GpuSolver {
|
||||
#if TF_ROCM_VERSION >= 40500
|
||||
// Hermitian (Symmetric) Eigen decomposition.
|
||||
template <typename Scalar>
|
||||
Status Heevd(gpuSolverOp_t jobz, gpuSolverFill_t uplo, int n, Scalar* dev_A,
|
||||
int lda, typename Eigen::NumTraits<Scalar>::Real* dev_W,
|
||||
Status Heevd(hipsolverEigMode_t jobz, gpuSolverFill_t uplo, int n,
|
||||
Scalar* dev_A, int lda,
|
||||
typename Eigen::NumTraits<Scalar>::Real* dev_W,
|
||||
int* dev_lapack_info);
|
||||
#endif
|
||||
|
||||
@@ -550,6 +555,9 @@ class GpuSolver {
|
||||
#else // TENSORFLOW_USE_ROCM
|
||||
hipStream_t hip_stream_;
|
||||
rocblas_handle rocm_blas_handle_;
|
||||
#if TF_ROCM_VERSION >= 40500
|
||||
hipsolverHandle_t hipsolver_handle_;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
std::vector<TensorReference> scratch_tensor_refs_;
|
||||
|
||||
@@ -582,7 +582,7 @@ TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
|
||||
#define GETRS_INSTANCE(Scalar, type_prefix) \
|
||||
template <> \
|
||||
Status GpuSolver::Getrs<Scalar>(rocblas_operation trans, int n, int nrhs, \
|
||||
Scalar* A, int lda, const int* dev_pivots, \
|
||||
Scalar* A, int lda, int* dev_pivots, \
|
||||
Scalar* B, int ldb, int* dev_lapack_info) { \
|
||||
mutex_lock lock(handle_map_mutex); \
|
||||
using ROCmScalar = typename ROCmComplexT<Scalar>::type; \
|
||||
@@ -645,7 +645,7 @@ TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
|
||||
pivots.bytes())) { \
|
||||
return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
|
||||
} \
|
||||
TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getri_batched, type_prefix)( \
|
||||
TF_RETURN_IF_ROCBLAS_ERROR(ROCSOLVER_FN(getri_batched, type_prefix)( \
|
||||
rocm_blas_handle_, n, \
|
||||
reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, \
|
||||
reinterpret_cast<int*>(pivots.mutable_data()), stride, \
|
||||
@@ -668,7 +668,7 @@ TF_CALL_ROCSOLV_TYPES(GETRI_BATCHED_INSTANCE);
|
||||
if (!CopyHostToDevice(context_, dev_a.mutable_data(), A, dev_a.bytes())) { \
|
||||
return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \
|
||||
} \
|
||||
TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf_batched, type_prefix)( \
|
||||
TF_RETURN_IF_ROCBLAS_ERROR(ROCSOLVER_FN(getrf_batched, type_prefix)( \
|
||||
rocm_blas_handle_, n, n, \
|
||||
reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, dev_pivots, \
|
||||
stride, dev_info->mutable_data(), batch_size)); \
|
||||
@@ -696,7 +696,7 @@ TF_CALL_ROCSOLV_TYPES(GETRF_BATCHED_INSTANCE);
|
||||
if (!CopyHostToDevice(context_, dev_b.mutable_data(), B, dev_b.bytes())) { \
|
||||
return errors::Internal("GetrfBatched: Failed to copy ptrs to device"); \
|
||||
} \
|
||||
TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs_batched, type_prefix)( \
|
||||
TF_RETURN_IF_ROCBLAS_ERROR(ROCSOLVER_FN(getrs_batched, type_prefix)( \
|
||||
rocm_blas_handle_, trans, n, nrhs, \
|
||||
reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, dev_pivots, \
|
||||
stride, reinterpret_cast<ROCmScalar**>(dev_b.mutable_data()), ldb, \
|
||||
|
||||
@@ -152,9 +152,9 @@ Status GpuSparse::Initialize() {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
#define TF_CALL_HIPSPARSE_DTYPES(m) \
|
||||
m(float, ROCM_R_32F) m(double, ROCM_R_64F) \
|
||||
m(std::complex<float>, ROCM_C_32F) m(std::complex<double>, ROCM_C_64F)
|
||||
#define TF_CALL_HIPSPARSE_DTYPES(m) \
|
||||
m(float, HIP_R_32F) m(double, HIP_R_64F) m(std::complex<float>, HIP_C_32F) \
|
||||
m(std::complex<double>, HIP_C_64F)
|
||||
|
||||
// Macro that specializes a sparse method for all 4 standard
|
||||
// numeric types.
|
||||
@@ -357,14 +357,13 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||
const hipsparseMatDescr_t descrA,
|
||||
Scalar* csrVal, const int* csrRowPtr,
|
||||
int* csrColInd) {
|
||||
GpuSparseCsrSortingConversionInfo info;
|
||||
TF_RETURN_IF_ERROR(info.Initialize());
|
||||
csru2csrInfo_t info;
|
||||
|
||||
size_t pBufferSizeInBytes = 0;
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
buffer_size_op(hipsparse_handle, m, n, nnz, AsHipComplex(csrVal),
|
||||
csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
|
||||
csrRowPtr, csrColInd, info, &pBufferSizeInBytes));
|
||||
|
||||
Tensor pBuffer_t;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
@@ -375,7 +374,7 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, m, n, nnz, descrA,
|
||||
AsHipComplex(csrVal), csrRowPtr, csrColInd,
|
||||
info.info(), pBuffer.data()));
|
||||
info, pBuffer.data()));
|
||||
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
FROM ubuntu:focal
|
||||
MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
|
||||
|
||||
ARG ROCM_DEB_REPO=https://repo.radeon.com/rocm/apt/5.1/
|
||||
ARG ROCM_DEB_REPO=https://repo.radeon.com/rocm/apt/5.3/
|
||||
ARG ROCM_BUILD_NAME=ubuntu
|
||||
ARG ROCM_BUILD_NUM=main
|
||||
ARG ROCM_PATH=/opt/rocm-5.1.0
|
||||
ARG ROCM_PATH=/opt/rocm-5.3.0
|
||||
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
ENV TF_NEED_ROCM 1
|
||||
@@ -54,7 +54,7 @@ RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteracti
|
||||
libnuma-dev \
|
||||
pciutils \
|
||||
virtualenv \
|
||||
python-pip \
|
||||
python3-pip \
|
||||
libxml2 \
|
||||
libxml2-dev \
|
||||
wget && \
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# ==============================================================================
|
||||
|
||||
# Select bazel version.
|
||||
BAZEL_VERSION="5.1.1"
|
||||
BAZEL_VERSION="5.3.0"
|
||||
|
||||
set +e
|
||||
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
|
||||
|
||||
@@ -25,7 +25,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
|
||||
echo ""
|
||||
|
||||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
||||
ROCM_INSTALL_DIR=/opt/rocm-5.1.0
|
||||
ROCM_INSTALL_DIR=/opt/rocm-5.3.0
|
||||
if [[ -n $1 ]]; then
|
||||
ROCM_INSTALL_DIR=$1
|
||||
fi
|
||||
|
||||
@@ -27,7 +27,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
|
||||
echo ""
|
||||
|
||||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
||||
ROCM_INSTALL_DIR=/opt/rocm-5.1.0
|
||||
ROCM_INSTALL_DIR=/opt/rocm-5.3.0
|
||||
if [[ -n $1 ]]; then
|
||||
ROCM_INSTALL_DIR=$1
|
||||
fi
|
||||
|
||||
162
third_party/gpus/find_rocm_config.py
vendored
162
third_party/gpus/find_rocm_config.py
vendored
@@ -69,36 +69,27 @@ def _get_header_version(path, name):
|
||||
|
||||
def _find_rocm_config(rocm_install_path):
|
||||
|
||||
def rocm_version_numbers_pre_rocm50(path, prior_err):
|
||||
version_file = os.path.join(path, ".info/version-dev")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError("{} ROCm version file ".format(prior_err) +
|
||||
'"{}" not found either.'.format(version_file))
|
||||
version_numbers = []
|
||||
with open(version_file) as f:
|
||||
version_string = f.read().strip()
|
||||
version_numbers = version_string.split(".")
|
||||
major = int(version_numbers[0])
|
||||
minor = int(version_numbers[1])
|
||||
patch = int(version_numbers[2].split("-")[0])
|
||||
return major, minor, patch
|
||||
def rocm_version_numbers(path):
|
||||
possible_version_files = [
|
||||
"include/rocm-core/rocm_version.h", # ROCm 5.2
|
||||
"include/rocm_version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError(
|
||||
"ROCm version file not found in {}".format(possible_version_files))
|
||||
|
||||
def rocm_version_numbers_post_rocm50(path):
|
||||
version_file = os.path.join(path, "include/rocm_version.h")
|
||||
if not os.path.exists(version_file):
|
||||
return False, 'ROCm version file "{}" not found.'.format(version_file) +\
|
||||
" Trying an alternate approach to determine the ROCm version.", 0,0,0
|
||||
major = _get_header_version(version_file, "ROCM_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "ROCM_VERSION_MINOR")
|
||||
patch = _get_header_version(version_file, "ROCM_VERSION_PATCH")
|
||||
return True, "", major, minor, patch
|
||||
return major, minor, patch
|
||||
|
||||
status, error_msg, major, minor, patch = \
|
||||
rocm_version_numbers_post_rocm50(rocm_install_path)
|
||||
|
||||
if not status:
|
||||
major, minor, patch = \
|
||||
rocm_version_numbers_pre_rocm50(rocm_install_path, error_msg)
|
||||
major, minor, patch = rocm_version_numbers(rocm_install_path)
|
||||
|
||||
rocm_config = {
|
||||
"rocm_version_number": _get_composite_version_number(major, minor, patch)
|
||||
@@ -110,10 +101,20 @@ def _find_rocm_config(rocm_install_path):
|
||||
def _find_hipruntime_config(rocm_install_path):
|
||||
|
||||
def hipruntime_version_number(path):
|
||||
version_file = os.path.join(path, "hip/include/hip/hip_version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'HIP Runtime version file "{}" not found'.format(version_file))
|
||||
possible_version_files = [
|
||||
"include/hip/hip_version.h", # ROCm 5.2
|
||||
"hip/include/hip/hip_version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError("HIP Runtime version file not found in {}".format(
|
||||
possible_version_files))
|
||||
|
||||
# This header file has an explicit #define for HIP_VERSION, whose value
|
||||
# is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR)
|
||||
# Retreive the major + minor and re-calculate here, since we do not
|
||||
@@ -132,8 +133,17 @@ def _find_hipruntime_config(rocm_install_path):
|
||||
def _find_miopen_config(rocm_install_path):
|
||||
|
||||
def miopen_version_numbers(path):
|
||||
version_file = os.path.join(path, "miopen/include/miopen/version.h")
|
||||
if not os.path.exists(version_file):
|
||||
possible_version_files = [
|
||||
"include/miopen/version.h", # ROCm 5.2 and prior
|
||||
"miopen/include/miopen/version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError(
|
||||
'MIOpen version file "{}" not found'.format(version_file))
|
||||
major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR")
|
||||
@@ -155,8 +165,8 @@ def _find_rocblas_config(rocm_install_path):
|
||||
|
||||
def rocblas_version_numbers(path):
|
||||
possible_version_files = [
|
||||
"rocblas/include/rocblas-version.h", # ROCm 3.7 and prior
|
||||
"rocblas/include/internal/rocblas-version.h", # ROCm 3.8
|
||||
"include/rocblas/internal/rocblas-version.h", # ROCm 5.2
|
||||
"rocblas/include/internal/rocblas-version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
@@ -187,8 +197,8 @@ def _find_rocrand_config(rocm_install_path):
|
||||
|
||||
def rocrand_version_number(path):
|
||||
possible_version_files = [
|
||||
"rocrand/include/rocrand_version.h", # ROCm 5.0 and prior
|
||||
"include/rocrand/rocrand_version.h", # ROCm 5.1
|
||||
"rocrand/include/rocrand_version.h", # ROCm 5.0 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
@@ -212,10 +222,19 @@ def _find_rocrand_config(rocm_install_path):
|
||||
def _find_rocfft_config(rocm_install_path):
|
||||
|
||||
def rocfft_version_numbers(path):
|
||||
version_file = os.path.join(path, "rocfft/include/rocfft-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
possible_version_files = [
|
||||
"include/rocfft/rocfft-version.h", # ROCm 5.2
|
||||
"rocfft/include/rocfft-version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError(
|
||||
'rocfft version file "{}" not found'.format(version_file))
|
||||
"rocfft version file not found in {}".format(possible_version_files))
|
||||
major = _get_header_version(version_file, "rocfft_version_major")
|
||||
minor = _get_header_version(version_file, "rocfft_version_minor")
|
||||
patch = _get_header_version(version_file, "rocfft_version_patch")
|
||||
@@ -234,10 +253,19 @@ def _find_rocfft_config(rocm_install_path):
|
||||
def _find_hipfft_config(rocm_install_path):
|
||||
|
||||
def hipfft_version_numbers(path):
|
||||
version_file = os.path.join(path, "hipfft/include/hipfft-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
possible_version_files = [
|
||||
"include/hipfft/hipfft-version.h", # ROCm 5.2
|
||||
"hipfft/include/hipfft-version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError(
|
||||
'hipfft version file "{}" not found'.format(version_file))
|
||||
"hipfft version file not found in {}".format(possible_version_files))
|
||||
major = _get_header_version(version_file, "hipfftVersionMajor")
|
||||
minor = _get_header_version(version_file, "hipfftVersionMinor")
|
||||
patch = _get_header_version(version_file, "hipfftVersionPatch")
|
||||
@@ -256,10 +284,19 @@ def _find_hipfft_config(rocm_install_path):
|
||||
def _find_roctracer_config(rocm_install_path):
|
||||
|
||||
def roctracer_version_numbers(path):
|
||||
version_file = os.path.join(path, "roctracer/include/roctracer.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'roctracer version file "{}" not found'.format(version_file))
|
||||
possible_version_files = [
|
||||
"include/roctracer/roctracer.h", # ROCm 5.2
|
||||
"roctracer/include/roctracer.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError("roctracer version file not found in {}".format(
|
||||
possible_version_files))
|
||||
major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR")
|
||||
# roctracer header does not have a patch version number
|
||||
@@ -279,10 +316,19 @@ def _find_roctracer_config(rocm_install_path):
|
||||
def _find_hipsparse_config(rocm_install_path):
|
||||
|
||||
def hipsparse_version_numbers(path):
|
||||
version_file = os.path.join(path, "hipsparse/include/hipsparse-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'hipsparse version file "{}" not found'.format(version_file))
|
||||
possible_version_files = [
|
||||
"include/hipsparse/hipsparse-version.h", # ROCm 5.2
|
||||
"hipsparse/include/hipsparse-version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError("hipsparse version file not found in {}".format(
|
||||
possible_version_files))
|
||||
major = _get_header_version(version_file, "hipsparseVersionMajor")
|
||||
minor = _get_header_version(version_file, "hipsparseVersionMinor")
|
||||
patch = _get_header_version(version_file, "hipsparseVersionPatch")
|
||||
@@ -301,8 +347,9 @@ def _find_hipsolver_config(rocm_install_path):
|
||||
|
||||
def hipsolver_version_numbers(path):
|
||||
possible_version_files = [
|
||||
"hipsolver/include/hipsolver-version.h", # ROCm 5.0 and prior
|
||||
"include/hipsolver/internal/hipsolver-version.h", # ROCm 5.2
|
||||
"hipsolver/include/internal/hipsolver-version.h", # ROCm 5.1
|
||||
"hipsolver/include/hipsolver-version.h", # ROCm 5.0 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
@@ -331,10 +378,19 @@ def _find_hipsolver_config(rocm_install_path):
|
||||
def _find_rocsolver_config(rocm_install_path):
|
||||
|
||||
def rocsolver_version_numbers(path):
|
||||
version_file = os.path.join(path, "rocsolver/include/rocsolver-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'rocsolver version file "{}" not found'.format(version_file))
|
||||
possible_version_files = [
|
||||
"include/rocsolver/rocsolver-version.h", # ROCm 5.2
|
||||
"rocsolver/include/rocsolver-version.h", # ROCm 5.1 and prior
|
||||
]
|
||||
version_file = None
|
||||
for f in possible_version_files:
|
||||
version_file_path = os.path.join(path, f)
|
||||
if os.path.exists(version_file_path):
|
||||
version_file = version_file_path
|
||||
break
|
||||
if not version_file:
|
||||
raise ConfigError("rocsolver version file not found in {}".format(
|
||||
possible_version_files))
|
||||
major = _get_header_version(version_file, "ROCSOLVER_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "ROCSOLVER_VERSION_MINOR")
|
||||
patch = _get_header_version(version_file, "ROCSOLVER_VERSION_PATCH")
|
||||
|
||||
@@ -1 +1 @@
|
||||
eJztW21v2zgS/q5fQSgoIl8cxeldcYsccoA3TVHftUlhZ7tYNIEh27StrSz6SCqpUfS/3wxJyaQs+SV2+mmzWNSWhg+HM888HCnMEbli8wWPJ1NJXrdet8jdlJI7mgrG3yXsibQzOWVchKSdJKSLZoJ0qaD8kY5C78g7Ih/iIZjTEcnSEeVEwvj2PBrCP+ZOk3ymXMQsJa/DFgnQwDe3/Ma/AGHBMjKLFiRlkmSCAkQsyDhOKKHfhnQuSZySIZvNkzhKh5Q8xXKqpjEg4Ab5w0CwgYzAOgL7OXwb23Ykksph/JlKOb84O3t6egoj5WzI+OQs0Ybi7EPn6vqmd30KDqshv6UJFYJw+r8s5rDUwYJEc/BnGA3AyyR6IoyTaMIp3JMM/X3isYzTSZMINpZPEaeAMoqF5PEgk06wcu9gzbYBhCtKid/ukU7PJ7+2e51eEzB+79y9v/3tjvze7nbbN3ed6x657ZKr25u3nbvO7Q18e0faN3+Q/3Zu3jYJhVDBNPTbnKP/4GSMYVSpIz1KHQfGTDsk5nQYj+MhrCudZNGEkgl7pDyF5ZA55bNYYDIFuDcClCSexTKS6srKonCay4P+eL7vf+JxijS8vZrB9AMe8QU6Q6Y0wvlHkKKhZDymykfyqNkHlGLgIAZWrXIhJJ2FnoeEF0MeA88EjThwQahQ1MEjMYWL0oSMY9Sk8ODiDCkwohJDlaoQxzx3QgHNtf84fsjScTzJuAogjhNyxDIZKq/mERKd5eDIEJMbpNmUs2wyRZLQ9DHmLJ3RVJLHiMeKlAH4/7H/qX33vhF6nTEUF9xL4lFpytiEpamXo+OQO6jcoZyrVHMqM67STuASBGjIRtSNn4y+Ur2uPAcLy2MoGrxV+FXpd2jjJYx91cnQsdf5zHOiE6GqfRrx0Sn6M4IcSqh7T2QDmwdjzmZkEAkTVCMMS98Kf0MCsVq6COEBVfIKQxUmKMszNpdnnA1nPppkKH8R+CIh7+MoS3A9SUY9ZKvnQc0xDulj+Scm8k+gC+YTMMnzvGESQZ1eqRRdY5SDayWBkKrGhUfAe4FmMAvpT6jsm+n66EoflxYoM50r2017kDIGUskoSaxB4OvbnLU60nnKTdpmxAzS1MGRIS4Q5isjksta/8A8HhO/iLGPKWQiNGxAT6oBlzZfrNEPYH9UZ69c5DRK1NQrRo1lqFbu2QHDzYeJWNK+KeN+ms0GlAez6E/GmwQihv/AsOHUjv95C37I34gyIyf4Hb+hNXxT5vY0muT5HAF60SRpNKN5eromH6Dmc3AYJB2AYLhKGETxytYokxqsngR1CO7HLGRzmiP73IcdIoWyAWW/9DM5Pv3Fb+j4z9A3iCGnofoYcP9IT0ReCXIS3I9OGj55pbxrKvyGGgeZVfYahegyABx1MZyAZM2D84a5aaIEWhMou4aHoYtiqFK7Ao7zmf3vP1S56V3tPj0OYXGAHKgYkRMD6/74RPUVuAGizkAUVGfx/Qfsm/epn0MoOtRDxCCHei1aEVAFwW86oRhdKJoo8Rt5LsHZkWa81toK4l3gUtFY3XJZJfrgrBr/pmVSBYLMeB9k16QnH6CWsmT6nywukhvG6ZidGcPTEX30iwyh//kQ+g16DhHYgI08eRWpgLApIci3M+VAkYelmzWRxJ9jlUb0Qe/HOpthAeK40nCWa+IDK/7yoG6oVlBR2hmFWjwuKGjuYGsFHcwlGaMijIJGiFfmQaNkt5zEHRkK6Jxk4IcmkLqoLzV93bFfWg/GRpV6tc25sZmbUquyef2Qz3rqNwpUUzcV4rOeVUxIm1bbcylOh0k2omc2bDh9BqG04++iBJ8KjleZ5FKjhhPk5L4g1zG54wtMK9RjlEAdppGk2JtzBl39SiPmkDc8bpJWE/5z0lmlxPb0TbN1fb7u9qDh7n9s/+e267vp3hmjc1Ng5HTYFQM2w6v3vkOQO56hGah8DVVAkGQmmrqf68/EpNISnLlfbsrrWFW5wXoFRfR0F8tw18xTN9NSFVcmstagd5Gl+gLsdwXrV6D65OIZWzzA/fBKvYOezNkCpvGcZ6mMZ3SLjcAyLnmwW60Czller/gZ/t+rZFf2AEvYj993PpGu9npdIa/T9iNyh4/7VuMC3bx+yviGT9jQU+f7P7YyMGNO+SZ5mjJwTrfaGguQAstEVyf0XNh5nRDnDtZc7gI0VpzGj1oj8m5NV7Pug0+HUTLMEhQX2KygqkSsXkXAkw5TTwga5ymCFgPohGLOFNggE9idCXzkmEewetQqjtsWPpSLXaVnZW27K89KEFzROHc6VoWLFF3hclFWRHGumrv+xRpeV6qFXVcrczrVNYtx49+isoxhSU52LCsNUlSW+fpihfWxcwv4z62pHQgFE326vtmXU2WU5+1nJZSKHa1mH6veTGryXkE7hLDZZHO7EsW/sJK19/7hzF1+iBgkkdjuOUJZrmE5+CfiQbL0UL9AgWa6WIxvYM6sjg+/ny6J3lR6iW3U38N/5q+zQCNqIfAxCdqyZAPWLwrhoaoeb1iq9R03gDE+wlUvpdzx48XyywCrpsd55w9Vu6Zi+1YIK5xbMS0MB/CY8dWWBdt0KxnIY+nqwPLRCSIBmpALQemJqzpGuysE5OfXD+3eARpeF+bZPa8Ds69I1NVNtUq49WjLRDXOYXXCnb0sFBxqcTuhUJb1TeaWOoEotk7YqE5tvwlbVTpRGroB4vwveVjtEkzI1srDcaU8rBUHlxpb12W3ffM2r0u3Il1AU0gWX0uFVMFPaCRriLuxi3SnKlfNeCy3Kxo03K+F1CB2ycBXazc8cAup8X9CC1kKjhq5+/5QRsGRu28PJRQ1cr/doSrvtZvDkk0lSq+iHHxrWM5dfgmxHceN4X4c1yD2C4gX5bjG/wkc1xN91hc/Po/hLsbz+O1gfNqT3TUZr2a3w6PSw/9Ls9uZu6zgkkdDCNhWIm5s99ZxjWNLub7yIiquoX8CyWEPv+u2r667B+j1y0BWt3+0zERxZINR/Wu1afRISWQYmq/YtA12qbT2kfQaFtSqusOwkrBXYh1c2x0PyvIu8L3ilq+Yje3eIq9xbJ3XV15U6vUUP0ft9Vx7C74L82zNd2AOIPs1NKhVfodiJfGvxDq4/jselAuAJY9b7QBL2z1fkxVATgWoK5Wvt2oegVdhipdlm/D+eh6u0QgVtEM/Ee+oHcqH/bXDgXm+dtgwh9CO6gqq1w67OsvaUYV1eO2wPSi3j1uLx9J27/axVPTFlRd9GVBVGS/VRvZuP3w+SBtZBnr2S+MS0AFeG+9UB2WilbrIn1AHZQ9MHaycVSufNYzIKB7ieU88csvGev9RjqQUDwmb9eBps02HQSvOnXp1HK8qxA00P+4Vx3mXB3gVw4vHCwVe0Hx1Dp0rKvAEL2Tox/LrF316RDKWfI2lsvYfNBFK50bzEWE2H0WSBtscB2zUjNrmBEnd2E2/H68bt/F3jmsGrv8dxJqBa19RGZY4iSiVygP59yX5R+u81TI0qQ7mxmlq3Nv0jmFN/jY9nW2/uDcbFrdpJ1uzvs1DLSFREEY+ZhFsbKo0JV9cFF3nV7po5udlUyIYl3QUrEpNCFI2E0Gj2MTUHx4E/itxQV4JPF8cLJGU/+YPgqz6x7Oepj0UCxHqP1gI8c9vaODfp9fd7m33Akr5PrVO+wrJAwBsFMNAGCQeTPY8yEW/jweK+31yeUn8fh/X2O8rNdbL9f4PWsTtHA==
|
||||
eJztW21v2zgS/q5fQSgoam9dJSlugUMOOcCbZlHftUlhZ7tYtIVB27TNrSz6SCppUPS/7wxJyRIt2YqjvX6JgTa2NPNwOC8Ph7Z4RC7E+l7yxVKTVyevTsjNkpEblighf43FHemneimkikg/jskQxRQZMsXkLZtFwVFwRN7yKYizGUmTGZNEg35/Tafwx93pkQ9MKi4S8io6IR0UCN2tsPsvQLgXKVnRe5IITVLFAIIrMucxI+zrlK014QmZitU65jSZMnLH9dIM40DADPKHgxATTUGagvwaPs2LcoRqYzC+llqvz46P7+7uImqMjYRcHMdWUB2/HVxcXo0uX4LBRuW3JGZKEcn+l3IJU53cE7oGe6Z0AlbG9I4ISehCMrinBdp7J7nmyaJHlJjrOyoZoMy40pJPUl1yVmYdzLkoAO6iCQn7IzIYheSX/mgw6gHG74ObN9e/3ZDf+8Nh/+pmcDki10NycX31enAzuL6CT7+S/tUf5L+Dq9c9wsBVMAz7upZoPxjJ0Y0mdGTEWMmAubAGqTWb8jmfwrySRUoXjCzELZMJTIesmVxxhcFUYN4MUGK+4ppqc2VrUjjMeauvIAzD95InmIbXFysYfiKpvEdjyJJRHH8GIZpqITkzNpJbm32QUgIMRMeaWd4rzVZREGDCq6nkkGeKUQm5oIwr6uAxMVUZpQcRR69pFcDFFabAjGl0VWJczGVmhAFaW/tRfyqSOV+k0jgQ9ZSeiVRHxqo1xUQXGThmiIsNptlSinSxxCRhyS2XIlmxRJNbKrlJyg7Y/278vn/zphsFgzkUF9yL+cwbkju39Ox0rB8yA405TEoTasl0Kk3YCVwCB03FjJX9p+kXZueVxeC+YDEUDd7K7aq0OyrixUJ8scGwvrfxzGJiA2GqfUnl7CXaM4MYaqj7QKWTYh7MpViRCVXOqY4YNrbl9kYEfLUxEdwDrBTkgsZNUJbHYq2PpZiuQhRJkf4o2KIh7nOaxjifOGUBZmsQQM0JCeET2TuhsnfAC+4dZFIQBNOYQp1emBBdopc7l4YCIVTds4CA9QrFYBQyXjA9dsON0ZQxTq1jxGysimYWlYwwJJWmcVxQAltfZ1lrPZ2F3IVtRZySTR3UjHCCMJ6PSM5r7QNxPidh7uMQQyhU5LIBLakG3Mh8LGh/BvmjOnljomQ0NkNvCXU3rtq6V3QYLj5Ccc3GrozHSbqaMNlZ0T+F7BHwGP4Btemy6P/TE3iRn4gRIy/wM35CafhkxIvD2CTPxuigFT2S0BXLwjN08QA2X4PBQOkABOomYODFiyJHudBg9cTIQ3Cfi0isWYYcyhBWiATKBpj9PEz1/OU/w671/wptAx9KFpm3HRke2YHIM0VedD7NXnRD8sxY1zP4XaMHkTXyFoXYMgAcczFaAGWtO6ddd9N5CbimY+S6AbqOcqjSYgU8z0YOv3035WZXtU/J8wgmB8gd4yPywsGWXyExfQUugMgz4AXTWXz7DuvmpyTMIEw61ENwoEM7F8sIyIJgN1sw9C4UDY3DbhZLMHZmM95ybUXineFUUdjcKmeV6jgRHBvyTnHgxVzGst45+ZibGvJkGqczZgodWFDad5lGtIQwQ4mYAv45elWtVyN9mi1YQhq9z+b/oi1gypVImLmOuTZHF1dbnSdF4aJfrH8KnuXnPMsTSKrsNvsKLZLqbCF0z/JpecZtieaCEyCGL1nWYkyLohnedjoWUiQ0XsqWdjMg4tg+A9wA+Zrll5dY1Q7qmgogji/OK1mhqNBzNPrhcjiC5m/8rv+f62Fo3WZZ5gCMwVWOsXYs8FAMIOaLNw7DFXkFU+JcKy4j61QVRQV7ZyuPLTNQ/ObcHFYghGcHkDngfQ+8VcKOVir2JV/LNNF8xRqUfEHYs+DAugdA/Le/4FGwmdJT3e+s+zeD92RoY/g3lP8RbMJhmSms5dDg2sb7K246oc3MlkT0PBiTVV6P3C0F2G27T4sFSJ2CiCUJaEOwGXlBSnew9LtODXoNyfit3SJmDYwlFdsavpzSeJpCLwjWMQkcoLjZnUPzL0zTbHHuKKy6UHe4XAoDNkkVNiwKu/A1hfnDvhK6f0gH3KeqhzLg1tweToBbTihz12mpiTO4GKitoi8SUG2RAw3VE0Blh1okoK0xSzS04tjeNaAgJ/i4vqPEQRbxuMQm7rVhIo9UchCnvAvriZkaMdPzd4Nr8F2ZlEzfnDNT3jUXwbvdh5YdDPT+8uqxleejHNZ8eCiPbT9qqqO6ASnVXJEBKlHCs0KwHt2OlMb2dx+TmKpmGxAj2d4eBOGOcXMkExpnF17ubU42mhapIcITH+zsVJzz2u1SHkgVEKlf3vZHLWxTyjAH71RKMC1sVioLqHa/UihMb8tSgdMuYZRH9xlDQiU1Ywwj2c7mxaEde6h+lZd4wih4ADWKJ0/0sIceZOE3kt30sJMPytnQuBSH/avXWSmWi7AM6GqnkKJe7VSkJLTaNbm6t88uD+UXynyum9UJCra3sAKa+9NoNUXxsvbTEnpgjYDvfugK6qWS0Xz4AuqjoObD108PxWg+bvmsqpLa1XNTex4BbKO0vnZuxva//WvGCE6wJUawaO7PfkZw4mXtJ0Y4hBGs734oI1gTPtiL7w7jgzLGYWxQwnj/SC6oqY9qLihVnffV29/NBaWx/e5ASzoFhzVqEJxsez2CBdy829khOOEt9Sc2eGB/YN32ozfZN8P+xeWwhW22D1TYaB9tsjZ/3kow+5v4kt4yQl01Z75w7XuRVk4e0yzUVExtv1CqRq9lqMRqvWsoWeA3Dgp/AWn4q6GTba99sICbd42aCKe0BfPUShzYSlj3/ehuwlrx6IaiDHNwT1GCaaGtqCmd2s6iVJZec1GJ1Xp/UbLAJw0R3zbqMDayLZKGAdx8PZ9fasYembb3Jf8+lNNdKPuUn74E3OKgjRcfTDxtko6x4PGkU4I5nHSKMG2QTnXp1ZNOsax90qnCap90ihb4+5rGrLORbW9f4yo+f9foG1CPJvYpP7Uq+/Y5BzJGm6wBsRpdv/3Qyj7HBzr4B0UPqIWfFB/EHX5xetuc/wN3+BY47th6Etp/kp2SGZ/iaQI80CHmthKNIQnDIyhuPjyZi31HDSpONQR5tnslVUVeewrg+Sg/LLI5HmKerMn3vwY8f7xmewwbK6bwfAhE6Pvm40f7wKoWIv7CtZEOP2cPwpZOJWQaUbqeUc064wYPm3drtJo8tVqnu+9Rszq9vQ+m7FDc/fv0DsWd3867LCkFwiuVz+Tf5+QfJ6cnJy5Nqp25d5ga8/Z9Ybgjfvu+Pmg+uZ/3TG7f6r9jfvtVC0RiIBx9rCisj6Y0tbw/y1fdL+y+l53GSIgSUrNZZ5tqIqCylep08+XSHGvrhM/UGXmm8PRKZ4Nk7HfHTQv1j6e63PKo7lVkj8NFeLiTdcJPyeVweD08g1L+lBTOkigtOwDYzdWAGDQeewkCiMV4jMdVxmNyfk7C8RjnOB4bNrbTDf4CEAHDYg==
|
||||
100
third_party/gpus/rocm_configure.bzl
vendored
100
third_party/gpus/rocm_configure.bzl
vendored
@@ -178,7 +178,13 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include")
|
||||
|
||||
# Add HIP headers (needs to match $HIP_PATH)
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include")
|
||||
if int(rocm_config.rocm_version_number) >= 50200:
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocprim")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocsolver")
|
||||
else:
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include")
|
||||
|
||||
# Add HIP-Clang headers (realpath relative to compiler binary)
|
||||
rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin)
|
||||
@@ -189,6 +195,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
|
||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include")
|
||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include")
|
||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include")
|
||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include")
|
||||
|
||||
# Support hcc based off clang 10.0.0 (for ROCm 3.3)
|
||||
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
|
||||
@@ -314,7 +321,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin):
|
||||
|
||||
return libs
|
||||
|
||||
def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin):
|
||||
def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin):
|
||||
"""Returns the ROCm libraries on the system.
|
||||
|
||||
Args:
|
||||
@@ -329,19 +336,19 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin):
|
||||
(name, _rocm_lib_paths(repository_ctx, name, path))
|
||||
for name, path in [
|
||||
("amdhip64", rocm_config.rocm_toolkit_path + "/hip"),
|
||||
("rocblas", rocm_config.rocm_toolkit_path + "/rocblas"),
|
||||
(hipfft_or_rocfft, rocm_config.rocm_toolkit_path + "/" + hipfft_or_rocfft),
|
||||
("rocblas", rocm_config.rocm_toolkit_path),
|
||||
(hipfft_or_rocfft, rocm_config.rocm_toolkit_path),
|
||||
("hiprand", rocm_config.rocm_toolkit_path),
|
||||
("MIOpen", rocm_config.rocm_toolkit_path + "/miopen"),
|
||||
("rccl", rocm_config.rocm_toolkit_path + "/rccl"),
|
||||
("hipsparse", rocm_config.rocm_toolkit_path + "/hipsparse"),
|
||||
("MIOpen", miopen_path),
|
||||
("rccl", rccl_path),
|
||||
("hipsparse", rocm_config.rocm_toolkit_path),
|
||||
("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"),
|
||||
("rocsolver", rocm_config.rocm_toolkit_path + "/rocsolver"),
|
||||
("rocsolver", rocm_config.rocm_toolkit_path),
|
||||
]
|
||||
]
|
||||
if int(rocm_config.rocm_version_number) >= 40500:
|
||||
libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path + "/hipsolver")))
|
||||
libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path + "/hipblas")))
|
||||
libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path)))
|
||||
libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path)))
|
||||
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
||||
|
||||
def _exec_find_rocm_config(repository_ctx, script_path):
|
||||
@@ -463,7 +470,7 @@ def _create_dummy_repository(repository_ctx):
|
||||
"%{hipblas_lib}": _lib_name("hipblas"),
|
||||
"%{miopen_lib}": _lib_name("miopen"),
|
||||
"%{rccl_lib}": _lib_name("rccl"),
|
||||
"%{hipfft_or_rocfft}": "hipfft",
|
||||
"%{hipfft_or_rocfft}": _lib_name("hipfft"),
|
||||
"%{hipfft_or_rocfft_lib}": _lib_name("hipfft"),
|
||||
"%{hiprand_lib}": _lib_name("hiprand"),
|
||||
"%{hipsparse_lib}": _lib_name("hipsparse"),
|
||||
@@ -550,6 +557,10 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
rocm_version_number = int(rocm_config.rocm_version_number)
|
||||
hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft"
|
||||
|
||||
# For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path
|
||||
miopen_path = rocm_config.rocm_toolkit_path + "/miopen" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path
|
||||
rccl_path = rocm_config.rocm_toolkit_path + "/rccl" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path
|
||||
|
||||
# Copy header and library files to execroot.
|
||||
# rocm_toolkit_path
|
||||
rocm_toolkit_path = rocm_config.rocm_toolkit_path
|
||||
@@ -561,69 +572,14 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
out_dir = "rocm/include",
|
||||
exceptions = ["gtest", "gmock"],
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = hipfft_or_rocfft + "-include",
|
||||
src_dir = rocm_toolkit_path + "/" + hipfft_or_rocfft + "/include",
|
||||
out_dir = "rocm/include/" + hipfft_or_rocfft,
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "rocblas-include",
|
||||
src_dir = rocm_toolkit_path + "/rocblas/include",
|
||||
out_dir = "rocm/include/rocblas",
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "rocblas-hsaco",
|
||||
src_dir = rocm_toolkit_path + "/rocblas/lib/library",
|
||||
out_dir = "rocm/lib/rocblas/lib/library",
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "miopen-include",
|
||||
src_dir = rocm_toolkit_path + "/miopen/include",
|
||||
out_dir = "rocm/include/miopen",
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "rccl-include",
|
||||
src_dir = rocm_toolkit_path + "/rccl/include",
|
||||
out_dir = "rocm/include/rccl",
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "hipsparse-include",
|
||||
src_dir = rocm_toolkit_path + "/hipsparse/include",
|
||||
out_dir = "rocm/include/hipsparse",
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "rocsolver-include",
|
||||
src_dir = rocm_toolkit_path + "/rocsolver/include",
|
||||
out_dir = "rocm/include/rocsolver",
|
||||
),
|
||||
]
|
||||
|
||||
# Add Hipsolver on ROCm4.5+
|
||||
if rocm_version_number >= 40500:
|
||||
copy_rules.append(
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "hipsolver-include",
|
||||
src_dir = rocm_toolkit_path + "/hipsolver/include",
|
||||
out_dir = "rocm/include/hipsolver",
|
||||
),
|
||||
)
|
||||
copy_rules.append(
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "hipblas-include",
|
||||
src_dir = rocm_toolkit_path + "/hipblas/include",
|
||||
out_dir = "rocm/include/hipblas",
|
||||
),
|
||||
)
|
||||
|
||||
# explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and
|
||||
# $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include
|
||||
# dir has been removed. This removal will happen in a near-future ROCm release.
|
||||
@@ -655,7 +611,7 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
),
|
||||
)
|
||||
|
||||
rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin)
|
||||
rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin)
|
||||
rocm_lib_srcs = []
|
||||
rocm_lib_outs = []
|
||||
for lib in rocm_libs.values():
|
||||
@@ -710,20 +666,12 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"%{rocsolver_lib}": rocm_libs["rocsolver"].file_name,
|
||||
"%{copy_rules}": "\n".join(copy_rules),
|
||||
"%{rocm_headers}": ('":rocm-include",\n' +
|
||||
'":' + hipfft_or_rocfft + '-include",\n' +
|
||||
'":rocblas-include",\n' +
|
||||
'":miopen-include",\n' +
|
||||
'":rccl-include",\n' +
|
||||
hiprand_include +
|
||||
rocrand_include +
|
||||
'":hipsparse-include",\n' +
|
||||
'":rocsolver-include"'),
|
||||
rocrand_include),
|
||||
}
|
||||
if rocm_version_number >= 40500:
|
||||
repository_dict["%{hipsolver_lib}"] = rocm_libs["hipsolver"].file_name
|
||||
repository_dict["%{rocm_headers}"] += ',\n":hipsolver-include"'
|
||||
repository_dict["%{hipblas_lib}"] = rocm_libs["hipblas"].file_name
|
||||
repository_dict["%{rocm_headers}"] += ',\n":hipblas-include"'
|
||||
|
||||
repository_ctx.template(
|
||||
"rocm/BUILD",
|
||||
|
||||
Reference in New Issue
Block a user