mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Improve RNN dtype mismatch error message (#166946)
/Enhance error message to explain mismatch and provide two actionable fixes: convert input with input.to(dtype) or convert model with model.to(dtype). Add test to validate error message and verify both suggested fixes work. Fixes #136931 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166946 Approved by: https://github.com/mikaylagawarecki, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
30fd43528e
commit
7c467cad4a
@@ -158,7 +158,9 @@ class TestAutocastCPU(TestAutocast):
|
||||
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
|
||||
|
||||
# Raise ValueError when autocast is not enabled
|
||||
with self.assertRaisesRegex(ValueError, "input must have the type"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"RNN input dtype .* does not match weight dtype"
|
||||
):
|
||||
m(x, (hx, cx))
|
||||
|
||||
# Should be able to run the below case with autocast
|
||||
|
||||
@@ -317,7 +317,8 @@ class RNNBase(Module):
|
||||
and not torch._C._is_any_autocast_enabled()
|
||||
):
|
||||
raise ValueError(
|
||||
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr]
|
||||
f"RNN input dtype ({input.dtype}) does not match weight dtype ({self._flat_weights[0].dtype}). " # type: ignore[union-attr]
|
||||
f"Convert input: input.to({self._flat_weights[0].dtype}), or convert model: model.to({input.dtype})" # type: ignore[union-attr]
|
||||
)
|
||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||
if input.dim() != expected_input_dim:
|
||||
|
||||
@@ -3341,6 +3341,16 @@ def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_g
|
||||
|
||||
|
||||
def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
# use float64 for dtype mismatch test if current dtype is float32, otherwise use float32
|
||||
# MPS doesn't support float64, so use float16 instead
|
||||
# Extract device type from device string (e.g., 'mps:0' -> 'mps')
|
||||
device_type = device.split(':')[0] if isinstance(device, str) else device.type
|
||||
if dtype == torch.float32:
|
||||
mismatched_dtype = torch.float16 if device_type == 'mps' else torch.float64
|
||||
else:
|
||||
mismatched_dtype = torch.float32
|
||||
make_input = partial(make_tensor, device=device, dtype=mismatched_dtype, requires_grad=requires_grad)
|
||||
|
||||
samples = [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
|
||||
@@ -3354,6 +3364,17 @@ def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_gr
|
||||
error_type=ValueError,
|
||||
error_regex="num_layers must be greater than zero"
|
||||
),
|
||||
# Test dtype mismatch error message
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(3, 5, dtype=dtype, device=device),
|
||||
forward_input=FunctionInput(make_input((2, 4, 3))),
|
||||
),
|
||||
error_on=ModuleErrorEnum.FORWARD_ERROR,
|
||||
error_type=ValueError,
|
||||
error_regex=(r"RNN input dtype .* does not match weight dtype .* "
|
||||
r"Convert input: input\.to\(.*\), or convert model: model\.to\(.*\)")
|
||||
),
|
||||
# Test bias parameter type validation
|
||||
ErrorModuleInput(
|
||||
ModuleInput(constructor_input=FunctionInput(3, 5, bias=0)),
|
||||
|
||||
Reference in New Issue
Block a user