Files
pytorch/benchmarks/dynamo/distributed.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

178 lines
5.5 KiB
Python
Raw Permalink Normal View History

import argparse
import logging
import os
from functools import partial
import torch
import torch._dynamo as dynamo
import torch.utils._pytree as pytree
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, record_function
try:
from .common import timed
from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
except ImportError:
from common import timed
from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
Special-case fsdp wrapped modules to be Unspecialized (#89330) ### Summary Making dynamo treat the nn.Modules inside FSDP wrappers as 'Unspecialized' results in dynamo-produced graphs where nn.module parameters are inputs to the graph rather than attributes of the outer graphmodule. This helps in FSDP since it forces dynamo to pick the latest copy of the parameters off the user's nn.Module (which FSDP mutates every pre_forward), solving the ordering issue in backward. ### Details Imagine this toy model ``` class MyModule(torch.nn.Module): def __init__(self, a, b): super(MyModule, self).__init__() self.net = nn.Sequential( nn.Linear(a, b), nn.ReLU(), ) def forward(self, x): return self.net(x) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net = nn.Sequential( *[MyModule(10, 10000)] + [MyModule(10000, 1000)] + [MyModule(1000, 5)] ) def forward(self, x): return self.net(x) ``` Where FSDP is recursively wrapped around each `MyModule`, then dynamo-compiled, with dynamo already configured to skip/break in FSDP code. You'd expect to get 3 compiled AOT functions, corresponding to the contents of `MyModule`, and then see FSDP's communication ops happen inbetween them (eagerly). This almost happens (everything works out fine in forward), but in backward there is an ordering issue. FSDP creates a flat buffer for all the parameters that are bucketed together, and then creates views into this buffer to replace the original parameters. On each iteration of forward, it creates a new view after 'filling' the flatbuffer with data from an all-gather operation, to 'unshard' the parameters from remote devices. Dynamo traces the first such view and stores it in a compiled graphmodule. During tracing, we see (1) view created for first MyModule, (2) compile first MyModule, (3) ... for the rest of layers Then during runtime, we see (A) view created for first MyModule (and orphaned), (B) execute first compiled MyModule, using old view, ... This is a problem, because we want backward hooks to run right after each compiled-backward, but autograd executes those hooks in an order mirroring their execution order during forward. Since we are forever using the views created during steps (1, 3, .. N), which all happen before the steps (A, B, ...), this means that all the hooks will happen after all the compiled backwards. An illustration of the problem - a torchviz graph showing the 2 possible orderings of autograd, and a profile showing the view-backwards ops happening after all the compiled backwards, and before all the backward hooks. <img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828002-32dbbd15-8fc3-4281-93e9-227ab5e32683.png"> <img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828632-33e40729-9a7f-4e68-9ce1-571e3a8dd2dd.png"> A solution is to make dynamo not specialize on these nn modules. It is worth pointing out that this nn.module specialization is de-facto failing, as we are modifying .parameters and this bypasses dynamo's __setattr__ monkeypatch, which should have automatically kicked us out to Unspecialized and forced a recompile. After unspecializing, the new views (created during steps A, C, ...) are actually _used_ at runtime by the module, making their creation order interleaved, making autograd execute their backwards interleaved. The new torchviz graph (this time with names added for the view tensors): <img width="2043" alt="image" src="https://user-images.githubusercontent.com/4984825/202828480-d30005ba-0d20-45d8-b647-30b7ff5e91d3.png"> And a new profile showing the interleaving of compiled backwards and hooks, allowing overlapping of reduce-scatter. <img width="2293" alt="image" src="https://user-images.githubusercontent.com/4984825/202828533-bb20a041-19b8-499c-b3cf-02808933df47.png"> @jansel @davidberard98 @aazzolini @mrshenli @awgu @ezyang @soumith @voznesenskym @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89330 Approved by: https://github.com/davidberard98
2022-11-28 21:08:39 +00:00
log = logging.getLogger(__name__)
def torchviz_model(args, model, inputs, rank):
from torchviz import make_dot
outputs = model(*inputs)
loss = reduce_to_scalar_loss(outputs)
parameter_names = dict(model.named_parameters())
dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
if rank == 0:
dot.render("torchviz.dot")
def profile_model(args, model, inputs, rank):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for i in range(args.repeat):
with record_function("Forward"):
outputs = model(*inputs)
loss = reduce_to_scalar_loss(outputs)
with record_function("Backward"):
loss.backward()
if rank == 0:
prof.export_chrome_trace(args.trace_file)
def run_model(args, model, inputs, key):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
# result_q = []
setup(rank, world_size)
if args.device == "cuda":
# needed for FSDP
torch.cuda.set_device(rank)
dev_rank = f"{args.device}:{rank}"
model = model.to(dev_rank)
def move_tensor(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(dev_rank)
return maybe_tensor
inputs = pytree.tree_map(move_tensor, inputs)
if args.fsdp:
model = apply_fsdp(
[FSDP+dynamo]: forward treats parameter-views as params (#88781) Dynamo+AotAutograd needs a way to wrap all tensors (whether inputs or params/buffers) in FakeTensor wrappers, and FSDP's mangling of parameters hides them from this wrapping. This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model. Perf/memory validation and possibly optimization is the next step. `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap` The problem: Dynamo (Actually aot_autograd) trips up with FSDP becuase it must wrap all input tensors in FakeTensor wrappers, and it only knows to wrap graph inputs or named_(parameters, buffers). FSDP's pre_forward hook sets views (which are not nn.param) into the flatparam as attrs on the module with the same name as the original param, but they will not show up in named_parameters. - in use_orig_params mode, FSDP still de-registers params during pre-forward hook, then re-registers them post-forward - during forward (between the hooks), the params are setattr'd on the module as regular view tensors, not nn.Parameters - note: use_orig_params is the recommended way to use FSDP, and use_orig_params=False is being deprecated. So i only consider use_orig_params=True for this enablement The solution: - adding them to named_buffers is not possible because it interferes with how FSDP's `_apply` works - since they are not actual nn.parameters, register_parameter will complain about registering them - simply seting `module._parameters[name] = view` seems to be a viable workaround, despite being hacky, and FSDP code does modify _parameters directly already. Note: Manual checkpointing still isn't working with FSDP+dynamo, so that will have to be addressed in a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781 Approved by: https://github.com/ezyang, https://github.com/awgu
2022-11-11 21:22:49 +00:00
args,
model,
use_checkpointing=args.fsdp_checkpoint,
use_wrap_policy=args.fsdp_wrap,
)
elif args.ddp:
model = DDP(model)
if args.verbose:
print(model)
if args.dynamo:
dynamo.reset()
if args.verbose:
dynamo.config.verbose = True
dynamo.config.log_level = logging.DEBUG
if args.dynamo_no_optimize_ddp:
dynamo.config.optimize_ddp = False
Special-case fsdp wrapped modules to be Unspecialized (#89330) ### Summary Making dynamo treat the nn.Modules inside FSDP wrappers as 'Unspecialized' results in dynamo-produced graphs where nn.module parameters are inputs to the graph rather than attributes of the outer graphmodule. This helps in FSDP since it forces dynamo to pick the latest copy of the parameters off the user's nn.Module (which FSDP mutates every pre_forward), solving the ordering issue in backward. ### Details Imagine this toy model ``` class MyModule(torch.nn.Module): def __init__(self, a, b): super(MyModule, self).__init__() self.net = nn.Sequential( nn.Linear(a, b), nn.ReLU(), ) def forward(self, x): return self.net(x) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net = nn.Sequential( *[MyModule(10, 10000)] + [MyModule(10000, 1000)] + [MyModule(1000, 5)] ) def forward(self, x): return self.net(x) ``` Where FSDP is recursively wrapped around each `MyModule`, then dynamo-compiled, with dynamo already configured to skip/break in FSDP code. You'd expect to get 3 compiled AOT functions, corresponding to the contents of `MyModule`, and then see FSDP's communication ops happen inbetween them (eagerly). This almost happens (everything works out fine in forward), but in backward there is an ordering issue. FSDP creates a flat buffer for all the parameters that are bucketed together, and then creates views into this buffer to replace the original parameters. On each iteration of forward, it creates a new view after 'filling' the flatbuffer with data from an all-gather operation, to 'unshard' the parameters from remote devices. Dynamo traces the first such view and stores it in a compiled graphmodule. During tracing, we see (1) view created for first MyModule, (2) compile first MyModule, (3) ... for the rest of layers Then during runtime, we see (A) view created for first MyModule (and orphaned), (B) execute first compiled MyModule, using old view, ... This is a problem, because we want backward hooks to run right after each compiled-backward, but autograd executes those hooks in an order mirroring their execution order during forward. Since we are forever using the views created during steps (1, 3, .. N), which all happen before the steps (A, B, ...), this means that all the hooks will happen after all the compiled backwards. An illustration of the problem - a torchviz graph showing the 2 possible orderings of autograd, and a profile showing the view-backwards ops happening after all the compiled backwards, and before all the backward hooks. <img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828002-32dbbd15-8fc3-4281-93e9-227ab5e32683.png"> <img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828632-33e40729-9a7f-4e68-9ce1-571e3a8dd2dd.png"> A solution is to make dynamo not specialize on these nn modules. It is worth pointing out that this nn.module specialization is de-facto failing, as we are modifying .parameters and this bypasses dynamo's __setattr__ monkeypatch, which should have automatically kicked us out to Unspecialized and forced a recompile. After unspecializing, the new views (created during steps A, C, ...) are actually _used_ at runtime by the module, making their creation order interleaved, making autograd execute their backwards interleaved. The new torchviz graph (this time with names added for the view tensors): <img width="2043" alt="image" src="https://user-images.githubusercontent.com/4984825/202828480-d30005ba-0d20-45d8-b647-30b7ff5e91d3.png"> And a new profile showing the interleaving of compiled backwards and hooks, allowing overlapping of reduce-scatter. <img width="2293" alt="image" src="https://user-images.githubusercontent.com/4984825/202828533-bb20a041-19b8-499c-b3cf-02808933df47.png"> @jansel @davidberard98 @aazzolini @mrshenli @awgu @ezyang @soumith @voznesenskym @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89330 Approved by: https://github.com/davidberard98
2022-11-28 21:08:39 +00:00
if args.dynamo == "inductor" and args.fsdp:
torch._inductor.config.triton.cudagraphs = False
log.warning("disabling inductor cudagraphs for compatibility with FSDP")
def print_compile(gm, ex):
print(
f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
)
return gm
dynamo_ctx = dynamo.optimize(
print_compile if args.dynamo == "print" else args.dynamo
)
model = dynamo_ctx(model)
# warmup
_ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
t_total = timed(
model, model_iter_fn, inputs, times=args.repeat, return_result=False
)
if args.torchviz:
torchviz_model(args, model, inputs, rank)
if args.profile:
profile_model(args, model, inputs, rank)
cleanup()
return t_total
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda")
parser.add_argument(
"--dynamo",
default=None,
help="if set to a str, uses dynamo[str] backend. else, eager",
)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--batch-size", "--batch_size", default=None)
parser.add_argument(
"--torchviz", action="store_true", help="Dump autograd graph with torchviz"
)
parser.add_argument("--profile", action="store_true", help="Run the profiler")
parser.add_argument(
"--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
)
parser.add_argument("--repeat", default=10, help="Repeats for timing run")
parser.add_argument(
"--dynamo-no-optimize-ddp",
"--dynamo_no_optimize_ddp",
action="store_true",
help="Disable dynamo's ddp optimizer (enabled by default)",
)
parser.add_argument(
"--fsdp-checkpoint",
"--fsdp_checkpoint",
action="store_true",
help="Use gradient checkpointing via model-specific policy",
)
parser.add_argument(
"--fsdp-wrap",
"--fsdp_wrap",
action="store_true",
help="Apply fsdp to submodules via model-specific policy",
)
dist_arg = parser.add_mutually_exclusive_group()
dist_arg.add_argument("--ddp", action="store_true")
dist_arg.add_argument("--fsdp", action="store_true")
model_arg = parser.add_mutually_exclusive_group(required=True)
model_arg.add_argument(
"--torchbench-model",
"--torchbench_model",
help="name of torchbench model, e.g. BERT_pytorch",
)
model_arg.add_argument(
"--toy-model", "--toy_model", action="store_true", help="use toy model instead"
)
args = parser.parse_args()
[FSDP+dynamo]: forward treats parameter-views as params (#88781) Dynamo+AotAutograd needs a way to wrap all tensors (whether inputs or params/buffers) in FakeTensor wrappers, and FSDP's mangling of parameters hides them from this wrapping. This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model. Perf/memory validation and possibly optimization is the next step. `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap` The problem: Dynamo (Actually aot_autograd) trips up with FSDP becuase it must wrap all input tensors in FakeTensor wrappers, and it only knows to wrap graph inputs or named_(parameters, buffers). FSDP's pre_forward hook sets views (which are not nn.param) into the flatparam as attrs on the module with the same name as the original param, but they will not show up in named_parameters. - in use_orig_params mode, FSDP still de-registers params during pre-forward hook, then re-registers them post-forward - during forward (between the hooks), the params are setattr'd on the module as regular view tensors, not nn.Parameters - note: use_orig_params is the recommended way to use FSDP, and use_orig_params=False is being deprecated. So i only consider use_orig_params=True for this enablement The solution: - adding them to named_buffers is not possible because it interferes with how FSDP's `_apply` works - since they are not actual nn.parameters, register_parameter will complain about registering them - simply seting `module._parameters[name] = view` seems to be a viable workaround, despite being hacky, and FSDP code does modify _parameters directly already. Note: Manual checkpointing still isn't working with FSDP+dynamo, so that will have to be addressed in a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781 Approved by: https://github.com/ezyang, https://github.com/awgu
2022-11-11 21:22:49 +00:00
model_name = args.torchbench_model
if args.toy_model:
model_name = "ToyModel"
model, inputs = get_model(args)
fn = partial(run_model, args, model, inputs)
world_size = os.getenv("WORLD_SIZE", 1)
t_total = fn(f"{model_name}_{world_size}")
print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")