diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index f9e84284f0d..2c232594f33 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -831,7 +831,9 @@ class AOTFxirTestCase(InductorTestCase): gm = torch._inductor.aot_compile( ep.module(), inp, options={"fx_wrapper": True, **test_config} ) - self.assertTrue(same(model(*inp), gm(*inp))) + # Flatten args for fx_wrapper gm + flat_args, _ = pytree.tree_flatten(inp) + self.assertTrue(same(model(*inp), gm(*flat_args))) for node in gm.graph.nodes: if ( @@ -1182,6 +1184,38 @@ def forward(self, arg0_1, arg1_1, arg2_1): compiled_out = compiled(*args) self.assertEqual(compiled_out.shape, shape) + def test_reshape_dynamic_ph(self): + """ + Test dynamic scalars using SymInts placeholder + """ + + class TestModule(torch.nn.Module): + def forward(self, x, shape): + return torch.reshape(x, shape) + 2 + + ds = { + "x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO), + "shape": [torch.export.Dim.AUTO, torch.export.Dim.AUTO], + } + args = (torch.randn((12, 14), device=self.device), [6, 28]) + self.check(TestModule(), args, ds) + + def test_reshape_dynamic_tmd(self): + """ + Test dynamic reshape using shape dependent information + """ + + class TestModule(torch.nn.Module): + def forward(self, x): + new_shape = [x.shape[0] // 2, x.shape[1] * 2] + return torch.reshape(x, new_shape) + 2 + + ds = { + "x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO), + } + args = (torch.randn((12, 14), device=self.device),) + self.check(TestModule(), args, ds) + class TestReplaceFloorDiv(InductorTestCase): """ diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b6796d2b7ce..46ca6048382 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -2537,16 +2537,19 @@ def _extract_inputs_from_exported_gm( fake_inputs = [ node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder" ] - # Replace non-tensor (constant) inputs with Nones, since these are not being - # used anyways by the graph - fake_inputs = [ - inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs - ] + + if not config.fx_wrapper: + # Replace non-tensor inputs with Nones + # constant scalars embedded in the graph + # symbolic scalars (symint) are not supported in non-fx_wrapper mode + fake_inputs = [ + inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs + ] if any(v is not None for v in fake_inputs): # Validate devices before switching to fake tensors. for idx, fi, i in zip(count(), fake_inputs, example_inputs_): - if fi is not None: + if fi is not None and isinstance(fi, torch.Tensor): assert isinstance(i, torch.Tensor) if fi.device != i.device: raise ValueError(