mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
increase clang-tidy coverage to more c10 source files (#102902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102902 Approved by: https://github.com/Skylion007
This commit is contained in:
@@ -226,7 +226,7 @@ command = [
|
||||
[[linter]]
|
||||
code = 'CLANGTIDY'
|
||||
include_patterns = [
|
||||
'c10/core/**/*.cpp',
|
||||
'c10/**/*.cpp',
|
||||
'torch/csrc/fx/**/*.cpp',
|
||||
'torch/csrc/generic/**/*.cpp',
|
||||
'torch/csrc/onnx/**/*.cpp',
|
||||
@@ -239,6 +239,7 @@ exclude_patterns = [
|
||||
# FunctionsManual.cpp is excluded to keep this diff clean. It will be fixed
|
||||
# in a follow up PR.
|
||||
# that are not easily converted to accepted c++
|
||||
'c10/cuda/**/*.cpp',
|
||||
'c10/test/**/*.cpp',
|
||||
'torch/csrc/jit/passes/onnx/helper.cpp',
|
||||
'torch/csrc/jit/passes/onnx/shape_type_inference.cpp',
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
using c10::intrusive_ptr;
|
||||
using c10::intrusive_ptr_target;
|
||||
using c10::make_intrusive;
|
||||
using c10::weak_intrusive_ptr;
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ is_non_overlapping_and_dense
|
||||
* backend.
|
||||
**/
|
||||
struct C10_API BackendMeta : intrusive_ptr_target {
|
||||
virtual ~BackendMeta(){};
|
||||
~BackendMeta() override = default;
|
||||
virtual intrusive_ptr<BackendMeta> clone(
|
||||
const intrusive_ptr<BackendMeta>& ptr) const {
|
||||
return ptr;
|
||||
@@ -263,7 +263,7 @@ struct C10_API ExtraMeta {
|
||||
c10::optional<std::string> custom_data_ptr_error_msg_ = c10::nullopt)
|
||||
: symbolic_shape_meta_(std::move(symbolic_shape_meta)),
|
||||
named_tensor_meta_(std::move(named_tensor_meta)),
|
||||
backend_meta_(backend_meta) {}
|
||||
backend_meta_(std::move(backend_meta)) {}
|
||||
|
||||
std::unique_ptr<ExtraMeta> clone() const {
|
||||
return std::make_unique<ExtraMeta>(*this);
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <algorithm>
|
||||
#include <bitset>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <iostream>
|
||||
@@ -362,14 +363,14 @@ struct ExpandableSegment {
|
||||
int device,
|
||||
cudaStream_t stream,
|
||||
size_t size,
|
||||
const std::vector<int>& peers)
|
||||
std::vector<int> peers)
|
||||
: device_(device),
|
||||
stream_(stream),
|
||||
max_handles_(0),
|
||||
// 2MB for small pool, 20MB for large pool
|
||||
segment_size_(size),
|
||||
peers_(peers) {
|
||||
cudaDeviceProp prop;
|
||||
peers_(std::move(peers)) {
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
|
||||
// we allocate enough address space for 1 1/8 the total memory on the GPU.
|
||||
// This allows for some cases where we have to unmap pages earlier in the
|
||||
@@ -390,11 +391,11 @@ struct ExpandableSegment {
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
while (end > handles_.size()) {
|
||||
handles_.push_back(c10::nullopt);
|
||||
handles_.emplace_back(c10::nullopt);
|
||||
}
|
||||
for (auto i : c10::irange(begin, end)) {
|
||||
TORCH_INTERNAL_ASSERT(!handles_.at(i));
|
||||
CUmemGenericAllocationHandle handle;
|
||||
CUmemGenericAllocationHandle handle = 0;
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
@@ -523,7 +524,7 @@ struct ExpandableSegment {
|
||||
}
|
||||
int device_;
|
||||
cudaStream_t stream_;
|
||||
CUdeviceptr ptr_;
|
||||
CUdeviceptr ptr_{};
|
||||
size_t max_handles_;
|
||||
size_t segment_size_;
|
||||
std::vector<c10::optional<CUmemGenericAllocationHandle>> handles_;
|
||||
@@ -561,7 +562,7 @@ struct ExpandableSegment {
|
||||
// [Checkpointing PrivatePoolState]
|
||||
struct BlockState {
|
||||
int device = 0;
|
||||
cudaStream_t stream = 0;
|
||||
cudaStream_t stream = nullptr;
|
||||
stream_set stream_uses = {};
|
||||
size_t size = 0;
|
||||
void* ptr = nullptr;
|
||||
@@ -715,8 +716,8 @@ struct PrivatePool {
|
||||
PrivatePool()
|
||||
: use_count(1),
|
||||
cudaMalloc_count(0),
|
||||
large_blocks(/*is_small=*/false, this),
|
||||
small_blocks(/*is_small=*/true, this) {}
|
||||
large_blocks(/*small=*/false, this),
|
||||
small_blocks(/*small=*/true, this) {}
|
||||
PrivatePool(const PrivatePool&) = delete;
|
||||
PrivatePool(PrivatePool&&) = delete;
|
||||
PrivatePool& operator=(const PrivatePool&) = delete;
|
||||
@@ -759,7 +760,7 @@ SegmentState::SegmentState(Block* head) {
|
||||
PrivatePoolState::PrivatePoolState(
|
||||
MempoolId_t pool_id,
|
||||
const std::vector<Block*>& private_pool_head_blocks)
|
||||
: owner_id(pool_id) {
|
||||
: owner_id(std::move(pool_id)) {
|
||||
for (Block* head : private_pool_head_blocks) {
|
||||
segments.emplace_back(head);
|
||||
}
|
||||
@@ -891,7 +892,7 @@ void CachingAllocatorConfig::lexArgs(
|
||||
size_t env_length = strlen(env);
|
||||
for (size_t i = 0; i < env_length; i++) {
|
||||
if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') {
|
||||
if (buf.size() != 0) {
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
buf.clear();
|
||||
}
|
||||
@@ -964,7 +965,7 @@ size_t CachingAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
if (config[i].compare("[") == 0) {
|
||||
size_t last_index = 0;
|
||||
while (++i < config.size() && config[i].compare("]") != 0) {
|
||||
std::string val1 = config[i];
|
||||
const std::string& val1 = config[i];
|
||||
size_t val2 = 0;
|
||||
|
||||
consumeToken(config, ++i, ':');
|
||||
@@ -1048,7 +1049,7 @@ size_t CachingAllocatorConfig::parseAllocatorConfig(
|
||||
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
|
||||
if (used_cudaMallocAsync) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
int version;
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
@@ -1131,7 +1132,7 @@ static std::string reportProcessMemoryInfo(int device) {
|
||||
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_());
|
||||
});
|
||||
|
||||
cudaDeviceProp prop;
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
|
||||
char pci_id[80];
|
||||
@@ -1143,7 +1144,7 @@ static std::string reportProcessMemoryInfo(int device) {
|
||||
prop.pciBusID,
|
||||
prop.pciDeviceID);
|
||||
|
||||
nvmlDevice_t nvml_device;
|
||||
nvmlDevice_t nvml_device = nullptr;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
NVML_SUCCESS ==
|
||||
DriverAPI::get()->nvmlDeviceGetHandleByPciBusId_v2_(
|
||||
@@ -1250,8 +1251,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
public:
|
||||
DeviceCachingAllocator()
|
||||
: large_blocks(/*is_small=*/false),
|
||||
small_blocks(/*is_small=*/true),
|
||||
: large_blocks(/*small=*/false),
|
||||
small_blocks(/*small=*/true),
|
||||
alloc_trace(new std::vector<TraceEntry>()) {
|
||||
stats.max_split_size = CachingAllocatorConfig::max_split_size();
|
||||
context_recorder_.store(nullptr);
|
||||
@@ -1280,7 +1281,7 @@ class DeviceCachingAllocator {
|
||||
const std::unordered_set<void*>& expected_live_allocations) {
|
||||
std::unique_lock<std::recursive_mutex> lock(mutex);
|
||||
|
||||
PrivatePool* pool;
|
||||
PrivatePool* pool = nullptr;
|
||||
auto pool_it = graph_pools.find(mempool_id);
|
||||
TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id");
|
||||
pool = pool_it->second.get();
|
||||
@@ -1370,8 +1371,8 @@ class DeviceCachingAllocator {
|
||||
// alloc_block should have thrown an exception already.
|
||||
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
|
||||
|
||||
size_t device_free;
|
||||
size_t device_total;
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
std::string allowed_info;
|
||||
|
||||
@@ -1660,8 +1661,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
/** set memory fraction to limit maximum allocated memory **/
|
||||
void setMemoryFraction(double fraction) {
|
||||
size_t device_free;
|
||||
size_t device_total;
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
@@ -1678,7 +1679,7 @@ class DeviceCachingAllocator {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
if (*largest ==
|
||||
0) { // make an initial guess if a zero *largest is passed in
|
||||
size_t tmp_bytes;
|
||||
size_t tmp_bytes = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(
|
||||
largest, // Use free memory as an optimistic initial guess of *largest
|
||||
&tmp_bytes));
|
||||
@@ -2038,7 +2039,7 @@ class DeviceCachingAllocator {
|
||||
});
|
||||
|
||||
if (record_history) {
|
||||
record_trace(TraceEntry::SNAPSHOT, 0, total_active, 0, nullptr);
|
||||
record_trace(TraceEntry::SNAPSHOT, 0, total_active, nullptr, nullptr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -2665,7 +2666,7 @@ class DeviceCachingAllocator {
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
size_t size = p.alloc_size;
|
||||
void* ptr;
|
||||
void* ptr = nullptr;
|
||||
|
||||
if (isRetry) {
|
||||
stats.num_alloc_retries += 1;
|
||||
@@ -2977,7 +2978,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
void insert_events(Block* block) {
|
||||
int prev_device;
|
||||
int prev_device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
|
||||
|
||||
stream_set streams(std::move(block->stream_uses));
|
||||
@@ -2997,7 +2998,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
void insert_events_deferred_until_no_capture() {
|
||||
if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) {
|
||||
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
|
||||
for (auto* block : needs_events_deferred_until_no_capture) {
|
||||
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
|
||||
insert_events(block);
|
||||
@@ -3140,7 +3141,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
}
|
||||
|
||||
bool initialized() override {
|
||||
return device_allocator.size() > 0;
|
||||
return !device_allocator.empty();
|
||||
}
|
||||
|
||||
/** allocates a block which is safe to use from the provided stream */
|
||||
@@ -3196,17 +3197,17 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
bool alloc_trace_record_context) override {
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
device_allocator[device]->recordHistory(
|
||||
enabled,
|
||||
std::move(context_recorder),
|
||||
context_recorder,
|
||||
alloc_trace_max_entries,
|
||||
alloc_trace_record_context);
|
||||
}
|
||||
|
||||
bool isHistoryEnabled() override {
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
return device_allocator[device]->isHistoryEnabled();
|
||||
}
|
||||
@@ -3220,7 +3221,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
}
|
||||
|
||||
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
device_allocator[device]->attachOutOfMemoryObserver(std::move(observer));
|
||||
}
|
||||
@@ -3319,7 +3320,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
OutOfMemoryError,
|
||||
size < one_exa_bytes,
|
||||
"CUDA out of memory. Tried to allocate more than 1EB memory.");
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
if (forceUncachedAllocator()) {
|
||||
@@ -3396,7 +3397,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
if (nbytes == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
|
||||
@@ -3407,7 +3408,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
if (nbytes == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
malloc(&r, device, nbytes, stream);
|
||||
@@ -3484,7 +3485,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
|
||||
// devPtr has to be deleted in same device when created.
|
||||
int curr_device;
|
||||
int curr_device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
|
||||
auto sp =
|
||||
std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) {
|
||||
@@ -3590,7 +3591,7 @@ struct BackendStaticInitializer {
|
||||
}
|
||||
};
|
||||
|
||||
std::atomic<CUDAAllocator*> allocator{};
|
||||
std::atomic<CUDAAllocator*> allocator;
|
||||
BackendStaticInitializer backend_static_initializer;
|
||||
|
||||
} // namespace CUDACachingAllocator
|
||||
|
||||
@@ -14,7 +14,7 @@ int32_t driver_version() {
|
||||
}
|
||||
|
||||
int device_count_impl(bool fail_if_no_driver) {
|
||||
int count;
|
||||
int count = 0;
|
||||
auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count));
|
||||
if (err == cudaSuccess) {
|
||||
return count;
|
||||
@@ -121,7 +121,7 @@ DeviceIndex device_count_ensure_non_zero() {
|
||||
}
|
||||
|
||||
DeviceIndex current_device() {
|
||||
int cur_device;
|
||||
int cur_device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
|
||||
return static_cast<DeviceIndex>(cur_device);
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ struct PtrUsage {
|
||||
// recorded_streams holds side usage streams added by record_stream calls.
|
||||
// In other words, it does NOT include the original creation stream.
|
||||
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
|
||||
UsageStream creation_stream;
|
||||
UsageStream creation_stream{};
|
||||
uint64_t size;
|
||||
bool captured;
|
||||
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
|
||||
@@ -152,7 +152,7 @@ inline void lazy_init_device(int device) {
|
||||
|
||||
// See "Retaining memory in the pool" here:
|
||||
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
|
||||
cudaMemPool_t mempool;
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||
uint64_t threshold = UINT64_MAX;
|
||||
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
|
||||
@@ -183,7 +183,7 @@ inline void lazy_init_device(int device) {
|
||||
|
||||
inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
|
||||
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
|
||||
cudaEvent_t event;
|
||||
cudaEvent_t event = nullptr;
|
||||
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
C10_CUDA_CHECK(cudaEventRecord(event, dependency));
|
||||
C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
|
||||
@@ -331,7 +331,7 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
if (!capture_underway &&
|
||||
ungraphed_ptrs_defer_free_until_no_capture.size() > 0) {
|
||||
!ungraphed_ptrs_defer_free_until_no_capture.empty()) {
|
||||
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
|
||||
for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
|
||||
auto it = ptr_info.find(ptr);
|
||||
@@ -363,8 +363,8 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
|
||||
// allocation. This aligns with the behavior of alloc_block in
|
||||
// CUDACachingAllocator.cpp.
|
||||
(void)cudaGetLastError(); // clear CUDA error
|
||||
size_t device_free;
|
||||
size_t device_total;
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
TORCH_CHECK_WITH(
|
||||
OutOfMemoryError,
|
||||
@@ -410,7 +410,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
OutOfMemoryError,
|
||||
size < one_exa_bytes,
|
||||
"CUDA out of memory. Tried to allocate more than 1EB memory.");
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
if (size != 0) {
|
||||
@@ -442,7 +442,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
}
|
||||
|
||||
bool initialized() override {
|
||||
return devs_initialized_flags.size() > 0;
|
||||
return !devs_initialized_flags.empty();
|
||||
}
|
||||
|
||||
static inline void assertValidDevice(int device) {
|
||||
@@ -466,8 +466,8 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// TORCH_CHECK(devs_initialized_flags[device], ...)?
|
||||
lazy_init_device(device);
|
||||
|
||||
size_t device_free;
|
||||
size_t device_total;
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
pytorch_memory_limits[device] =
|
||||
static_cast<uint64_t>(fraction * device_total);
|
||||
@@ -481,14 +481,14 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// introduces performance nondeterminism.
|
||||
}
|
||||
|
||||
void emptyCache(void) override {
|
||||
void emptyCache() override {
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
for (int dev = 0; dev < device_count; dev++) {
|
||||
if (devs_initialized_flags[dev]) {
|
||||
CUDAGuard g(dev);
|
||||
|
||||
cudaMemPool_t mempool;
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
cudaDeviceGetDefaultMemPool(&mempool, dev);
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemPoolTrimTo(mempool, 0);
|
||||
@@ -533,8 +533,8 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
CUDAGuard g(device);
|
||||
lazy_init_device(device);
|
||||
|
||||
size_t free_upper_bound;
|
||||
size_t device_total;
|
||||
size_t free_upper_bound = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
free_upper_bound + pytorch_used_bytes[device] <= device_total);
|
||||
@@ -542,7 +542,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
free_upper_bound,
|
||||
pytorch_memory_limits[device] - pytorch_used_bytes[device]);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
void* dummy;
|
||||
void* dummy = nullptr;
|
||||
|
||||
// Defensively checks for preexisting CUDA error state.
|
||||
auto err = cudaGetLastError();
|
||||
@@ -668,7 +668,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
if (devs_initialized_flags[device]) {
|
||||
CUDAGuard g(device);
|
||||
|
||||
cudaMemPool_t mempool;
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
|
||||
mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
|
||||
@@ -725,7 +725,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
assertValidDevice(device);
|
||||
|
||||
CUDAGuard g(device);
|
||||
cudaMemPool_t mempool;
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||
// Using zero as the reset value is the method recommended by Cuda driver
|
||||
// team. Vivek Kini says:
|
||||
@@ -783,7 +783,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
CUDAGuard g(free_stream.device);
|
||||
|
||||
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
|
||||
cudaEvent_t event;
|
||||
cudaEvent_t event = nullptr;
|
||||
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
|
||||
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
|
||||
C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
|
||||
@@ -817,7 +817,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
if (nbytes == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
|
||||
@@ -828,7 +828,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
if (nbytes == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
int device;
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
mallocAsync(&r, device, nbytes, stream);
|
||||
@@ -843,7 +843,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// cudaDeviceEnablePeerAccess. We need pool-specific enablement. See
|
||||
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
|
||||
c10::cuda::CUDAGuard device_guard(dev);
|
||||
cudaMemPool_t mempool;
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
|
||||
cudaMemAccessDesc desc = {};
|
||||
desc.location.type = cudaMemLocationTypeDevice;
|
||||
@@ -851,7 +851,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
desc.flags = cudaMemAccessFlagsProtReadWrite;
|
||||
C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
|
||||
}
|
||||
virtual cudaError_t memcpyAsync(
|
||||
cudaError_t memcpyAsync(
|
||||
void* dst,
|
||||
int dstDevice,
|
||||
const void* src,
|
||||
|
||||
@@ -14,7 +14,7 @@ DriverAPI create_driver_api() {
|
||||
|
||||
C10_FORALL_DRIVER_LIBRARIES(OPEN_LIBRARIES)
|
||||
#undef OPEN_LIBRARIES
|
||||
DriverAPI r;
|
||||
DriverAPI r{};
|
||||
|
||||
#define LOOKUP_ENTRY(name, n) \
|
||||
r.name##_ = ((decltype(&name))dlsym(handle_##n, #name)); \
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace cuda {
|
||||
namespace impl {
|
||||
|
||||
bool has_cuda_gpu() {
|
||||
int count;
|
||||
int count = 0;
|
||||
C10_CUDA_IGNORE_ERROR(cudaGetDeviceCount(&count));
|
||||
|
||||
return count != 0;
|
||||
|
||||
@@ -22,9 +22,8 @@ C10_DEFINE_bool(
|
||||
namespace c10 {
|
||||
|
||||
namespace {
|
||||
// NOLINTNEXTLINE(modernize-redundant-void-arg)
|
||||
std::function<string(void)>* GetFetchStackTrace() {
|
||||
static std::function<string(void)> func = []() {
|
||||
std::function<string()>* GetFetchStackTrace() {
|
||||
static std::function<string()> func = []() {
|
||||
return get_backtrace(/*frames_to_skip=*/1);
|
||||
};
|
||||
return &func;
|
||||
|
||||
@@ -112,6 +112,7 @@ void* SmallVectorBase<Size_T>::mallocForGrow(
|
||||
size_t TSize,
|
||||
size_t& NewCapacity) {
|
||||
NewCapacity = getNewCapacity<Size_T>(MinSize, TSize, this->capacity());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
auto Result = std::malloc(NewCapacity * TSize);
|
||||
if (Result == nullptr) {
|
||||
throw std::bad_alloc();
|
||||
@@ -128,6 +129,7 @@ void SmallVectorBase<Size_T>::grow_pod(
|
||||
size_t NewCapacity = getNewCapacity<Size_T>(MinSize, TSize, this->capacity());
|
||||
void* NewElts = nullptr;
|
||||
if (BeginX == FirstEl) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
NewElts = std::malloc(NewCapacity * TSize);
|
||||
if (NewElts == nullptr) {
|
||||
throw std::bad_alloc();
|
||||
@@ -137,6 +139,7 @@ void SmallVectorBase<Size_T>::grow_pod(
|
||||
memcpy(NewElts, this->BeginX, size() * TSize);
|
||||
} else {
|
||||
// If this wasn't grown from the inline copy, grow the allocated space.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
NewElts = std::realloc(this->BeginX, NewCapacity * TSize);
|
||||
if (NewElts == nullptr) {
|
||||
throw std::bad_alloc();
|
||||
|
||||
@@ -338,9 +338,7 @@ struct last<typelist<Head>> final {
|
||||
};
|
||||
template <class TypeList>
|
||||
using last_t = typename last<TypeList>::type;
|
||||
static_assert(
|
||||
std::is_same<int, last_t<typelist<double, float, int>>>::value,
|
||||
"");
|
||||
static_assert(std::is_same<int, last_t<typelist<double, float, int>>>::value);
|
||||
|
||||
/**
|
||||
* Take/drop a number of arguments from a typelist.
|
||||
|
||||
@@ -57,13 +57,14 @@ const uint128_pod kuint128max = {
|
||||
} while (0)
|
||||
static inline int Fls64(uint64_t n) {
|
||||
// GOOGLE_DCHECK_NE(0, n);
|
||||
int pos = 0;
|
||||
uint64_t pos = 0;
|
||||
STEP(uint64_t, n, pos, 0x20);
|
||||
uint32_t n32 = n;
|
||||
STEP(uint32_t, n32, pos, 0x10);
|
||||
STEP(uint32_t, n32, pos, 0x08);
|
||||
STEP(uint32_t, n32, pos, 0x04);
|
||||
return pos + ((uint64_t{0x3333333322221100u} >> (n32 << 2)) & 0x3);
|
||||
return static_cast<int>(
|
||||
pos + ((uint64_t{0x3333333322221100u} >> (n32 << 2)) & 0x3));
|
||||
}
|
||||
#undef STEP
|
||||
|
||||
@@ -128,7 +129,7 @@ std::ostream& operator<<(std::ostream& o, const uint128& b) {
|
||||
|
||||
// Select a divisor which is the largest power of the base < 2^64.
|
||||
uint128 div;
|
||||
std::streamsize div_base_log = 0;
|
||||
int div_base_log = 0;
|
||||
switch (flags & std::ios::basefield) {
|
||||
case std::ios::hex:
|
||||
div = (uint64_t)0x1000000000000000u; // 16^15
|
||||
|
||||
@@ -50,6 +50,7 @@ struct DontIncreaseRefcount {};
|
||||
// tells us if the object was allocated by us. If it wasn't, no
|
||||
// intrusive_ptr for you!
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor)
|
||||
class C10_API intrusive_ptr_target {
|
||||
// Note [Weak references for intrusive refcounting]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -47,6 +47,7 @@ int GetNUMANode(const void* ptr) {
|
||||
AT_ASSERT(ptr);
|
||||
|
||||
int numa_node = -1;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
TORCH_CHECK(
|
||||
get_mempolicy(
|
||||
&numa_node,
|
||||
@@ -78,12 +79,15 @@ void NUMAMove(void* ptr, size_t size, int numa_node_id) {
|
||||
|
||||
uintptr_t page_start_ptr =
|
||||
((reinterpret_cast<uintptr_t>(ptr)) & ~(getpagesize() - 1));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions)
|
||||
ptrdiff_t offset = reinterpret_cast<uintptr_t>(ptr) - page_start_ptr;
|
||||
// Avoid extra dynamic allocation and NUMA api calls
|
||||
AT_ASSERT(
|
||||
numa_node_id >= 0 &&
|
||||
static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
|
||||
unsigned long mask = 1UL << numa_node_id;
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
TORCH_CHECK(
|
||||
mbind(
|
||||
reinterpret_cast<void*>(page_start_ptr),
|
||||
|
||||
@@ -84,7 +84,7 @@ class reverse_iterator {
|
||||
|
||||
constexpr reverse_iterator& operator=(const reverse_iterator& rhs) noexcept {
|
||||
current = rhs.current;
|
||||
return current;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename _Iter>
|
||||
|
||||
@@ -112,6 +112,7 @@ FatalSignalHandler::FatalSignalHandler()
|
||||
writingCond(PTHREAD_COND_INITIALIZER),
|
||||
writingMutex(PTHREAD_MUTEX_INITIALIZER) {}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
FatalSignalHandler::signal_handler FatalSignalHandler::kSignalHandlers[] = {
|
||||
{"SIGABRT", SIGABRT, {}},
|
||||
{"SIGINT", SIGINT, {}},
|
||||
@@ -159,7 +160,7 @@ void FatalSignalHandler::stacktraceSignalHandler(bool needsLock) {
|
||||
if (needsLock) {
|
||||
pthread_mutex_lock(&writingMutex);
|
||||
}
|
||||
pid_t tid = syscall(SYS_gettid);
|
||||
pid_t tid = static_cast<pid_t>(syscall(SYS_gettid));
|
||||
std::string backtrace = fmt::format(
|
||||
"{}({}), PID: {}, Thread {}: \n {}",
|
||||
fatalSignalName,
|
||||
@@ -201,7 +202,7 @@ void FatalSignalHandler::fatalSignalHandler(int signum) {
|
||||
DIR* procDir = opendir("/proc/self/task");
|
||||
if (procDir) {
|
||||
pid_t pid = getpid();
|
||||
pid_t currentTid = syscall(SYS_gettid);
|
||||
pid_t currentTid = static_cast<pid_t>(syscall(SYS_gettid));
|
||||
struct dirent* entry = nullptr;
|
||||
pthread_mutex_lock(&writingMutex);
|
||||
while ((entry = readdir(procDir)) != nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user