diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 0e9993f60f8..006642a4d78 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5553,6 +5553,34 @@ class CommonTemplate: check_lowp=not is_halide_backend(self.device), # misaligned addr fp16 ) + @xfail_if_triton_cpu + def test_lp_pool1d_with_inf_norm(self): + # https://github.com/pytorch/pytorch/issues/167197 + # Test that LPPool1d works with infinity norm (should behave like max pooling) + def fn(x): + return torch.nn.functional.lp_pool1d( + x, norm_type=float("inf"), kernel_size=2, stride=2 + ) + + self.common( + fn, + (torch.randn(3, 4, 8),), + ) + + @xfail_if_triton_cpu + def test_lp_pool2d_with_inf_norm(self): + # https://github.com/pytorch/pytorch/issues/167197 + # Test that LPPool2d works with infinity norm (should behave like max pooling) + def fn(x): + return torch.nn.functional.lp_pool2d( + x, norm_type=float("inf"), kernel_size=2, stride=2 + ) + + self.common( + fn, + (torch.randn(3, 4, 8, 8),), + ) + @tf32_on_and_off(0.006) @skip_if_gpu_halide # slow def test_alexnet_prefix(self): @@ -6334,6 +6362,16 @@ class CommonTemplate: x = torch.randn([16, 16], device=self.device) self.assertEqual(cfn(x), fn(x)) + @xfail_if_triton_cpu + def test_pow_infinite(self): + def fn(a, b): + return torch.pow(a, b) + + opt = torch.compile(fn, backend="inductor") + a = torch.randn((3, 4, 8), device=self.device) + b = float("inf") + self.assertTrue(same(opt(a, b), fn(a, b))) + def test_glu(self): def fn(x): return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 45ad3be1b0d..2c0fc5d21db 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6390,7 +6390,7 @@ fallback_pow_tensor_scalar = fallback_handler( @register_lowering(aten.pow, broadcast=True) def pow(a, b): - if isinstance(b, float) and b == int(b): + if isinstance(b, float) and math.isfinite(b) and b == int(b): return pow(a, int(b)) elif isinstance(b, float) and b == 0.5: return sqrt(a)