mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[ROCm]: Add backward support for ROCm5.0
This commit is contained in:
@@ -30,8 +30,12 @@ limitations under the License.
|
||||
|
||||
// Common place for all collective thunks to include nccl/rccl headers.
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#else
|
||||
#include "rocm/include/rccl.h"
|
||||
#endif
|
||||
#else
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#endif
|
||||
|
||||
|
||||
@@ -20,7 +20,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
||||
|
||||
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
#else
|
||||
#include "rocm/include/hipsparse.h"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/env.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
||||
@@ -22,7 +22,12 @@ limitations under the License.
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/hipfft/hipfft.h"
|
||||
#else
|
||||
#include "rocm/include/hipfft.h"
|
||||
#endif
|
||||
#include "rocm/rocm_config.h"
|
||||
#include "rocm/rocm_config.h"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -20,7 +20,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_ROCM_ROCSOLVER_WRAPPER_H_
|
||||
|
||||
#include "rocm/rocm_config.h"
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rocsolver/rocsolver.h"
|
||||
#else
|
||||
#include "rocm/include/rocsolver.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/compiler/xla/stream_executor/lib/env.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
||||
@@ -20,7 +20,11 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#else
|
||||
#include "rocm/include/rccl.h"
|
||||
#endif
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/nccl/nccl_manager.h"
|
||||
|
||||
@@ -30,7 +30,11 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#if (TF_ROCM_VERSION >= 50200)
|
||||
#include "rocm/include/rccl/rccl.h"
|
||||
#else
|
||||
#include "rocm/include/rccl.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
|
||||
Reference in New Issue
Block a user