mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Allow SymInt input for torch.fx reinplace pass (#133178)
Fixes #133176 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133178 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
61625a18ef
commit
dfc7c860e4
@@ -3,6 +3,9 @@ import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.fx.passes.reinplace import reinplace
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch._dynamo.source import ConstantSource
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
|
||||
try:
|
||||
from functorch.experimental import functionalize
|
||||
@@ -358,5 +361,41 @@ def forward(self):
|
||||
return zeros
|
||||
""")
|
||||
|
||||
def test_reinplace_sym_input(self):
|
||||
# Symbolic input test: the out-of-place add() call should be converted
|
||||
# into add_(), and symbolic input won't cause any error.
|
||||
def f(x, index):
|
||||
a = torch.select(x, 0, index)
|
||||
clone = a.clone()
|
||||
b = clone.add(1)
|
||||
return b
|
||||
|
||||
x = torch.randn((4, 8, 16, 16), requires_grad=False)
|
||||
index = 2
|
||||
shape_env = ShapeEnv()
|
||||
symbol = shape_env.create_symbol(index, source=ConstantSource(
|
||||
f"__testing_only{len(shape_env.var_to_val)}"))
|
||||
sym_index = torch.SymInt(SymNode(symbol, shape_env, int, hint=index))
|
||||
|
||||
inpt = [x, sym_index]
|
||||
f2 = reinplace(make_fx(f)(*inpt), *inpt)
|
||||
|
||||
real_inpt = [x, index]
|
||||
expected_out = f(*real_inpt)
|
||||
actual_out = f2(*real_inpt)
|
||||
self.assertEqual(actual_out, expected_out)
|
||||
print(f2.code)
|
||||
self.assertExpectedInline(f2.code, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, x_1, index_1):
|
||||
select = torch.ops.aten.select.int(x_1, 0, index_1); x_1 = index_1 = None
|
||||
clone = torch.ops.aten.clone.default(select); select = None
|
||||
add = torch.ops.aten.add_.Tensor(clone, 1); add = None
|
||||
return clone
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@@ -114,7 +114,7 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
||||
self.node_counter = -1
|
||||
|
||||
with FakeTensorMode() as mode:
|
||||
fake_args = [mode.from_tensor(a) for a in args]
|
||||
fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
return super().run(*fake_args)
|
||||
|
||||
def _schemas_match(functional_schema, inplace_schema):
|
||||
@@ -473,8 +473,7 @@ def reinplace(gm, *sample_args):
|
||||
input_storages = {
|
||||
StorageWeakRef(
|
||||
node.meta['fake_result']._typed_storage()
|
||||
) for node in gm.graph.nodes if node.op == 'placeholder'}
|
||||
|
||||
) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))}
|
||||
|
||||
# We also need to know for a given node, what are all of its aliasing nodes.
|
||||
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
|
||||
|
||||
Reference in New Issue
Block a user