Files
pytorch/torch/csrc/distributed/python_placement.cpp
zpcore 1d62781fbe Isolate _StridedShard from Shard (#167887)
### Summary

Separate out `_StridedShard` class from `Shard`.

### Details
1. `_StridedShard(dim, sf=...)==Shard(dim)` should return False.
2. `is_shard()` should only be true for `Shard` class.
3. _StridedShard becomes not a subclass of Shard
4. Supports Redistribution for any StridedShard placement that is convertible to ordered Shard placement
5. Fix pointwise related propagation to support `_StridedShard`. Originally `_StridedShard` was silently cast into `Shard`. This will cause FSDP optimizer to update the model parameters incorrectly.
    1. Operations impacted: operation that works on model parameters. From what I have seen, those operations are pointwise foreach operations. E.g., `optimizer.step()`, `torch.nn.utils.clip_grad_norm_()`, `GradScaler.step()`.
6. Some minor fixes to op sharding propagation regarding `_StridedShard` type, which impacts ops like rand.
    1. view ops: fix `maybe_get_shard_mesh_dim_and_placement()` to propagate StridedShard.

Note: If we isolate StridedShard from Shard, generally this is what will happen if we don't change sharding strategy to support StridedShard: https://github.com/pytorch/pytorch/issues/166598#issuecomment-3583225318

### Analysis of impact to pointwise ops
**TL;DR** `StridedShard` should be propagated though with the same split_factor. Correctness not impacted.

The change is centralized in `map_placements_after_reduction`, where the function handles two cases:
1. If reduction dims contains StridedShard, it replaces the reduction dims in original placements from Shard/StridedShard into Partial(...). E.g., for torch.sum(dim=0), we have strategy `SS(0)S(0)SS(0) --> PPP`. It should work because:
    - For all those Partial type like [sum, avg, product, max, min], they are independent of the whether it's stridedshard or normal shard. E.g. in partial('sum'), we sum up all dim that uses Shard(0) and StridedShard(0). Then it should work for arbitraty StridedShard combinations regardless of split factor and ordering.
2. For non reduction dims that contains StridedShard, e.g., `S(1)SS(1)`, we keep the original sharding but with an adjusted shard dim `S(1)SS(1)->S(0)SS(0)` when we reduce tensor dim 0. This should also work fine because we didn't touch any of the sharding on tensor dim 1.

### Analysis of impact to Optimizer in FSDP:
**TL;DR** `StridedShard` can pass though foreach pointwise ops in the optimizer related to the model parameters, and will not impact the correctness of FSDP, no SS->S or S->SS redistribution happen.

I added a test `Test_StridedShard_Optimizer` in this PR. The test verified that different optimizers updated the grad for the fsdp model (contains `StridedShard` in parameters) correctly compared with the unsharded model. I also checked that for different optimizers, what aten ops will take StridedShard in inputs' placements as in the following table. I see three cases for those optimizers:
1. `detach`, `zeros_like`, `full_like`, clone all passes SS through input to output. I guess this happens in model weight initialization.
2. out-of-place foreach, like `_foreach_add.Scalar`, the SS will pass through to output spec.
3. in-place foreach. Those ops are kind of weird that they don't have output spec (If we check the output spec, it's None. I debugged and noticed that this is due to the metadata propagation for fake tensor is None when we call inplace foreach op.). However, I can see it generates new input spec for each arg. Those new input spec will still keep SS if the original input spec is SS.

So I am kind of confident now that this PR allows StridedShard pass though pointwise ops in the optimizer and will not impact the correctness of FSDP.

- reference: modified [_sharding_prop.py](https://gist.github.com/zpcore/fc81a0017fd673333ac282b67c750bf2) for testing what ops received inputs that contain StridedShard and whether StridedShard has passed through aten op or not.

| Optimizer      | Aten Op                        |
|----------------|-------------------------------|
| adadelta       | `detach.default`              |
|                | `zeros_like.default`         |
|                | `_foreach_mul_.Scalar`         |
|                | `_foreach_addcmul_.Scalar`   |
|                | `_foreach_add.Scalar`          |
|                | `_foreach_sqrt_.default`       |
|                | `_foreach_div_.List`           |
|                | `_foreach_mul_.List`           |
|                | `_foreach_add_.List`           |
| adagrad        | `detach.default`         |
|                | `full_like.default`          |
|                | `_foreach_addcmul_.Scalar`     |
|                | `_foreach_sqrt.default`       |
|                | `_foreach_add_.Scalar`         |
|                | `_foreach_mul.ScalarList`    |
|                | `_foreach_addcdiv_.Scalar`     |
| adam           | `detach.default`         |
|                | `zeros_like.default`       |
|                | `_foreach_lerp_.Scalar`        |
|                | `_foreach_mul_.Scalar`         |
|                | `_foreach_addcmul_.Scalar`     |
|                | `_foreach_sqrt.default`     |
|                | `_foreach_div_.ScalarList`     |
|                | `_foreach_add_.Scalar`         |
|                | `_foreach_addcdiv_.ScalarList ` |
| adamw          | `detach.default`               |
|                | `zeros_like.default`           |
|                | `_foreach_mul_.Scalar`         |
|                | `_foreach_lerp_.Scalar`        |
|                | `_foreach_addcmul_.Scalar`   |
|                | `_foreach_sqrt.default`        |
|                | `_foreach_div_.ScalarList`     |
|                | `_foreach_add_.Scalar`         |
|                | `_foreach_addcdiv_.ScalarList` |
| rmsprop        | `detach.default`               |
|                | `zeros_like.default`           |
|                | `_foreach_mul_.Scalar`         |
|                | `_foreach_addcmul_.Scalar`     |
|                | `_foreach_sqrt.default`        |
|                | `_foreach_add_.Scalar`         |
|                | `_foreach_addcdiv_.Scalar`     |
| sgd            | `detach.default`               |
|                | `_foreach_add_.List`           |
| sgd (momentum) | `detach.default`               |
|                | `clone.default`                |
|                | `_foreach_add_.List`           |
|                | `_foreach_add_.List`           |
|                | `_foreach_mul_.Scalar`        |

### Analysis of impact to view ops
**TL;DR** Previous and current code can propagate `StridedShard` by maintaining the same split_factor (which I don't think this is correct for all cases), as long as the tensor dim is shardable on the mesh. We need follow up fix the let view op propagate arbitrary combination of `StridedShard`.

I marked `expectedFailure` for `test_view_propagation_not_supported_yet` in this PR. @weifengpy plans to work on a broader support.

### Analysis of impact to random ops
Our current `OffsetBasedRNGTracker` uses the Shard placement to compute the shard idx of each shard the shard size. This will be used to decide the RNG offset. However, the offset computation didn't consider the split_factor in the `StridedShard`.

This PR didn't fix any offset computation based on `StridedShard`. It only allows the `StridedShard` to be treated equally as Shard, which is the same behavior before this PR when `StridedShard` is a subclass of Shard.

---

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167887
Approved by: https://github.com/wconstab
2025-12-10 06:10:23 +00:00

107 lines
4.1 KiB
C++

#include <torch/csrc/distributed/python_placement.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/distributed/Placement.h>
#include <torch/csrc/utils/pybind.h>
using namespace pybind11::literals;
namespace torch::distributed {
namespace {
const auto placement_class_docstring =
R"(The base class for the Placement type, where it describes how a DTensor is placed onto the
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
and ``Partial``.
This class is not meant to be used directly, mainly served as a typing stub.
)";
} // namespace
void initPlacementBindings(PyObject* module) {
auto py_module = py::reinterpret_borrow<py::module>(module);
auto distributed_module = py_module.def_submodule("_distributed");
py::class_<Placement>(
distributed_module, "Placement", placement_class_docstring)
.def(py::init<>()) // Allow construction of Python subclasses.
.def(
"is_partial",
&Placement::is_partial,
py::arg("reduce_op") = py::none())
.def("is_replicate", &Placement::is_replicate)
.def("is_shard", &Placement::is_shard, py::arg("dim") = py::none());
py::class_<Shard, Placement>(distributed_module, "Shard")
.def(py::init<int64_t>(), py::arg("dim"))
.def_readonly("dim", &Shard::dim)
.def("is_shard", &Shard::is_shard, py::arg("dim") = py::none())
.def(
"__eq__",
[](const Shard& lhs, const Shard& rhs) { return lhs == rhs; },
py::is_operator())
// Note: we need to use dicts for pickling to match the old
// dataclasses.
.def(py::pickle(
[](const Shard& shard) { return py::dict("dim"_a = shard.dim); },
[](const py::dict& d) {
return Shard(py::cast<int64_t>(d["dim"]));
}));
py::class_<StridedShard, Placement>(distributed_module, "StridedShard")
.def(
py::init<int64_t, int64_t>(),
py::arg("dim"),
py::kw_only(),
py::arg("split_factor"))
.def_readonly("dim", &StridedShard::dim)
.def_readonly("split_factor", &StridedShard::split_factor)
.def(
"__eq__",
[](const StridedShard& lhs, const StridedShard& rhs) {
return lhs == rhs;
},
py::is_operator())
.def(py::pickle(
[](const StridedShard& shard) {
return py::dict(
"dim"_a = shard.dim, "split_factor"_a = shard.split_factor);
},
[](const py::dict& d) {
return StridedShard(
py::cast<int64_t>(d["dim"]),
py::cast<int64_t>(d["split_factor"]));
}));
py::class_<Replicate, Placement>(distributed_module, "Replicate")
.def(py::init())
.def("is_replicate", &Replicate::is_replicate)
.def(
"__eq__",
[](const Replicate& lhs, const Replicate& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
// I observed SIGSEGV when trying to use None as the
// pickled state, though AFAICT that matches the
// behavior of
// object().__reduce__().
// test_placement_types.test_type_identification will repro if an
// enterprising reader wants to get this fixed.
[](const Replicate& repl) { return py::dict(); },
[](const py::dict&) { return Replicate(); }));
py::class_<Partial, Placement>(distributed_module, "Partial")
.def(py::init<>())
.def(py::init<std::optional<std::string>>(), py::arg("reduce_op"))
.def_readonly("reduce_op", &Partial::reduce_op)
.def(
"is_partial", &Partial::is_partial, py::arg("reduce_op") = py::none())
.def(
"__eq__",
[](const Partial& lhs, const Partial& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
[](const Partial& part) {
return py::dict("reduce_op"_a = part.reduce_op);
},
[](const py::dict& d) {
return Partial(py::cast<std::string>(d["reduce_op"]));
}));
}
} // namespace torch::distributed