Allow SymInt input for torch.fx reinplace pass (#133178)

Fixes #133176

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133178
Approved by: https://github.com/ezyang
This commit is contained in:
YangQun1
2024-08-13 20:07:17 +00:00
committed by PyTorch MergeBot
parent 61625a18ef
commit dfc7c860e4
2 changed files with 41 additions and 3 deletions

View File

@@ -3,6 +3,9 @@ import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.fx.passes.reinplace import reinplace
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._dynamo.source import ConstantSource
from torch.fx.experimental.sym_node import SymNode
try:
from functorch.experimental import functionalize
@@ -358,5 +361,41 @@ def forward(self):
return zeros
""")
def test_reinplace_sym_input(self):
# Symbolic input test: the out-of-place add() call should be converted
# into add_(), and symbolic input won't cause any error.
def f(x, index):
a = torch.select(x, 0, index)
clone = a.clone()
b = clone.add(1)
return b
x = torch.randn((4, 8, 16, 16), requires_grad=False)
index = 2
shape_env = ShapeEnv()
symbol = shape_env.create_symbol(index, source=ConstantSource(
f"__testing_only{len(shape_env.var_to_val)}"))
sym_index = torch.SymInt(SymNode(symbol, shape_env, int, hint=index))
inpt = [x, sym_index]
f2 = reinplace(make_fx(f)(*inpt), *inpt)
real_inpt = [x, index]
expected_out = f(*real_inpt)
actual_out = f2(*real_inpt)
self.assertEqual(actual_out, expected_out)
print(f2.code)
self.assertExpectedInline(f2.code, """\
def forward(self, x_1, index_1):
select = torch.ops.aten.select.int(x_1, 0, index_1); x_1 = index_1 = None
clone = torch.ops.aten.clone.default(select); select = None
add = torch.ops.aten.add_.Tensor(clone, 1); add = None
return clone
""")
if __name__ == '__main__':
run_tests()

View File

@@ -114,7 +114,7 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
self.node_counter = -1
with FakeTensorMode() as mode:
fake_args = [mode.from_tensor(a) for a in args]
fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
return super().run(*fake_args)
def _schemas_match(functional_schema, inplace_schema):
@@ -473,8 +473,7 @@ def reinplace(gm, *sample_args):
input_storages = {
StorageWeakRef(
node.meta['fake_result']._typed_storage()
) for node in gm.graph.nodes if node.op == 'placeholder'}
) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))}
# We also need to know for a given node, what are all of its aliasing nodes.
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)