mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[ROCm] fix and unskip tests on rocm (#169827)
This PR fixes: - `torch.nonzero` for large tensors on ROCm. It was malfunctioning due to a known hip compiler problem with `::min` for int64_t arguments. Fixed by expliced typing to `std::min<int64_t>` - using `torch.ops.aten.miopen_batch_norm` instead of `torch.ops.aten.cudnn_batch_norm` on ROCm Fixed tests: - Fixes #168878. - Fixes #168879. - Fixes #168553. - Fixes #168554. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169827 Approved by: https://github.com/jeffdaily, https://github.com/mlazos, https://github.com/cyyever Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
e09550e9db
commit
76e60f375a
@@ -183,7 +183,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
|
||||
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
|
||||
auto num_nonzeros = allocator.allocate(sizeof(int) * num_chunks);
|
||||
for (int64_t idx = 0; idx < num_chunks; idx++) {
|
||||
int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
|
||||
int64_t remaining = std::min<int64_t>(chunk_size, self.numel() - idx * chunk_size);
|
||||
ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp<scalar_t>, const scalar_t*) itr(
|
||||
self_.const_data_ptr<scalar_t>() + idx * chunk_size,
|
||||
NonZeroOp<scalar_t>());
|
||||
@@ -241,7 +241,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
|
||||
int64_t curr_nonzeros = 0;
|
||||
if (self.dim() > 0) {
|
||||
for (int64_t idx = 0; idx < num_chunks; idx++) {
|
||||
int remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
|
||||
int remaining = std::min<int64_t>(chunk_size, self.numel() - idx * chunk_size);
|
||||
|
||||
ATEN_CUB_COUNTING_ITERATOR(int64_t) counting_itr(idx * chunk_size);
|
||||
ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp<scalar_t>, const scalar_t*)
|
||||
@@ -353,7 +353,7 @@ void nonzero_static_cuda_out_impl(
|
||||
<<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
in_data_ptr, out_data_ptr, (int64_t*)agg_cum.get(), self.numel(), size, iters_per_cta);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
int64_t out_grid = std::min(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
|
||||
int64_t out_grid = std::min<int64_t>(num_sms, (size + BLOCK_THREADS - 1)/BLOCK_THREADS);
|
||||
write_fill_value<<<out_grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(out_data_ptr, (int64_t *)agg_cum.get() + grid_size - 1, fill_value, size);
|
||||
if (self.dim() > 1) {
|
||||
TensorDims<int64_t> dims;
|
||||
|
||||
@@ -87,7 +87,6 @@ from torch.testing._internal.common_utils import (
|
||||
outs_and_grads,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TEST_MKL,
|
||||
TestCase,
|
||||
xfail_inherited_tests,
|
||||
@@ -3900,7 +3899,6 @@ def forward(self, tangents_1):
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
@unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable")
|
||||
@skipIfRocm # https://github.com/pytorch/pytorch/issues/96560
|
||||
def test_batch_norm_amp(self):
|
||||
device = "cuda"
|
||||
input_dtype = torch.float16
|
||||
@@ -3914,7 +3912,12 @@ def forward(self, tangents_1):
|
||||
)
|
||||
|
||||
def bn(x):
|
||||
return torch.ops.aten.cudnn_batch_norm(
|
||||
fn = (
|
||||
torch.ops.aten.cudnn_batch_norm
|
||||
if torch.version.hip is None
|
||||
else torch.ops.aten.miopen_batch_norm
|
||||
)
|
||||
return fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
|
||||
@@ -44,7 +44,6 @@ from torch.testing._internal.common_utils import (
|
||||
numpy_to_torch_dtype_dict,
|
||||
run_tests,
|
||||
skipIfNoSciPy,
|
||||
skipIfRocm,
|
||||
slowTest,
|
||||
suppress_warnings,
|
||||
TEST_SCIPY,
|
||||
@@ -1613,7 +1612,6 @@ class TestUnaryUfuncs(TestCase):
|
||||
@onlyCUDA
|
||||
@dtypes(torch.int8)
|
||||
@largeTensorTest("8GB")
|
||||
@skipIfRocm(msg="ROCM tries to allocate 60GB")
|
||||
def test_nonzero_large(self, device, dtype):
|
||||
indices = (
|
||||
torch.tensor((0, 2, 3, 4, 6, 100, 103, 2**30, 2**31 - 3, 2**31 - 2)),
|
||||
|
||||
Reference in New Issue
Block a user