diff --git a/test/test_autocast.py b/test/test_autocast.py index 9d565e4f81e..0876c770fa6 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -158,7 +158,9 @@ class TestAutocastCPU(TestAutocast): m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16) # Raise ValueError when autocast is not enabled - with self.assertRaisesRegex(ValueError, "input must have the type"): + with self.assertRaisesRegex( + ValueError, r"RNN input dtype .* does not match weight dtype" + ): m(x, (hx, cx)) # Should be able to run the below case with autocast diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index fbac210b395..03987ec7557 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -317,7 +317,8 @@ class RNNBase(Module): and not torch._C._is_any_autocast_enabled() ): raise ValueError( - f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr] + f"RNN input dtype ({input.dtype}) does not match weight dtype ({self._flat_weights[0].dtype}). " # type: ignore[union-attr] + f"Convert input: input.to({self._flat_weights[0].dtype}), or convert model: model.to({input.dtype})" # type: ignore[union-attr] ) expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 68e02f16a7d..7947a6ed71c 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -3341,6 +3341,16 @@ def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_g def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs): + # use float64 for dtype mismatch test if current dtype is float32, otherwise use float32 + # MPS doesn't support float64, so use float16 instead + # Extract device type from device string (e.g., 'mps:0' -> 'mps') + device_type = device.split(':')[0] if isinstance(device, str) else device.type + if dtype == torch.float32: + mismatched_dtype = torch.float16 if device_type == 'mps' else torch.float64 + else: + mismatched_dtype = torch.float32 + make_input = partial(make_tensor, device=device, dtype=mismatched_dtype, requires_grad=requires_grad) + samples = [ ErrorModuleInput( ModuleInput(constructor_input=FunctionInput(10, 0, 1)), @@ -3354,6 +3364,17 @@ def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_gr error_type=ValueError, error_regex="num_layers must be greater than zero" ), + # Test dtype mismatch error message + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(3, 5, dtype=dtype, device=device), + forward_input=FunctionInput(make_input((2, 4, 3))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex=(r"RNN input dtype .* does not match weight dtype .* " + r"Convert input: input\.to\(.*\), or convert model: model\.to\(.*\)") + ), # Test bias parameter type validation ErrorModuleInput( ModuleInput(constructor_input=FunctionInput(3, 5, bias=0)),