mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Support SymInt placeholder in wrapper fxir (#167757)
Summary: add support for symint placeholders added two test cases with dynamic reshape - dynamic info coming from tmd on placeholders - dynamic info coming from placeholders (symints) Test Plan: test_reshape_dynamic_ph test_reshape_dynamic_tmd Differential Revision: D86984100 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167757 Approved by: https://github.com/blaine-rister
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d8ceaa36f
commit
bdd3c3a29c
@@ -831,7 +831,9 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
gm = torch._inductor.aot_compile(
|
||||
ep.module(), inp, options={"fx_wrapper": True, **test_config}
|
||||
)
|
||||
self.assertTrue(same(model(*inp), gm(*inp)))
|
||||
# Flatten args for fx_wrapper gm
|
||||
flat_args, _ = pytree.tree_flatten(inp)
|
||||
self.assertTrue(same(model(*inp), gm(*flat_args)))
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if (
|
||||
@@ -1182,6 +1184,38 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
compiled_out = compiled(*args)
|
||||
self.assertEqual(compiled_out.shape, shape)
|
||||
|
||||
def test_reshape_dynamic_ph(self):
|
||||
"""
|
||||
Test dynamic scalars using SymInts placeholder
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x, shape):
|
||||
return torch.reshape(x, shape) + 2
|
||||
|
||||
ds = {
|
||||
"x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO),
|
||||
"shape": [torch.export.Dim.AUTO, torch.export.Dim.AUTO],
|
||||
}
|
||||
args = (torch.randn((12, 14), device=self.device), [6, 28])
|
||||
self.check(TestModule(), args, ds)
|
||||
|
||||
def test_reshape_dynamic_tmd(self):
|
||||
"""
|
||||
Test dynamic reshape using shape dependent information
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
new_shape = [x.shape[0] // 2, x.shape[1] * 2]
|
||||
return torch.reshape(x, new_shape) + 2
|
||||
|
||||
ds = {
|
||||
"x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO),
|
||||
}
|
||||
args = (torch.randn((12, 14), device=self.device),)
|
||||
self.check(TestModule(), args, ds)
|
||||
|
||||
|
||||
class TestReplaceFloorDiv(InductorTestCase):
|
||||
"""
|
||||
|
||||
@@ -2537,16 +2537,19 @@ def _extract_inputs_from_exported_gm(
|
||||
fake_inputs = [
|
||||
node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
# Replace non-tensor (constant) inputs with Nones, since these are not being
|
||||
# used anyways by the graph
|
||||
fake_inputs = [
|
||||
inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs
|
||||
]
|
||||
|
||||
if not config.fx_wrapper:
|
||||
# Replace non-tensor inputs with Nones
|
||||
# constant scalars embedded in the graph
|
||||
# symbolic scalars (symint) are not supported in non-fx_wrapper mode
|
||||
fake_inputs = [
|
||||
inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs
|
||||
]
|
||||
|
||||
if any(v is not None for v in fake_inputs):
|
||||
# Validate devices before switching to fake tensors.
|
||||
for idx, fi, i in zip(count(), fake_inputs, example_inputs_):
|
||||
if fi is not None:
|
||||
if fi is not None and isinstance(fi, torch.Tensor):
|
||||
assert isinstance(i, torch.Tensor)
|
||||
if fi.device != i.device:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user