2020-07-30 19:16:14 -07:00
|
|
|
#include <c10/cuda/CUDAFunctions.h>
|
2021-05-06 14:27:32 -07:00
|
|
|
#include <c10/macros/Macros.h>
|
2024-10-24 20:26:27 -07:00
|
|
|
#include <c10/util/WaitCounter.h>
|
2020-07-30 19:16:14 -07:00
|
|
|
|
2020-11-18 19:35:20 -08:00
|
|
|
#include <limits>
|
|
|
|
|
|
2023-03-17 04:50:31 +00:00
|
|
|
namespace c10::cuda {
|
2020-07-30 19:16:14 -07:00
|
|
|
|
2020-08-05 11:36:44 -07:00
|
|
|
namespace {
|
|
|
|
|
// returns -1 on failure
|
|
|
|
|
int32_t driver_version() {
|
|
|
|
|
int driver_version = -1;
|
2022-03-30 10:07:52 -07:00
|
|
|
C10_CUDA_IGNORE_ERROR(cudaDriverGetVersion(&driver_version));
|
2020-08-05 11:36:44 -07:00
|
|
|
return driver_version;
|
|
|
|
|
}
|
|
|
|
|
|
2021-02-09 06:43:14 -08:00
|
|
|
int device_count_impl(bool fail_if_no_driver) {
|
2023-06-04 06:33:01 +00:00
|
|
|
int count = 0;
|
2023-04-10 17:31:12 +00:00
|
|
|
auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count));
|
2020-08-05 11:36:44 -07:00
|
|
|
if (err == cudaSuccess) {
|
|
|
|
|
return count;
|
|
|
|
|
}
|
|
|
|
|
// Clear out the error state, so we don't spuriously trigger someone else.
|
|
|
|
|
// (This shouldn't really matter, since we won't be running very much CUDA
|
|
|
|
|
// code in this regime.)
|
2024-10-19 13:17:43 +00:00
|
|
|
[[maybe_unused]] cudaError_t last_err = cudaGetLastError();
|
2020-08-05 11:36:44 -07:00
|
|
|
switch (err) {
|
|
|
|
|
case cudaErrorNoDevice:
|
|
|
|
|
// Zero devices is ok here
|
|
|
|
|
count = 0;
|
|
|
|
|
break;
|
|
|
|
|
case cudaErrorInsufficientDriver: {
|
|
|
|
|
auto version = driver_version();
|
|
|
|
|
if (version <= 0) {
|
2021-02-09 06:43:14 -08:00
|
|
|
if (!fail_if_no_driver) {
|
|
|
|
|
// No CUDA driver means no devices
|
|
|
|
|
count = 0;
|
|
|
|
|
break;
|
|
|
|
|
}
|
2020-08-05 11:36:44 -07:00
|
|
|
TORCH_CHECK(
|
|
|
|
|
false,
|
|
|
|
|
"Found no NVIDIA driver on your system. Please check that you "
|
|
|
|
|
"have an NVIDIA GPU and installed a driver from "
|
|
|
|
|
"http://www.nvidia.com/Download/index.aspx");
|
|
|
|
|
} else {
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
false,
|
|
|
|
|
"The NVIDIA driver on your system is too old (found version ",
|
|
|
|
|
version,
|
|
|
|
|
"). Please update your GPU driver by downloading and installing "
|
|
|
|
|
"a new version from the URL: "
|
|
|
|
|
"http://www.nvidia.com/Download/index.aspx Alternatively, go to: "
|
|
|
|
|
"https://pytorch.org to install a PyTorch version that has been "
|
|
|
|
|
"compiled with your version of the CUDA driver.");
|
|
|
|
|
}
|
2025-08-11 16:09:24 +00:00
|
|
|
}
|
2020-08-05 11:36:44 -07:00
|
|
|
case cudaErrorInitializationError:
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
false,
|
|
|
|
|
"CUDA driver initialization failed, you might not "
|
|
|
|
|
"have a CUDA gpu.");
|
|
|
|
|
case cudaErrorUnknown:
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
false,
|
|
|
|
|
"CUDA unknown error - this may be due to an "
|
|
|
|
|
"incorrectly set up environment, e.g. changing env "
|
|
|
|
|
"variable CUDA_VISIBLE_DEVICES after program start. "
|
|
|
|
|
"Setting the available devices to be zero.");
|
|
|
|
|
#if C10_ASAN_ENABLED
|
|
|
|
|
case cudaErrorMemoryAllocation:
|
|
|
|
|
// In ASAN mode, we know that a cudaErrorMemoryAllocation error will
|
|
|
|
|
// pop up if compiled with NVCC (clang-cuda is fine)
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
false,
|
|
|
|
|
"Got 'out of memory' error while trying to initialize CUDA. "
|
|
|
|
|
"CUDA with nvcc does not work well with ASAN and it's probably "
|
|
|
|
|
"the reason. We will simply shut down CUDA support. If you "
|
|
|
|
|
"would like to use GPUs, turn off ASAN.");
|
|
|
|
|
break;
|
|
|
|
|
#endif // C10_ASAN_ENABLED
|
2025-09-10 04:21:38 +00:00
|
|
|
#if defined(_WIN32) && CUDA_VERSION >= 13000
|
2025-09-09 04:12:10 +00:00
|
|
|
// Workaround for CUDA-13.0 error handling on Windows, see
|
|
|
|
|
// https://github.com/pytorch/pytorch/issues/162333#issuecomment-3267929585
|
|
|
|
|
case cudaErrorNotSupported:
|
|
|
|
|
if (!fail_if_no_driver) {
|
|
|
|
|
TORCH_WARN(
|
|
|
|
|
"cudaGetDeviceCount() returned cudaErrorNotSupported, "
|
|
|
|
|
"likely using older driver or on CPU machine");
|
|
|
|
|
count = 0;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
2020-08-05 11:36:44 -07:00
|
|
|
default:
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
false,
|
|
|
|
|
"Unexpected error from cudaGetDeviceCount(). Did you run "
|
|
|
|
|
"some cuda functions before calling NumCudaDevices() "
|
|
|
|
|
"that might have already set an error? Error ",
|
|
|
|
|
err,
|
|
|
|
|
": ",
|
|
|
|
|
cudaGetErrorString(err));
|
|
|
|
|
}
|
|
|
|
|
return count;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
DeviceIndex device_count() noexcept {
|
|
|
|
|
// initialize number of devices only once
|
|
|
|
|
static int count = []() {
|
|
|
|
|
try {
|
2021-02-09 06:43:14 -08:00
|
|
|
auto result = device_count_impl(/*fail_if_no_driver=*/false);
|
2020-11-18 19:35:20 -08:00
|
|
|
TORCH_INTERNAL_ASSERT(
|
2024-02-28 18:57:08 +00:00
|
|
|
result <= std::numeric_limits<DeviceIndex>::max(),
|
2020-11-18 19:35:20 -08:00
|
|
|
"Too many CUDA devices, DeviceIndex overflowed");
|
2021-02-09 06:43:14 -08:00
|
|
|
return result;
|
2020-08-05 11:36:44 -07:00
|
|
|
} catch (const c10::Error& ex) {
|
|
|
|
|
// We don't want to fail, but still log the warning
|
|
|
|
|
// msg() returns the message without the stack trace
|
|
|
|
|
TORCH_WARN("CUDA initialization: ", ex.msg());
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
}();
|
|
|
|
|
return static_cast<DeviceIndex>(count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DeviceIndex device_count_ensure_non_zero() {
|
|
|
|
|
// Call the implementation every time to throw the exception
|
2021-02-09 06:43:14 -08:00
|
|
|
int count = device_count_impl(/*fail_if_no_driver=*/true);
|
2020-08-05 11:36:44 -07:00
|
|
|
// Zero gpus doesn't produce a warning in `device_count` but we fail here
|
|
|
|
|
TORCH_CHECK(count, "No CUDA GPUs are available");
|
2024-02-08 23:00:52 +00:00
|
|
|
TORCH_INTERNAL_ASSERT(
|
2024-02-28 18:57:08 +00:00
|
|
|
count <= std::numeric_limits<DeviceIndex>::max(),
|
2024-02-08 23:00:52 +00:00
|
|
|
"Too many CUDA devices, DeviceIndex overflowed");
|
2020-07-30 19:16:14 -07:00
|
|
|
return static_cast<DeviceIndex>(count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DeviceIndex current_device() {
|
2024-02-08 23:00:52 +00:00
|
|
|
DeviceIndex cur_device = -1;
|
2023-04-10 17:31:12 +00:00
|
|
|
C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
|
2024-02-08 23:00:52 +00:00
|
|
|
return cur_device;
|
2020-07-30 19:16:14 -07:00
|
|
|
}
|
|
|
|
|
|
2025-06-17 18:59:41 +00:00
|
|
|
void set_device(DeviceIndex device, const bool force) {
|
|
|
|
|
C10_CUDA_CHECK(c10::cuda::SetDevice(device, force));
|
2020-07-30 19:16:14 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void device_synchronize() {
|
2022-09-09 03:20:05 -07:00
|
|
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
|
|
|
|
if (C10_UNLIKELY(interp)) {
|
Refactor gpu trace to be device-agnostic (#121794)
# Motivation
Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend.
# Solution
move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794
Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
2024-03-30 08:36:53 +00:00
|
|
|
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
|
2022-09-09 03:20:05 -07:00
|
|
|
}
|
2024-10-24 20:26:27 -07:00
|
|
|
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.cuda_device_synchronize);
|
2020-07-30 19:16:14 -07:00
|
|
|
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
|
|
|
|
}
|
|
|
|
|
|
enable warnings on cuda synchronization (#62092)
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
2021-07-30 09:10:47 -07:00
|
|
|
// this function has to be called from callers performing cuda synchronizing
|
|
|
|
|
// operations, to raise proper error or warning
|
|
|
|
|
void warn_or_error_on_sync() {
|
|
|
|
|
if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_ERROR) {
|
|
|
|
|
TORCH_CHECK(false, "called a synchronizing CUDA operation");
|
|
|
|
|
} else if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_WARN) {
|
|
|
|
|
TORCH_WARN("called a synchronizing CUDA operation");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-05-14 19:35:49 +00:00
|
|
|
std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext() {
|
2023-03-17 04:50:31 +00:00
|
|
|
// check current device first
|
2023-08-25 20:16:11 +00:00
|
|
|
auto current_device_index = current_device();
|
2023-03-17 04:50:31 +00:00
|
|
|
if (current_device_index >= 0) {
|
|
|
|
|
if (hasPrimaryContext(current_device_index)) {
|
|
|
|
|
return current_device_index;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (const auto device_index : c10::irange(at::cuda::device_count())) {
|
|
|
|
|
if (device_index == current_device_index)
|
|
|
|
|
continue;
|
|
|
|
|
if (hasPrimaryContext(device_index)) {
|
|
|
|
|
return device_index;
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-07-08 07:03:53 +00:00
|
|
|
return std::nullopt;
|
2023-03-17 04:50:31 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace _internal {
|
2025-01-08 17:21:52 +00:00
|
|
|
static bool dummyHasPrimaryContext([[maybe_unused]] DeviceIndex device_index) {
|
2023-03-17 04:50:31 +00:00
|
|
|
TORCH_CHECK(false, "Should never been called");
|
|
|
|
|
}
|
2024-10-29 17:01:53 +00:00
|
|
|
static bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext;
|
2023-03-17 04:50:31 +00:00
|
|
|
|
|
|
|
|
// Private api to be called from CUDAHooks.cpp
|
2025-03-12 14:22:56 +00:00
|
|
|
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
2023-08-25 20:16:11 +00:00
|
|
|
C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) {
|
2023-03-17 04:50:31 +00:00
|
|
|
hasPrimaryContext = func ? func : dummyHasPrimaryContext;
|
|
|
|
|
}
|
|
|
|
|
} // namespace _internal
|
|
|
|
|
|
2023-08-25 20:16:11 +00:00
|
|
|
bool hasPrimaryContext(DeviceIndex device_index) {
|
2023-03-17 04:50:31 +00:00
|
|
|
return _internal::hasPrimaryContext(device_index);
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-10 17:31:12 +00:00
|
|
|
// Wrappers for raw CUDA device management functions
|
|
|
|
|
cudaError_t GetDeviceCount(int* dev_count) {
|
|
|
|
|
return cudaGetDeviceCount(dev_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This is a codepath for CUDA 12 that comes with a critical change in behavior
|
|
|
|
|
// of `cudaSetDevice`. Unlike to previous CUDA versions that allocate context
|
|
|
|
|
// lazily CUDA 12.x eagerly allocates primary context the moment `cudaSetDevice`
|
|
|
|
|
// is called. This can lead to dramatic consequences and pollute the device
|
|
|
|
|
// memory in distributed runs. To avoid unnecessary context creation a new
|
|
|
|
|
// function called `MaybeSetDevice` was introduced. This function is to be
|
|
|
|
|
// called in device guard destructor and at the exit of torch.cuda.device
|
|
|
|
|
// context manager. The behavior of `MaybeSetDevice` is quite simple, it calls
|
|
|
|
|
// to `cudaSetDevice` if context already exist or if context was not allocated
|
|
|
|
|
// on targeted device it simply saves the device index. This way we can keep
|
|
|
|
|
// PyTorch backward compatible for applications like this:
|
|
|
|
|
//
|
|
|
|
|
// ```
|
|
|
|
|
// import torch
|
|
|
|
|
// x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this
|
|
|
|
|
// call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0
|
|
|
|
|
// ```
|
|
|
|
|
#if CUDA_VERSION >= 12000
|
2024-10-29 17:01:53 +00:00
|
|
|
thread_local static DeviceIndex targetDeviceIndex = -1;
|
2023-04-10 17:31:12 +00:00
|
|
|
|
2024-02-08 23:00:52 +00:00
|
|
|
cudaError_t GetDevice(DeviceIndex* device) {
|
2023-04-10 17:31:12 +00:00
|
|
|
if (targetDeviceIndex >= 0) {
|
|
|
|
|
*device = targetDeviceIndex;
|
|
|
|
|
return cudaSuccess;
|
|
|
|
|
}
|
2024-02-08 23:00:52 +00:00
|
|
|
int tmp_device = -1;
|
|
|
|
|
auto err = cudaGetDevice(&tmp_device);
|
|
|
|
|
if (err == cudaSuccess) {
|
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
2024-02-28 18:57:08 +00:00
|
|
|
tmp_device >= 0 &&
|
|
|
|
|
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
|
2024-02-08 23:00:52 +00:00
|
|
|
"cudaGetDevice returns invalid device ",
|
|
|
|
|
tmp_device);
|
|
|
|
|
*device = static_cast<DeviceIndex>(tmp_device);
|
|
|
|
|
}
|
|
|
|
|
return err;
|
2023-04-10 17:31:12 +00:00
|
|
|
}
|
|
|
|
|
|
2025-06-17 18:59:41 +00:00
|
|
|
cudaError_t SetDevice(DeviceIndex device, const bool force) {
|
2025-11-28 04:58:46 +00:00
|
|
|
TORCH_CHECK(
|
|
|
|
|
device >= 0, "device id must be non-negative!", static_cast<int>(device));
|
2023-04-10 17:31:12 +00:00
|
|
|
targetDeviceIndex = -1;
|
2025-06-17 18:59:41 +00:00
|
|
|
if (force) {
|
|
|
|
|
return cudaSetDevice(device);
|
|
|
|
|
}
|
2023-04-10 17:31:12 +00:00
|
|
|
int cur_device = -1;
|
|
|
|
|
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
|
|
|
|
|
if (device == cur_device) {
|
|
|
|
|
return cudaSuccess;
|
|
|
|
|
}
|
|
|
|
|
return cudaSetDevice(device);
|
|
|
|
|
}
|
|
|
|
|
|
2024-02-08 23:00:52 +00:00
|
|
|
cudaError_t MaybeSetDevice(DeviceIndex device) {
|
2023-10-18 20:32:53 +00:00
|
|
|
if (hasPrimaryContext(device)) {
|
2023-04-10 17:31:12 +00:00
|
|
|
return c10::cuda::SetDevice(device);
|
|
|
|
|
}
|
|
|
|
|
targetDeviceIndex = device;
|
|
|
|
|
return cudaSuccess;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This function always initializes the CUDA context
|
|
|
|
|
// on to_device
|
2024-02-08 23:00:52 +00:00
|
|
|
DeviceIndex ExchangeDevice(DeviceIndex to_device) {
|
|
|
|
|
auto cur_device = targetDeviceIndex;
|
2023-04-10 17:31:12 +00:00
|
|
|
targetDeviceIndex = -1;
|
|
|
|
|
if (cur_device < 0) {
|
2024-02-08 23:00:52 +00:00
|
|
|
int tmp_device = -1;
|
|
|
|
|
C10_CUDA_CHECK(cudaGetDevice(&tmp_device));
|
|
|
|
|
cur_device = static_cast<DeviceIndex>(tmp_device);
|
2023-04-10 17:31:12 +00:00
|
|
|
if (to_device == cur_device) {
|
|
|
|
|
return cur_device;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
C10_CUDA_CHECK(cudaSetDevice(to_device));
|
|
|
|
|
return cur_device;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This function does not initialize the CUDA context
|
|
|
|
|
// on to_device if it does not already exist
|
2024-02-08 23:00:52 +00:00
|
|
|
DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
|
|
|
|
|
int tmp_cur_device = -1;
|
|
|
|
|
C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device));
|
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
2024-02-28 18:57:08 +00:00
|
|
|
tmp_cur_device >= 0 &&
|
|
|
|
|
tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(),
|
2024-02-08 23:00:52 +00:00
|
|
|
"cudaGetDevice returns invalid device ",
|
|
|
|
|
tmp_cur_device);
|
|
|
|
|
auto cur_device = static_cast<DeviceIndex>(tmp_cur_device);
|
|
|
|
|
if (to_device == tmp_cur_device) {
|
2023-04-10 17:31:12 +00:00
|
|
|
return cur_device;
|
|
|
|
|
}
|
2023-10-18 20:32:53 +00:00
|
|
|
if (hasPrimaryContext(to_device)) {
|
2023-04-10 17:31:12 +00:00
|
|
|
C10_CUDA_CHECK(cudaSetDevice(to_device));
|
|
|
|
|
} else {
|
|
|
|
|
targetDeviceIndex = to_device;
|
|
|
|
|
}
|
|
|
|
|
return cur_device;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetTargetDevice() {
|
|
|
|
|
if (targetDeviceIndex >= 0) {
|
|
|
|
|
C10_CUDA_CHECK(c10::cuda::SetDevice(targetDeviceIndex));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else
|
2024-02-08 23:00:52 +00:00
|
|
|
cudaError_t GetDevice(DeviceIndex* device) {
|
|
|
|
|
int tmp_device = -1;
|
|
|
|
|
auto err = cudaGetDevice(&tmp_device);
|
|
|
|
|
if (err == cudaSuccess) {
|
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
2024-02-28 18:57:08 +00:00
|
|
|
tmp_device >= 0 &&
|
|
|
|
|
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
|
2024-02-08 23:00:52 +00:00
|
|
|
"cudaGetDevice returns invalid device ",
|
|
|
|
|
tmp_device);
|
|
|
|
|
*device = static_cast<DeviceIndex>(tmp_device);
|
|
|
|
|
}
|
|
|
|
|
return err;
|
2023-04-10 17:31:12 +00:00
|
|
|
}
|
|
|
|
|
|
2025-06-17 18:59:41 +00:00
|
|
|
cudaError_t SetDevice(DeviceIndex device, const bool force) {
|
2025-11-28 04:58:46 +00:00
|
|
|
TORCH_CHECK(
|
|
|
|
|
device >= 0, "device id must be non-negative!", static_cast<int>(device));
|
2025-06-17 18:59:41 +00:00
|
|
|
if (force) {
|
|
|
|
|
return cudaSetDevice(device);
|
|
|
|
|
}
|
2023-04-10 17:31:12 +00:00
|
|
|
int cur_device = -1;
|
|
|
|
|
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
|
|
|
|
|
if (device == cur_device) {
|
|
|
|
|
return cudaSuccess;
|
|
|
|
|
}
|
|
|
|
|
return cudaSetDevice(device);
|
|
|
|
|
}
|
|
|
|
|
|
2024-02-08 23:00:52 +00:00
|
|
|
cudaError_t MaybeSetDevice(DeviceIndex device) {
|
2023-04-10 17:31:12 +00:00
|
|
|
return c10::cuda::SetDevice(device);
|
|
|
|
|
}
|
|
|
|
|
|
2024-02-08 23:00:52 +00:00
|
|
|
DeviceIndex ExchangeDevice(DeviceIndex to_device) {
|
|
|
|
|
DeviceIndex cur_device = -1;
|
2023-04-10 17:31:12 +00:00
|
|
|
C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
|
|
|
|
|
if (to_device == cur_device) {
|
|
|
|
|
return cur_device;
|
|
|
|
|
}
|
|
|
|
|
C10_CUDA_CHECK(cudaSetDevice(to_device));
|
|
|
|
|
return cur_device;
|
|
|
|
|
}
|
|
|
|
|
|
2024-02-08 23:00:52 +00:00
|
|
|
DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
|
2023-04-10 17:31:12 +00:00
|
|
|
return c10::cuda::ExchangeDevice(to_device);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetTargetDevice() {
|
|
|
|
|
// no-op on CUDA version < 12.x
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
2023-03-17 04:50:31 +00:00
|
|
|
} // namespace c10::cuda
|