diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index 8ddb59cba77..453986af4c6 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -169,6 +169,13 @@ def _get_dtensor_dispatch_graph( op_overload = cast(torch._ops.OpOverload, node.target) + if node.target == torch.ops.aten.view.default: + # HACK: this is a hack to get around with the fact that some + # view operations on a "global" tensor is invalid usage + # but somehow the view operation on the batch input might hit it + # so we convert the view op to reshape before calling DTensor + op_overload = torch.ops.aten.reshape.default + # run dispatch once to get the real DTensor output. out, op_schema, output_sharding = _operator_dispatch( op_overload, diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 52be898fa1c..d7c2db10961 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -667,6 +667,7 @@ def register_prop_rule_map( register_prop_rule_map(aten.squeeze.default, torch.squeeze) register_prop_rule_map(aten.squeeze.dim, torch.squeeze) register_prop_rule_map(aten.view.default, Tensor.view) +register_prop_rule_map(aten.reshape.default, torch.reshape) register_prop_rule_map(aten._unsafe_view.default, Tensor.view) register_prop_rule_map(aten.unsqueeze.default, torch.unsqueeze) register_prop_rule_map(aten.expand.default, Tensor.expand)