diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index fd11e984bbd..d7877c5a3fa 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -25,6 +25,10 @@ from torch._dynamo.utils import clone_inputs # We are primarily interested in tf32 datatype torch.backends.cuda.matmul.allow_tf32 = True +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + def _reassign_parameters(model): # torch_geometric models register parameter as tensors due to