Support SymInt placeholder in wrapper fxir (#167757)

Summary:
add support for symint placeholders

added two test cases with dynamic reshape
- dynamic info coming from tmd on placeholders
- dynamic info coming from placeholders (symints)

Test Plan:
test_reshape_dynamic_ph
test_reshape_dynamic_tmd

Differential Revision: D86984100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167757
Approved by: https://github.com/blaine-rister
This commit is contained in:
Nan Zhang
2025-11-17 21:10:52 +00:00
committed by PyTorch MergeBot
parent 9d8ceaa36f
commit bdd3c3a29c
2 changed files with 44 additions and 7 deletions

View File

@@ -831,7 +831,9 @@ class AOTFxirTestCase(InductorTestCase):
gm = torch._inductor.aot_compile(
ep.module(), inp, options={"fx_wrapper": True, **test_config}
)
self.assertTrue(same(model(*inp), gm(*inp)))
# Flatten args for fx_wrapper gm
flat_args, _ = pytree.tree_flatten(inp)
self.assertTrue(same(model(*inp), gm(*flat_args)))
for node in gm.graph.nodes:
if (
@@ -1182,6 +1184,38 @@ def forward(self, arg0_1, arg1_1, arg2_1):
compiled_out = compiled(*args)
self.assertEqual(compiled_out.shape, shape)
def test_reshape_dynamic_ph(self):
"""
Test dynamic scalars using SymInts placeholder
"""
class TestModule(torch.nn.Module):
def forward(self, x, shape):
return torch.reshape(x, shape) + 2
ds = {
"x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO),
"shape": [torch.export.Dim.AUTO, torch.export.Dim.AUTO],
}
args = (torch.randn((12, 14), device=self.device), [6, 28])
self.check(TestModule(), args, ds)
def test_reshape_dynamic_tmd(self):
"""
Test dynamic reshape using shape dependent information
"""
class TestModule(torch.nn.Module):
def forward(self, x):
new_shape = [x.shape[0] // 2, x.shape[1] * 2]
return torch.reshape(x, new_shape) + 2
ds = {
"x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO),
}
args = (torch.randn((12, 14), device=self.device),)
self.check(TestModule(), args, ds)
class TestReplaceFloorDiv(InductorTestCase):
"""

View File

@@ -2537,16 +2537,19 @@ def _extract_inputs_from_exported_gm(
fake_inputs = [
node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"
]
# Replace non-tensor (constant) inputs with Nones, since these are not being
# used anyways by the graph
fake_inputs = [
inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs
]
if not config.fx_wrapper:
# Replace non-tensor inputs with Nones
# constant scalars embedded in the graph
# symbolic scalars (symint) are not supported in non-fx_wrapper mode
fake_inputs = [
inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs
]
if any(v is not None for v in fake_inputs):
# Validate devices before switching to fake tensors.
for idx, fi, i in zip(count(), fake_inputs, example_inputs_):
if fi is not None:
if fi is not None and isinstance(fi, torch.Tensor):
assert isinstance(i, torch.Tensor)
if fi.device != i.device:
raise ValueError(