Checking if the input is finite before calculation in lowering of pow func (#167723)

Fixes #167197

The inductor backend is trying to convert the float infinity value to an integer in pow lowering (possibly for indexing, iteration counts, or type conversions). Python/C++ cannot convert float('inf') to an integer, causing the overflow error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167723
Approved by: https://github.com/eellison, https://github.com/RohitRathore1
This commit is contained in:
Kushagra Rastogi
2025-12-15 19:16:13 +00:00
committed by PyTorch MergeBot
parent 4c043d855b
commit 71ecbc44fb
2 changed files with 39 additions and 1 deletions

View File

@@ -5553,6 +5553,34 @@ class CommonTemplate:
check_lowp=not is_halide_backend(self.device), # misaligned addr fp16
)
@xfail_if_triton_cpu
def test_lp_pool1d_with_inf_norm(self):
# https://github.com/pytorch/pytorch/issues/167197
# Test that LPPool1d works with infinity norm (should behave like max pooling)
def fn(x):
return torch.nn.functional.lp_pool1d(
x, norm_type=float("inf"), kernel_size=2, stride=2
)
self.common(
fn,
(torch.randn(3, 4, 8),),
)
@xfail_if_triton_cpu
def test_lp_pool2d_with_inf_norm(self):
# https://github.com/pytorch/pytorch/issues/167197
# Test that LPPool2d works with infinity norm (should behave like max pooling)
def fn(x):
return torch.nn.functional.lp_pool2d(
x, norm_type=float("inf"), kernel_size=2, stride=2
)
self.common(
fn,
(torch.randn(3, 4, 8, 8),),
)
@tf32_on_and_off(0.006)
@skip_if_gpu_halide # slow
def test_alexnet_prefix(self):
@@ -6334,6 +6362,16 @@ class CommonTemplate:
x = torch.randn([16, 16], device=self.device)
self.assertEqual(cfn(x), fn(x))
@xfail_if_triton_cpu
def test_pow_infinite(self):
def fn(a, b):
return torch.pow(a, b)
opt = torch.compile(fn, backend="inductor")
a = torch.randn((3, 4, 8), device=self.device)
b = float("inf")
self.assertTrue(same(opt(a, b), fn(a, b)))
def test_glu(self):
def fn(x):
return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2)

View File

@@ -6390,7 +6390,7 @@ fallback_pow_tensor_scalar = fallback_handler(
@register_lowering(aten.pow, broadcast=True)
def pow(a, b):
if isinstance(b, float) and b == int(b):
if isinstance(b, float) and math.isfinite(b) and b == int(b):
return pow(a, int(b))
elif isinstance(b, float) and b == 0.5:
return sqrt(a)