From 73da7a40b6dd0509b179f0ca92d9fa79973ee306 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 3 Nov 2025 07:17:43 -0800 Subject: [PATCH] [MPS] Error out when BatchNorm is called for Complex (#166215) Or BatchNorm or LayerNorm for Long types Discovered while trying to enable `test_ops.py` for MPS Pull Request resolved: https://github.com/pytorch/pytorch/pull/166215 Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #166214 --- aten/src/ATen/native/mps/operations/Normalization.mm | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index f5264cf32d9..0c95fec667e 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -84,6 +84,9 @@ std::tuple batch_norm_mps_out(const Tensor& self, Tensor& output, Tensor& save_mean, Tensor& save_var) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), + "Batch norm for complex is not supported for MPS"); using namespace at::native::mps; struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} @@ -918,6 +921,7 @@ std::tuple layer_norm_mps(const Tensor& input, // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) const int axis = input_ndim - normalized_ndim; MPSStream* stream = getCurrentMPSStream(); + TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS"); @autoreleasepool { mps::dispatch_sync_with_rethrow(stream->queue(), ^() { // which kernel variant to use based on the normalized axis N size