mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Enable DDPOptimizer by default in dynamo (#88523)
Performance benchmarks on 6 popular models from 1-64 GPUs compiled with torchinductor show performance gains or parity with eager, and showed regressions without DDPOptimizer. *Note: resnet50 with small batch size shows a regression with optimizer, in part due to failing to compile one subgraph due to input mutation, which will be fixed. (hf_Bert, hf_T5_large, hf_T5, hf_GPT2_large, timm_vision_transformer, resnet50) Correctness checks are implemented in CI (test_dynamo_distributed.py), via single-gpu benchmark scripts iterating over many models (benchmarks/dynamo/torchbench.py/timm_models.py/huggingface.py), and via (multi-gpu benchmark scripts in torchbench)[https://github.com/pytorch/benchmark/tree/main/userbenchmark/ddp_experiments]. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88523 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
9048cf16fe
commit
7860fcc245
@@ -81,8 +81,8 @@ def run_model(args, model, inputs, key):
|
||||
if args.verbose:
|
||||
dynamo.config.verbose = True
|
||||
dynamo.config.log_level = logging.DEBUG
|
||||
if args.dynamo_optimize_ddp:
|
||||
dynamo.config.optimize_ddp = True
|
||||
if args.dynamo_no_optimize_ddp:
|
||||
dynamo.config.optimize_ddp = False
|
||||
if args.dynamo == "inductor" and args.fsdp:
|
||||
torch._inductor.config.triton.cudagraphs = False
|
||||
log.warn("disabling inductor cudagraphs for compatibility with FSDP")
|
||||
@@ -129,7 +129,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--trace_file", default="profile.json", help="Run the profiler")
|
||||
parser.add_argument("--repeat", default=10, help="Repeats for timing run")
|
||||
parser.add_argument(
|
||||
"--dynamo_optimize_ddp",
|
||||
"--dynamo_no_optimize_ddp",
|
||||
action="store_true",
|
||||
help="Enable dynamo's ddp optimizer",
|
||||
)
|
||||
|
||||
@@ -138,8 +138,11 @@ capture_scalar_outputs = False
|
||||
enforce_cond_guards_match = True
|
||||
|
||||
# Automatically split model graph into pieces to match DDP bucket sizes
|
||||
# to allow DDP comm/compute overlap
|
||||
optimize_ddp = False
|
||||
# to allow DDP comm/compute overlap. Disable to allow DDP models to
|
||||
# run without graph-breaks, but also without comm/compute overlap.
|
||||
# set torch._dynamo.config.log_level to INFO or DEBUG for more info
|
||||
# about optimize_ddp behavior.
|
||||
optimize_ddp = True
|
||||
|
||||
# If True, raises exception if TorchDynamo is called with a context manager
|
||||
raise_on_ctx_manager_usage = True
|
||||
|
||||
Reference in New Issue
Block a user