Improve RNN dtype mismatch error message (#166946)

/Enhance error message to explain mismatch and provide two actionable fixes: convert input with input.to(dtype) or convert model with model.to(dtype). Add test to validate error message and verify both suggested fixes work.

Fixes #136931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166946
Approved by: https://github.com/mikaylagawarecki, https://github.com/cyyever
This commit is contained in:
vishalgoyal316
2025-12-31 06:51:06 +00:00
committed by PyTorch MergeBot
parent 30fd43528e
commit 7c467cad4a
3 changed files with 26 additions and 2 deletions

View File

@@ -158,7 +158,9 @@ class TestAutocastCPU(TestAutocast):
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
# Raise ValueError when autocast is not enabled
with self.assertRaisesRegex(ValueError, "input must have the type"):
with self.assertRaisesRegex(
ValueError, r"RNN input dtype .* does not match weight dtype"
):
m(x, (hx, cx))
# Should be able to run the below case with autocast

View File

@@ -317,7 +317,8 @@ class RNNBase(Module):
and not torch._C._is_any_autocast_enabled()
):
raise ValueError(
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr]
f"RNN input dtype ({input.dtype}) does not match weight dtype ({self._flat_weights[0].dtype}). " # type: ignore[union-attr]
f"Convert input: input.to({self._flat_weights[0].dtype}), or convert model: model.to({input.dtype})" # type: ignore[union-attr]
)
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:

View File

@@ -3341,6 +3341,16 @@ def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_g
def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
# use float64 for dtype mismatch test if current dtype is float32, otherwise use float32
# MPS doesn't support float64, so use float16 instead
# Extract device type from device string (e.g., 'mps:0' -> 'mps')
device_type = device.split(':')[0] if isinstance(device, str) else device.type
if dtype == torch.float32:
mismatched_dtype = torch.float16 if device_type == 'mps' else torch.float64
else:
mismatched_dtype = torch.float32
make_input = partial(make_tensor, device=device, dtype=mismatched_dtype, requires_grad=requires_grad)
samples = [
ErrorModuleInput(
ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
@@ -3354,6 +3364,17 @@ def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_gr
error_type=ValueError,
error_regex="num_layers must be greater than zero"
),
# Test dtype mismatch error message
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(3, 5, dtype=dtype, device=device),
forward_input=FunctionInput(make_input((2, 4, 3))),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=ValueError,
error_regex=(r"RNN input dtype .* does not match weight dtype .* "
r"Convert input: input\.to\(.*\), or convert model: model\.to\(.*\)")
),
# Test bias parameter type validation
ErrorModuleInput(
ModuleInput(constructor_input=FunctionInput(3, 5, bias=0)),