increase clang-tidy coverage to more c10 source files (#102902)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102902
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2023-06-04 06:33:01 +00:00
committed by PyTorch MergeBot
parent 992bffe5a3
commit 87cbfe957a
16 changed files with 87 additions and 79 deletions

View File

@@ -226,7 +226,7 @@ command = [
[[linter]]
code = 'CLANGTIDY'
include_patterns = [
'c10/core/**/*.cpp',
'c10/**/*.cpp',
'torch/csrc/fx/**/*.cpp',
'torch/csrc/generic/**/*.cpp',
'torch/csrc/onnx/**/*.cpp',
@@ -239,6 +239,7 @@ exclude_patterns = [
# FunctionsManual.cpp is excluded to keep this diff clean. It will be fixed
# in a follow up PR.
# that are not easily converted to accepted c++
'c10/cuda/**/*.cpp',
'c10/test/**/*.cpp',
'torch/csrc/jit/passes/onnx/helper.cpp',
'torch/csrc/jit/passes/onnx/shape_type_inference.cpp',

View File

@@ -7,7 +7,6 @@
using c10::intrusive_ptr;
using c10::intrusive_ptr_target;
using c10::make_intrusive;
using c10::weak_intrusive_ptr;
namespace {

View File

@@ -213,7 +213,7 @@ is_non_overlapping_and_dense
* backend.
**/
struct C10_API BackendMeta : intrusive_ptr_target {
virtual ~BackendMeta(){};
~BackendMeta() override = default;
virtual intrusive_ptr<BackendMeta> clone(
const intrusive_ptr<BackendMeta>& ptr) const {
return ptr;
@@ -263,7 +263,7 @@ struct C10_API ExtraMeta {
c10::optional<std::string> custom_data_ptr_error_msg_ = c10::nullopt)
: symbolic_shape_meta_(std::move(symbolic_shape_meta)),
named_tensor_meta_(std::move(named_tensor_meta)),
backend_meta_(backend_meta) {}
backend_meta_(std::move(backend_meta)) {}
std::unique_ptr<ExtraMeta> clone() const {
return std::make_unique<ExtraMeta>(*this);

View File

@@ -20,6 +20,7 @@
#include <cuda_runtime_api.h>
#include <algorithm>
#include <bitset>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <iostream>
@@ -362,14 +363,14 @@ struct ExpandableSegment {
int device,
cudaStream_t stream,
size_t size,
const std::vector<int>& peers)
std::vector<int> peers)
: device_(device),
stream_(stream),
max_handles_(0),
// 2MB for small pool, 20MB for large pool
segment_size_(size),
peers_(peers) {
cudaDeviceProp prop;
peers_(std::move(peers)) {
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
// we allocate enough address space for 1 1/8 the total memory on the GPU.
// This allows for some cases where we have to unmap pages earlier in the
@@ -390,11 +391,11 @@ struct ExpandableSegment {
return rangeFromHandles(begin, end);
}
while (end > handles_.size()) {
handles_.push_back(c10::nullopt);
handles_.emplace_back(c10::nullopt);
}
for (auto i : c10::irange(begin, end)) {
TORCH_INTERNAL_ASSERT(!handles_.at(i));
CUmemGenericAllocationHandle handle;
CUmemGenericAllocationHandle handle = 0;
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
@@ -523,7 +524,7 @@ struct ExpandableSegment {
}
int device_;
cudaStream_t stream_;
CUdeviceptr ptr_;
CUdeviceptr ptr_{};
size_t max_handles_;
size_t segment_size_;
std::vector<c10::optional<CUmemGenericAllocationHandle>> handles_;
@@ -561,7 +562,7 @@ struct ExpandableSegment {
// [Checkpointing PrivatePoolState]
struct BlockState {
int device = 0;
cudaStream_t stream = 0;
cudaStream_t stream = nullptr;
stream_set stream_uses = {};
size_t size = 0;
void* ptr = nullptr;
@@ -715,8 +716,8 @@ struct PrivatePool {
PrivatePool()
: use_count(1),
cudaMalloc_count(0),
large_blocks(/*is_small=*/false, this),
small_blocks(/*is_small=*/true, this) {}
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
@@ -759,7 +760,7 @@ SegmentState::SegmentState(Block* head) {
PrivatePoolState::PrivatePoolState(
MempoolId_t pool_id,
const std::vector<Block*>& private_pool_head_blocks)
: owner_id(pool_id) {
: owner_id(std::move(pool_id)) {
for (Block* head : private_pool_head_blocks) {
segments.emplace_back(head);
}
@@ -891,7 +892,7 @@ void CachingAllocatorConfig::lexArgs(
size_t env_length = strlen(env);
for (size_t i = 0; i < env_length; i++) {
if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') {
if (buf.size() != 0) {
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
buf.clear();
}
@@ -964,7 +965,7 @@ size_t CachingAllocatorConfig::parseRoundUpPower2Divisions(
if (config[i].compare("[") == 0) {
size_t last_index = 0;
while (++i < config.size() && config[i].compare("]") != 0) {
std::string val1 = config[i];
const std::string& val1 = config[i];
size_t val2 = 0;
consumeToken(config, ++i, ':');
@@ -1048,7 +1049,7 @@ size_t CachingAllocatorConfig::parseAllocatorConfig(
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version;
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
@@ -1131,7 +1132,7 @@ static std::string reportProcessMemoryInfo(int device) {
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_());
});
cudaDeviceProp prop;
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
char pci_id[80];
@@ -1143,7 +1144,7 @@ static std::string reportProcessMemoryInfo(int device) {
prop.pciBusID,
prop.pciDeviceID);
nvmlDevice_t nvml_device;
nvmlDevice_t nvml_device = nullptr;
TORCH_INTERNAL_ASSERT(
NVML_SUCCESS ==
DriverAPI::get()->nvmlDeviceGetHandleByPciBusId_v2_(
@@ -1250,8 +1251,8 @@ class DeviceCachingAllocator {
public:
DeviceCachingAllocator()
: large_blocks(/*is_small=*/false),
small_blocks(/*is_small=*/true),
: large_blocks(/*small=*/false),
small_blocks(/*small=*/true),
alloc_trace(new std::vector<TraceEntry>()) {
stats.max_split_size = CachingAllocatorConfig::max_split_size();
context_recorder_.store(nullptr);
@@ -1280,7 +1281,7 @@ class DeviceCachingAllocator {
const std::unordered_set<void*>& expected_live_allocations) {
std::unique_lock<std::recursive_mutex> lock(mutex);
PrivatePool* pool;
PrivatePool* pool = nullptr;
auto pool_it = graph_pools.find(mempool_id);
TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id");
pool = pool_it->second.get();
@@ -1370,8 +1371,8 @@ class DeviceCachingAllocator {
// alloc_block should have thrown an exception already.
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
size_t device_free;
size_t device_total;
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
std::string allowed_info;
@@ -1660,8 +1661,8 @@ class DeviceCachingAllocator {
/** set memory fraction to limit maximum allocated memory **/
void setMemoryFraction(double fraction) {
size_t device_free;
size_t device_total;
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
@@ -1678,7 +1679,7 @@ class DeviceCachingAllocator {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (*largest ==
0) { // make an initial guess if a zero *largest is passed in
size_t tmp_bytes;
size_t tmp_bytes = 0;
C10_CUDA_CHECK(cudaMemGetInfo(
largest, // Use free memory as an optimistic initial guess of *largest
&tmp_bytes));
@@ -2038,7 +2039,7 @@ class DeviceCachingAllocator {
});
if (record_history) {
record_trace(TraceEntry::SNAPSHOT, 0, total_active, 0, nullptr);
record_trace(TraceEntry::SNAPSHOT, 0, total_active, nullptr, nullptr);
}
return result;
}
@@ -2665,7 +2666,7 @@ class DeviceCachingAllocator {
C10_CUDA_CHECK(cudaGetLastError());
size_t size = p.alloc_size;
void* ptr;
void* ptr = nullptr;
if (isRetry) {
stats.num_alloc_retries += 1;
@@ -2977,7 +2978,7 @@ class DeviceCachingAllocator {
}
void insert_events(Block* block) {
int prev_device;
int prev_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
stream_set streams(std::move(block->stream_uses));
@@ -2997,7 +2998,7 @@ class DeviceCachingAllocator {
}
void insert_events_deferred_until_no_capture() {
if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) {
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
for (auto* block : needs_events_deferred_until_no_capture) {
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
insert_events(block);
@@ -3140,7 +3141,7 @@ class NativeCachingAllocator : public CUDAAllocator {
}
bool initialized() override {
return device_allocator.size() > 0;
return !device_allocator.empty();
}
/** allocates a block which is safe to use from the provided stream */
@@ -3196,17 +3197,17 @@ class NativeCachingAllocator : public CUDAAllocator {
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
bool alloc_trace_record_context) override {
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->recordHistory(
enabled,
std::move(context_recorder),
context_recorder,
alloc_trace_max_entries,
alloc_trace_record_context);
}
bool isHistoryEnabled() override {
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return device_allocator[device]->isHistoryEnabled();
}
@@ -3220,7 +3221,7 @@ class NativeCachingAllocator : public CUDAAllocator {
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->attachOutOfMemoryObserver(std::move(observer));
}
@@ -3319,7 +3320,7 @@ class NativeCachingAllocator : public CUDAAllocator {
OutOfMemoryError,
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
if (forceUncachedAllocator()) {
@@ -3396,7 +3397,7 @@ class NativeCachingAllocator : public CUDAAllocator {
if (nbytes == 0) {
return nullptr;
}
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
@@ -3407,7 +3408,7 @@ class NativeCachingAllocator : public CUDAAllocator {
if (nbytes == 0) {
return nullptr;
}
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
malloc(&r, device, nbytes, stream);
@@ -3484,7 +3485,7 @@ class NativeCachingAllocator : public CUDAAllocator {
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
// devPtr has to be deleted in same device when created.
int curr_device;
int curr_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
auto sp =
std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) {
@@ -3590,7 +3591,7 @@ struct BackendStaticInitializer {
}
};
std::atomic<CUDAAllocator*> allocator{};
std::atomic<CUDAAllocator*> allocator;
BackendStaticInitializer backend_static_initializer;
} // namespace CUDACachingAllocator

View File

@@ -14,7 +14,7 @@ int32_t driver_version() {
}
int device_count_impl(bool fail_if_no_driver) {
int count;
int count = 0;
auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count));
if (err == cudaSuccess) {
return count;
@@ -121,7 +121,7 @@ DeviceIndex device_count_ensure_non_zero() {
}
DeviceIndex current_device() {
int cur_device;
int cur_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
return static_cast<DeviceIndex>(cur_device);
}

View File

@@ -58,7 +58,7 @@ struct PtrUsage {
// recorded_streams holds side usage streams added by record_stream calls.
// In other words, it does NOT include the original creation stream.
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
UsageStream creation_stream;
UsageStream creation_stream{};
uint64_t size;
bool captured;
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
@@ -152,7 +152,7 @@ inline void lazy_init_device(int device) {
// See "Retaining memory in the pool" here:
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
cudaMemPool_t mempool;
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
C10_CUDA_CHECK(cudaMemPoolSetAttribute(
@@ -183,7 +183,7 @@ inline void lazy_init_device(int device) {
inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event;
cudaEvent_t event = nullptr;
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
C10_CUDA_CHECK(cudaEventRecord(event, dependency));
C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
@@ -331,7 +331,7 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
std::lock_guard<std::mutex> lk(general_mutex);
if (!capture_underway &&
ungraphed_ptrs_defer_free_until_no_capture.size() > 0) {
!ungraphed_ptrs_defer_free_until_no_capture.empty()) {
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
auto it = ptr_info.find(ptr);
@@ -363,8 +363,8 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
// allocation. This aligns with the behavior of alloc_block in
// CUDACachingAllocator.cpp.
(void)cudaGetLastError(); // clear CUDA error
size_t device_free;
size_t device_total;
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
TORCH_CHECK_WITH(
OutOfMemoryError,
@@ -410,7 +410,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
OutOfMemoryError,
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
if (size != 0) {
@@ -442,7 +442,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
}
bool initialized() override {
return devs_initialized_flags.size() > 0;
return !devs_initialized_flags.empty();
}
static inline void assertValidDevice(int device) {
@@ -466,8 +466,8 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// TORCH_CHECK(devs_initialized_flags[device], ...)?
lazy_init_device(device);
size_t device_free;
size_t device_total;
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
pytorch_memory_limits[device] =
static_cast<uint64_t>(fraction * device_total);
@@ -481,14 +481,14 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// introduces performance nondeterminism.
}
void emptyCache(void) override {
void emptyCache() override {
std::lock_guard<std::mutex> lk(general_mutex);
for (int dev = 0; dev < device_count; dev++) {
if (devs_initialized_flags[dev]) {
CUDAGuard g(dev);
cudaMemPool_t mempool;
cudaMemPool_t mempool = nullptr;
cudaDeviceGetDefaultMemPool(&mempool, dev);
cudaDeviceSynchronize();
cudaMemPoolTrimTo(mempool, 0);
@@ -533,8 +533,8 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
CUDAGuard g(device);
lazy_init_device(device);
size_t free_upper_bound;
size_t device_total;
size_t free_upper_bound = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
TORCH_INTERNAL_ASSERT(
free_upper_bound + pytorch_used_bytes[device] <= device_total);
@@ -542,7 +542,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
free_upper_bound,
pytorch_memory_limits[device] - pytorch_used_bytes[device]);
auto stream = c10::cuda::getCurrentCUDAStream();
void* dummy;
void* dummy = nullptr;
// Defensively checks for preexisting CUDA error state.
auto err = cudaGetLastError();
@@ -668,7 +668,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
if (devs_initialized_flags[device]) {
CUDAGuard g(device);
cudaMemPool_t mempool;
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
C10_CUDA_CHECK(cudaMemPoolGetAttribute(
mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
@@ -725,7 +725,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
assertValidDevice(device);
CUDAGuard g(device);
cudaMemPool_t mempool;
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
// Using zero as the reset value is the method recommended by Cuda driver
// team. Vivek Kini says:
@@ -783,7 +783,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
CUDAGuard g(free_stream.device);
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event;
cudaEvent_t event = nullptr;
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
@@ -817,7 +817,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
if (nbytes == 0) {
return nullptr;
}
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
@@ -828,7 +828,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
if (nbytes == 0) {
return nullptr;
}
int device;
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, stream);
@@ -843,7 +843,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// cudaDeviceEnablePeerAccess. We need pool-specific enablement. See
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
c10::cuda::CUDAGuard device_guard(dev);
cudaMemPool_t mempool;
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
cudaMemAccessDesc desc = {};
desc.location.type = cudaMemLocationTypeDevice;
@@ -851,7 +851,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
desc.flags = cudaMemAccessFlagsProtReadWrite;
C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
}
virtual cudaError_t memcpyAsync(
cudaError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,

View File

@@ -14,7 +14,7 @@ DriverAPI create_driver_api() {
C10_FORALL_DRIVER_LIBRARIES(OPEN_LIBRARIES)
#undef OPEN_LIBRARIES
DriverAPI r;
DriverAPI r{};
#define LOOKUP_ENTRY(name, n) \
r.name##_ = ((decltype(&name))dlsym(handle_##n, #name)); \

View File

@@ -10,7 +10,7 @@ namespace cuda {
namespace impl {
bool has_cuda_gpu() {
int count;
int count = 0;
C10_CUDA_IGNORE_ERROR(cudaGetDeviceCount(&count));
return count != 0;

View File

@@ -22,9 +22,8 @@ C10_DEFINE_bool(
namespace c10 {
namespace {
// NOLINTNEXTLINE(modernize-redundant-void-arg)
std::function<string(void)>* GetFetchStackTrace() {
static std::function<string(void)> func = []() {
std::function<string()>* GetFetchStackTrace() {
static std::function<string()> func = []() {
return get_backtrace(/*frames_to_skip=*/1);
};
return &func;

View File

@@ -112,6 +112,7 @@ void* SmallVectorBase<Size_T>::mallocForGrow(
size_t TSize,
size_t& NewCapacity) {
NewCapacity = getNewCapacity<Size_T>(MinSize, TSize, this->capacity());
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
auto Result = std::malloc(NewCapacity * TSize);
if (Result == nullptr) {
throw std::bad_alloc();
@@ -128,6 +129,7 @@ void SmallVectorBase<Size_T>::grow_pod(
size_t NewCapacity = getNewCapacity<Size_T>(MinSize, TSize, this->capacity());
void* NewElts = nullptr;
if (BeginX == FirstEl) {
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
NewElts = std::malloc(NewCapacity * TSize);
if (NewElts == nullptr) {
throw std::bad_alloc();
@@ -137,6 +139,7 @@ void SmallVectorBase<Size_T>::grow_pod(
memcpy(NewElts, this->BeginX, size() * TSize);
} else {
// If this wasn't grown from the inline copy, grow the allocated space.
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
NewElts = std::realloc(this->BeginX, NewCapacity * TSize);
if (NewElts == nullptr) {
throw std::bad_alloc();

View File

@@ -338,9 +338,7 @@ struct last<typelist<Head>> final {
};
template <class TypeList>
using last_t = typename last<TypeList>::type;
static_assert(
std::is_same<int, last_t<typelist<double, float, int>>>::value,
"");
static_assert(std::is_same<int, last_t<typelist<double, float, int>>>::value);
/**
* Take/drop a number of arguments from a typelist.

View File

@@ -57,13 +57,14 @@ const uint128_pod kuint128max = {
} while (0)
static inline int Fls64(uint64_t n) {
// GOOGLE_DCHECK_NE(0, n);
int pos = 0;
uint64_t pos = 0;
STEP(uint64_t, n, pos, 0x20);
uint32_t n32 = n;
STEP(uint32_t, n32, pos, 0x10);
STEP(uint32_t, n32, pos, 0x08);
STEP(uint32_t, n32, pos, 0x04);
return pos + ((uint64_t{0x3333333322221100u} >> (n32 << 2)) & 0x3);
return static_cast<int>(
pos + ((uint64_t{0x3333333322221100u} >> (n32 << 2)) & 0x3));
}
#undef STEP
@@ -128,7 +129,7 @@ std::ostream& operator<<(std::ostream& o, const uint128& b) {
// Select a divisor which is the largest power of the base < 2^64.
uint128 div;
std::streamsize div_base_log = 0;
int div_base_log = 0;
switch (flags & std::ios::basefield) {
case std::ios::hex:
div = (uint64_t)0x1000000000000000u; // 16^15

View File

@@ -50,6 +50,7 @@ struct DontIncreaseRefcount {};
// tells us if the object was allocated by us. If it wasn't, no
// intrusive_ptr for you!
// NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor)
class C10_API intrusive_ptr_target {
// Note [Weak references for intrusive refcounting]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -47,6 +47,7 @@ int GetNUMANode(const void* ptr) {
AT_ASSERT(ptr);
int numa_node = -1;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
TORCH_CHECK(
get_mempolicy(
&numa_node,
@@ -78,12 +79,15 @@ void NUMAMove(void* ptr, size_t size, int numa_node_id) {
uintptr_t page_start_ptr =
((reinterpret_cast<uintptr_t>(ptr)) & ~(getpagesize() - 1));
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
// NOLINTNEXTLINE(bugprone-narrowing-conversions)
ptrdiff_t offset = reinterpret_cast<uintptr_t>(ptr) - page_start_ptr;
// Avoid extra dynamic allocation and NUMA api calls
AT_ASSERT(
numa_node_id >= 0 &&
static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
unsigned long mask = 1UL << numa_node_id;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
TORCH_CHECK(
mbind(
reinterpret_cast<void*>(page_start_ptr),

View File

@@ -84,7 +84,7 @@ class reverse_iterator {
constexpr reverse_iterator& operator=(const reverse_iterator& rhs) noexcept {
current = rhs.current;
return current;
return *this;
}
template <typename _Iter>

View File

@@ -112,6 +112,7 @@ FatalSignalHandler::FatalSignalHandler()
writingCond(PTHREAD_COND_INITIALIZER),
writingMutex(PTHREAD_MUTEX_INITIALIZER) {}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
FatalSignalHandler::signal_handler FatalSignalHandler::kSignalHandlers[] = {
{"SIGABRT", SIGABRT, {}},
{"SIGINT", SIGINT, {}},
@@ -159,7 +160,7 @@ void FatalSignalHandler::stacktraceSignalHandler(bool needsLock) {
if (needsLock) {
pthread_mutex_lock(&writingMutex);
}
pid_t tid = syscall(SYS_gettid);
pid_t tid = static_cast<pid_t>(syscall(SYS_gettid));
std::string backtrace = fmt::format(
"{}({}), PID: {}, Thread {}: \n {}",
fatalSignalName,
@@ -201,7 +202,7 @@ void FatalSignalHandler::fatalSignalHandler(int signum) {
DIR* procDir = opendir("/proc/self/task");
if (procDir) {
pid_t pid = getpid();
pid_t currentTid = syscall(SYS_gettid);
pid_t currentTid = static_cast<pid_t>(syscall(SYS_gettid));
struct dirent* entry = nullptr;
pthread_mutex_lock(&writingMutex);
while ((entry = readdir(procDir)) != nullptr) {