diff --git a/docs/source/_static/img/dynamo/TorchDynamo.png b/docs/source/_static/img/dynamo/TorchDynamo.png new file mode 100644 index 00000000000..351689d80dc Binary files /dev/null and b/docs/source/_static/img/dynamo/TorchDynamo.png differ diff --git a/docs/source/_static/img/dynamo/td_stack.png b/docs/source/_static/img/dynamo/td_stack.png new file mode 100644 index 00000000000..d20b3250453 Binary files /dev/null and b/docs/source/_static/img/dynamo/td_stack.png differ diff --git a/docs/source/_static/img/dynamo/torchinductor_backend.png b/docs/source/_static/img/dynamo/torchinductor_backend.png new file mode 100644 index 00000000000..84e37aa7c4b Binary files /dev/null and b/docs/source/_static/img/dynamo/torchinductor_backend.png differ diff --git a/docs/source/dynamo/custom-backends.rst b/docs/source/dynamo/custom-backends.rst new file mode 100644 index 00000000000..2c8b338045e --- /dev/null +++ b/docs/source/dynamo/custom-backends.rst @@ -0,0 +1,154 @@ +Custom Backends +=============== + +Debugging Backend +----------------- + +Suppose you wanted to better understand what is going on during a +compilation you can create a custom compiler which we’ll refer to as a +backend that will print pretty print the fx ``GraphModule`` extracted +from dynamo’s bytecode analysis and return a ``forward()`` callable. + +.. code-block:: python + + from typing import List + import torch + import torch._dynamo as dynamo + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @dynamo.optimize(my_compiler) + def fn(x, y): + a = torch.cos(x) + b = torch.sin(y) + return a + b + fn(torch.randn(10), torch.randn(10)) + +Running the above example produces the following output: + +:: + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ------------------------------------------------------ ---------- -------- + placeholder x x () {} + placeholder y y () {} + call_function cos (x,) {} + call_function sin (y,) {} + call_function add (cos, sin) {} + output output output ((add,),) {} + +This works for ``torch.nn.Module`` as well as shown below + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + def forward(self, x): + return self.relu(torch.cos(x)) + mod = MockModule() + optimized_mod = dynamo.optimize(my_compiler)(mod) + optimized_mod(torch.randn(10)) + +Let’s take a look at one more example with control flow. + +.. code-block:: python + + from typing import List + import torch + import torch._dynamo as dynamo + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @dynamo.optimize(my_compiler) + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) + +Running this example produces the following output: + +:: + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------- ------------------------------------------------------ ---------------- -------- + placeholder a a () {} + placeholder b b () {} + call_function abs_1 (a,) {} + call_function add (abs_1, 1) {} + call_function truediv (a, add) {} + call_method sum_1 sum (b,) {} + call_function lt (sum_1, 0) {} + output output output ((truediv, lt),) {} + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ----------------------- ----------- -------- + placeholder b b () {} + placeholder x x () {} + call_function mul (b, -1) {} + call_function mul_1 (x, mul) {} + output output output ((mul_1,),) {} + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ----------------------- --------- -------- + placeholder b b () {} + placeholder x x () {} + call_function mul (x, b) {} + output output output ((mul,),) {} + +The order of the last two graphs is nondeterministic depending +on which one is encountered first by the just-in-time compiler. + +Speedy Backend +-------------- + +Integrating a custom backend that offers superior performance is also +easy and we’ll integrate a real one +with `optimize_for_inference `__: + +.. code-block :: python + + def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + scripted = torch.jit.trace(gm, example_inputs) + return torch.jit.optimize_for_inference(scripted) + +And then you should be able to optimize any existing code with + +.. code-block:: python + + @dynamo.optimize(optimize_for_inference_compiler) + def code_to_accelerate(): + ... + +Composable Backends +------------------- + +TorchDynamo includes many backends, which can be found in +`backends.py `__ +or ``torchdynamo.list_backends()``. You can combine these backends +together with the following code: + +.. code-block:: python + + from torch._dynamo.optimizations import BACKENDS + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + trt_compiled = BACKENDS["tensorrt"](gm, example_inputs) + if trt_compiled is not None: + return trt_compiled + # first backend failed, try something else... + cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs) + if cudagraphs_compiled is not None: + return cudagraphs_compiled + return gm.forward diff --git a/docs/source/dynamo/deep-dive.rst b/docs/source/dynamo/deep-dive.rst new file mode 100644 index 00000000000..c60047c2a3d --- /dev/null +++ b/docs/source/dynamo/deep-dive.rst @@ -0,0 +1,145 @@ +TorchDynamo Deeper Dive +======================= +**Author**: `Jason Ansel `_ + +What is a guard? +---------------- + +TorchDynamo operates just-in-time and specializes graphs based on +dynamic properties. For example, the first graph above has the following +guards: + +:: + + GUARDS: + - local 'a' TENSOR_MATCH + - local 'b' TENSOR_MATCH + - global 'torch' FUNCTION_MATCH + +If any of those guards fail, the graph will be recaptured and +recompiled. The interesting guard type there is ``TENSOR_MATCH``, which +checks the following torch.Tensor properties: + +- Python class of the tensor (tensor subclassing, etc) +- dtype +- device +- requires_grad +- dispatch_key (with thread-local includes/excludes applied) +- ndim +- sizes\* (optional) +- strides\* (optional) + +For sizes/strides you can disable this specialization by setting the +following parameter: + +.. code-block:: python + +torch._dynamo.config.dynamic_shapes = True + +The full specialization mode allows the backend compiler to assume an +entirely static graph. Unfortunately, most backends require this. +Operators which return dynamic shapes will trigger a graph break when +not in dynamic shape mode. + +What is dynamo doing? +--------------------- + +If you want to understand better what TorchDynamo is doing, you can set: + +.. code-block:: python + + torchdynamo.config.debug = True + +which triggers useful (but spammy) printouts. + +For example, the printouts for the first graph in the ``toy_example`` +above are: + +:: + + __compiled_fn_0 .1 + opcode name target args kwargs + ------------- ------- ------------------------------------------------------ ---------------- -------- + placeholder a a () {} + placeholder b b () {} + call_function abs_1 (a,) {} + call_function add (abs_1, 1) {} + call_function truediv (a, add) {} + call_method sum_1 sum (b,) {} + call_function lt (sum_1, 0) {} + output output output ((truediv, lt),) {} + + ORIGINAL BYTECODE toy_example example.py 9 + 10 0 LOAD_FAST 0 (a) + 2 LOAD_GLOBAL 0 (torch) + 4 LOAD_METHOD 1 (abs) + 6 LOAD_FAST 0 (a) + 8 CALL_METHOD 1 + 10 LOAD_CONST 1 (1) + 12 BINARY_ADD + 14 BINARY_TRUE_DIVIDE + 16 STORE_FAST 2 (x) + + 11 18 LOAD_FAST 1 (b) + 20 LOAD_METHOD 2 (sum) + 22 CALL_METHOD 0 + 24 LOAD_CONST 2 (0) + 26 COMPARE_OP 0 (<) + 28 POP_JUMP_IF_FALSE 38 + + 12 30 LOAD_FAST 1 (b) + 32 LOAD_CONST 3 (-1) + 34 BINARY_MULTIPLY + 36 STORE_FAST 1 (b) + + 13 >> 38 LOAD_FAST 2 (x) + 40 LOAD_FAST 1 (b) + 42 BINARY_MULTIPLY + 44 RETURN_VALUE + + MODIFIED BYTECODE + 9 0 LOAD_GLOBAL 3 (__compiled_fn_0) + 2 LOAD_FAST 0 (a) + 4 LOAD_FAST 1 (b) + 6 CALL_FUNCTION 2 + 8 UNPACK_SEQUENCE 2 + 10 STORE_FAST 2 (x) + 12 POP_JUMP_IF_FALSE 24 + 14 LOAD_GLOBAL 4 (__resume_at_30_1) + 16 LOAD_FAST 1 (b) + 18 LOAD_FAST 2 (x) + 20 CALL_FUNCTION 2 + 22 RETURN_VALUE + >> 24 LOAD_GLOBAL 5 (__resume_at_38_2) + 26 LOAD_FAST 1 (b) + 28 LOAD_FAST 2 (x) + 30 CALL_FUNCTION 2 + 32 RETURN_VALUE + + GUARDS: + - local 'a' TENSOR_MATCH + - local 'b' TENSOR_MATCH + - global 'torch' FUNCTION_MATCH + +At the top you can see the FX graph (which we already shared above). +Next you see the original bytecode of the function, followed by the +modified bytecode generated by TorchDynamo. Finally, you see the guards +which we covered above. + +In the modified bytecode ``__compiled_fn_0`` is the return value of +``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and +``__resume_at_38_2`` are both generated continuation functions that pick +up execution after a graph break (at bytecode offsets 30 and 38). Each +of these functions take the form: + +:: + + __resume_at_: + ... restore stack state if needed ... + JUMP_ABSOLUTE into toy_example + ... original bytecode of toy_example ... + +By generating this `resume_at` function we force the remainder of the +function to be executed in a new Python frame which recursively +triggers TorchDynamo to restart its capture once execution reaches that +point for the first time. diff --git a/docs/source/dynamo/faq.rst b/docs/source/dynamo/faq.rst new file mode 100644 index 00000000000..2b66e81ebc6 --- /dev/null +++ b/docs/source/dynamo/faq.rst @@ -0,0 +1,376 @@ +Frequently Asked Questions +========================== + +At a high level, the TorchDynamo stack consists of a graph capture from +Python code using dynamo and a backend compiler. In this example the +backend compiler consists of backward graph tracing using AOTAutograd +and graph lowering using TorchInductor. There are of course many more +compilers available `here `__ +but for this document we will focus on inductor as a motivating example. + +Torchdynamo supports training, using AotAutograd to capture backwards: + + 1. the ``.forward()`` graph and ``optimizer.step()`` is captured by torchdynamo’s python evalframe frontend + 2. for each segment of ``.forward()`` that torchdynamo captures, it uses AotAutograd to generate a backward graph segment + 3. each pair of forward, backward graph are (optionally) min-cut partitioned to save the minimal state between forward/backward + 4. the forward, backward pairs are wrapped in autograd.function modules 5. usercode calling\ ``.backward()`` still triggers eager’s autograd engine, which runs each ‘compiled backward’ graph as if it were one op, also running any non-compiled eager ops’ .backward() functions + +Do you support Distributed code? +-------------------------------- + +DDP has been tested and works, support for other distributed training +libraries is under discussion. + +The main reason why Distributed code is challenging with dynamo is +because AOTAutograd unrolls both the forward and backward pass and +provides 2 graphs for backends to optimize. This is a problem for +distributed code because we’d like to ideally overlap communication +operations with computations. Eager pytorch accomplishes this in +different ways for DDP/FSDP- using autograd hooks, module hooks, and +modifications/mutations of module states. In a naive application of +dynamo, hooks that should run directly after an operation during +backwards may be delayed until after the entire compiled region of +backwards ops, due to how AOTAutograd compiled functions interact with +dispatcher hooks. + +The basic strategy for optimizing DDP with Dynamo is outlined in +`distributed.py `__ +where the main idea will be to graph break on `DDP bucket +boundaries `__. + +When each node in DDP needs to synchronize its weights with the other +nodes it organizes its gradients and parameters into buckets which +reduces communication times and allows a node to broadcast a fraction of +its gradients to other waiting nodes. + +Graph breaks in distributed code means you can expect dynamo and its +backends to optimize the compute overhead of a distributed program but +not its communication overhead. Graph-breaks may interfere with +compilation speedups, if the reduced graph-size robs the compiler of +fusion opportunities. However, there are diminishing returns with +increasing graph size since most of the current compute optimizations +are local fusions. So in practice this approach may be sufficient. + +Do I still need to export whole graphs? +--------------------------------------- + +For the vast majority of models you probably don’t and you can use +``torch._dynamo()`` optimize as is but there are a few situations where +full graphs are necessary and you can can ensure a full graph by simply +running ``torch.dynamo(..., nopython=True)`` \* Large scale training +runs, think $250K+ that require pipeline parallelism and other advanced +sharding strategies \* Inference optimizers like +`TensorRT `__ or +`AITemplate `__ that rely +on fusing much more aggressively than training optimizers \* Mobile training or +inference. + +Future work will include tracing communication operations into graphs, +coordinating these operations with compute optimizations, and optimizing +the communciation operations. + +Why is my code crashing? +------------------------ + +If your code ran just fine without dynamo and started to crash with it +enabled then the most important first step is figuring out which part of +the stack your failure occurred in so try running things in the below +order and only try the next step if the previous step succeeded. + +1. ``dynamo.optimize("eager")`` which only runs torchdynamo forward graph + capture and then runs the captured graph with PyTorch. If this fails + then there’s an issue with TorchDynamo. + +2. ``dynamo.optimize("aot_eager")`` + which runs torchdynamo to capture a forward graph, and then AOTAutograd + to trace the backward graph without any additional backend compiler + steps. PyTorch eager will then be used to run the forward and backward + graphs. If this fails then there’s an issue with AOTAutograd. + +3. ``dynamo.optimize("inductor")`` which runs torchdynamo to capture a + forward graph, and then AOTAutograd to trace the backward graph with the + TorchInductor compiler. If this fails then there’s an issue with TorchInductor + +TorchDynamo Errors +~~~~~~~~~~~~~~~~~~ + +If the error that is generated occurs with the ``"eager"`` backend, then +torchdynamo is the most likely source of the error. + +To debug these issues we recommend setting +``torch._dynamo.config.verbose=True`` to get a full stack trace to both +the error in torchdynamo and the user code. In addition to this flag, +you can also set the ``log_level`` of torchdynamo through +``torch._dynamo.config.log_level``. The available levels are the +following: - ``logging.DEBUG``: Print every instruction that is +encountered in addition to all below log levels - ``logging.INFO``: +Print each function that is compiled (original and modified bytecode) +and the graph that is captured in addition to all below log levels - +``logging.WARNING`` (default): Print graph breaks in addition to all +below log levels - ``logging.ERROR``: Print errors only + +If a model is sufficiently large, the logs can become overwhelming. If +an error occurs deep within a model’s python code, it can be useful to +execute only the frame in which the error occurs to enable easier +debugging. There are 2 tools available to enable this: + +* ``env TORCHDYNAMO_DEBUG_FUNCTION=`` will only run TorchDynamo on functions with that name. + +* ``env torch._dynamo.config.replay_record_enabled = True``) which dumps an execution record when an error is encountered. This record can then be replayed to run only the frame where an error occurred. + +TorchInductor Errors +-------------------- + +With TorchInductor as the chosen backend, AOTAutograd is used to +generate the backward graph from the forward graph captured by +torchdynamo. It’s important to note that errors can occur during this +tracing and also while TorchInductor lowers the forward and backward +graphs to GPU code or C++. + +A model can often consist of hundreds or thousands of FX nodes, so +narrowing the exact nodes where this problem occurred can be very +difficult which is why we highly recommend you use our minifier to +create tiny reproducible examples of failures you’re seeing. We can +minify errors that occur either at the AOTAutograd layer or Inductor +layer which you should try in the following order. + +1. ``env TORCHDYNAMO_REPRO_AFTER="aot" python your_model.py`` +2. ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py`` + +Minifying your error is the quickest path to getting it fixed. + +The minifier will actually create a ``repro.py`` for you at the location +set by ``env TORCHDYNAMO_REPRO_DIR`` so make you have right access to +that directory. You can then run ``python repro.py`` and confirm that +you are getting the same error. + +.. note:: + For other compilers such as nvfuser, the process is similar but + instead you would leverage ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py``. + +Why is compilation slow? +------------------------ + +Dynamo Compilation +~~~~~~~~~~~~~~~~~~ + +TorchDynamo has a builtin stats function for collecting and displaying +the time spent in each compilation phase. These stats can be accessed by +calling ``torch._dynamo.utils.compile_times()`` after executing +``torch._dynamo``. By default, this returns a string representation of +the compile times spent in each TorchDynamo function by name. + +Inductor Compilation +~~~~~~~~~~~~~~~~~~~~ + +TorchInductor has a builtin stats and trace function for displaying time +spent in each compilation phase, output code, output graph visualization +and IR dump. ``env TORCHINDUCTOR_TRACE=1 python repro.py``. This is a +debugging tool designed to make it easier to debug/understand the +internals of TorchInductor with an output that will look something like +`this `__ + +Each file in that debug trace can be enabled/disabled via +``torch._inductor.config.trace.*``. The profile and the diagram are both +disabled by default since they are expensive to generate. See the +`example debug directory +output `__ +for more examples. + +Excessive Recompilation +~~~~~~~~~~~~~~~~~~~~~~~ + +When TorchDynamo compiles a function (or part of one), it makes certain +assumptions about locals and globals in order to allow compiler +optimizations, and expresses these assumptions as guards that check +particular values at runtime. If any of these guards fail, Dynamo will +recompile that function (or part) up to +``torch._dynamo.config.cache_size_limit`` times. If your program is +hitting the cache limit, you will first need to determine which guard is +failing and what part of your program is triggering it. + +The `recompilation profiler <#recompilation-profiler>`__ automates the +process of setting TorchDynamo’s cache limit to 1 and running your +program under an observation-only ‘compiler’ that records the causes of +any guard failures. You should be sure to run your program for at least +as long (as many iterations) as you were running when you ran into +trouble, and the profiler will accumulate statistics over this duration. + +.. code-block:: python + + prof = dynamo.utils.CompilationProfiler() + @dynamo.optimize(prof) + def my_model(): + ... + my_model() + print(prof.report()) + +Many of the reasons for graph breaks and excessive recompilation will be +fixed with upcoming support for `tracing dynamic tensor +shapes `__, +more careful choices for guards and better tuned heuristics. + +Why are you recompiling in production? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, you may not want unexpected compiles after a program has +warmed up. For example, if you are serving production traffic in a +latency critical application. For this, TorchDynamo provides an +alternate mode where prior compiled graphs are used, but no new ones are +generated: + +.. code-block:: python + + frozen_toy_example = dynamo.run(toy_example) + frozen_toy_example(torch.randn(10), torch.randn(10)) + +How are you speeding up my code? +-------------------------------- + +There are 3 major ways to accelerat PyTorch code: + +1. Kernel fusion via vertical fusions which fuse sequential operations to avoid + excessive read/writes. For example, fuse 2 subsequent cosines means you + can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion: + the simplest example being batching where a single matrix is multiplied + with a batch of examples but the more general scenario is a grouped GEMM + where a group of matrix multiplications are scheduled together + +2. Out of order execution: A general optimization for compilers, by looking ahead + at the exact data dependencies within a graph we can decide on the most + opportune time to execute a node and which buffers can be reused + +3. Automatic work placement: Similar of the out of order execution point, + but by matching nodes of a graph to resources like physical hardware or + memory we can design an appropriate schedule + +The above are general principles for accelerating PyTorch code but +different backends will each make different tradeoffs on what to +optimize. For example Inductor first takes care of fusing whatever it +can and only then generates `Triton `__ +kernels. It can also + +Triton in addition offers speedups because of automatic memory +coalescing, memory management and scheduling within each Streaming +Multiprocessor and has been designed to handle tiled computations. + +However, regardless of the backend you use it’s best to use a benchmark +and see approach so try out the PyTorch profiler, visually inspect the +generated kernels and try to see what’s going on for yourself. + +Why am I not seeing speedups? +----------------------------- + +Graph Breaks +~~~~~~~~~~~~ + +The main reason you won’t see the speedups you’d like to by using dynamo +is excessive graph breaks. So what’s a graph break? + +Given a program like: + +.. code-block:: python + + @dynamo.optimize(...) + def some_fun(x): + ... + some_fun(x) + ... + +Torchdynamo will attempt to compile all of the torch/tensor operations +within ``some_fun()`` into a single FX graph, but it may fail to capture +everything into one graph. + +Some graph break reasons are insurmountable to TorchDynamo like calling +into a C extension other than torch is invisible to torchdynamo, and +could do arbitrary things without TorchDynamo being able to introduce +necessary guards to ensure that the compiled program would be safe to reuse. + + To maximize performance, it’s important to have as few graph breaks + as possible. + +Identifying the cause of a graph break +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To identify all graph breaks in a program and the associated reasons for +the breaks, ``torch._dynamo.explain`` can be used. This tool runs +TorchDynamo on the supplied function and aggregates the graph breaks +that are encountered. Here is an example usage: + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + print("woo") + if b.sum() < 0: + b = b * -1 + return x * b + explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10)) + print(explanation) + """ + Dynamo produced 3 graphs, with 2 graph break and 6 ops. + Break reasons: + 1. call_function BuiltinVariable(print) [ConstantVariable(str)] {} + File "t2.py", line 16, in toy_example + print("woo") + + 2. generic_jump + File "t2.py", line 17, in toy_example + if b.sum() < 0: + """ + +To throw an error on the first graph break encountered you can use +disable python fallback by using ``nopython=True``, this should be +familiar if you’ve worked with export based compilers. + +.. code-block:: python + + @dynamo.optimize(, nopython=True) + def toy_example(a, b): + ... + +Why didn’t my code recompile when I changed it? +----------------------------------------------- + +If you went ahead and enabled dynamic shapes via +``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code +won’t recompile on shape changes. We’ve added support for dynamic shapes +which avoids recompilations in the case when shapes vary by less than a +factor of 2. This is especially useful in scenarios like varying image +sizes in CV or variable sequence length in NLP. In inference scenarios +it’s often not possible to know what a batch size will be beforehand +because you take what you can get from different client apps. + +In general, TorchDynamo tries very hard not to recompile things +unnecessarily so if for example torchdynamo finds 3 graphs and your +change only modified one graph then only that graph will recompile. So +another tip to avoid potentially slow compilation times is to warmup a +model by compiling it once after which subsequent compilations will be +much faster. Cold start compile times is still a metric we track +visibly. + +Why am I getting incorrect results? +----------------------------------- + +Accuracy issues can also be minified if you set the environment variable +``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect +model and a full repro might be something like +``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason +we need this is downstream compilers will codegen code whether it’s +Triton code or the C++ backend, the numerics from those downstream +compilers can be different in subtle ways yet have dramatic impact on +your training stability. So the accuracy debugger is very useful for us +to detect bugs in our codegen or with a backend compiler. + +Why am I getting OOMs? +---------------------- + +Dynamo is still an alpha product so there’s a few sources of OOMs and if +you’re seeing an OOM try disabling the following configurations in this +order and then open an issue on Github so we can solve the root problem +1. If you’re using dynamic shapes try disabling them, we’ve disabled +them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2. +CUDA graphs with Triton are enabled by default in inductor but removing +them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``. \ No newline at end of file diff --git a/docs/source/dynamo/get-started.rst b/docs/source/dynamo/get-started.rst new file mode 100644 index 00000000000..44434d49e52 --- /dev/null +++ b/docs/source/dynamo/get-started.rst @@ -0,0 +1,181 @@ +Getting Started +=============== + +Let’s start with a simple example and make things more complicated step +by step. Please note that you’re likely to see more significant speedups +the newer your GPU is. + +.. code:: python + + from torch._dynamo import optimize + import torch + def fn(x, y): + a = torch.cos(x).cuda() + b = torch.sin(y).cuda() + return a + b + new_fn = optimize("inductor")(fn) + input_tensor = torch.randn(10000).to(device="cuda:0") + a = new_fn() + +This example will not actually run faster. Its purpose is to demonstrate +the ``torch.cos()`` and ``torch.sin()`` features which are +examples of pointwise ops as in they operate element by element on a +vector. A more famous pointwise op you might actually want to use would +be something like ``torch.relu()``. Pointwise ops in eager mode are +suboptimal because each one would need to need to read a tensor from +memory, make some changes and then write back those changes. The single +most important optimization that inductor does is fusion. So back to our +example we can turn 2 reads and 2 writes into 1 read and 1 write which +is crucial especially for newer GPUs where the bottleneck is memory +bandwidth (how quickly you can send data to a GPU) instead of compute +(how quickly your GPU can crunch floating point operations) + +Another major optimization that inductor makes available is automatic +support for CUDA graphs. +CUDA graphs help eliminate the overhead from launching individual +kernels from a python program which is especially relevant for newer GPUs. + +dynamo supports many different backends but inductor specifically works +by generating `Triton `__ kernels and +we can inspect them by running ``TORCHINDUCTOR_TRACE=1 python trig.py`` +with the actual generated kernel being + +.. code:: python + + @pointwise(size_hints=[16384], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 10000 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.sin(tmp0) + tmp2 = tl.sin(tmp1) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) + +And you can verify that fusing the two ``sins`` did actually occur +because the two ``sin`` operations occur within a single Triton kernel +and the temporary variables are held in registers with very fast access. + +You can read up a lot more on Triton’s performance +`here `__ but the key is it’s in python +so you can easily understand it even if you haven’t written all that +many CUDA kernels. + +As a next step let’s try a real model like resnet50 from the PyTorch +hub. + +.. code:: python + + import torch + import torch._dynamo as dynamo + model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True) + opt_model = dynamo.optimize("inductor")(model) + model(torch.randn(1,3,64,64)) + +And that’s not the only available backend, you can run in a REPL +``dynamo.list_backends()`` to see all the available ones. Try out the +``aot_cudagraphs`` or ``nvfuser`` next as inspiration. + +Let’s do something a bit more interesting now, our community frequently +uses pretrained models from +`transformers `__ or +`TIMM `__ and one of +our design goals is for dynamo and inductor to work out of the box with +any model that people would like to author. + +So we’re going to directly download a pretrained model from the +HuggingFace hub and optimize it: + +.. code:: python + + import torch + from transformers import BertTokenizer, BertModel + import torch._dynamo as dynamo + # Copy pasted from here https://huggingface.co/bert-base-uncased + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0") + model = dynamo.optimize("inductor")(model) # This is the only line of code that we changed + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0") + output = model(**encoded_input) + +If you remove the ``to(device="cuda:0")`` from the model and +encoded_input then triton will generate C++ kernels that will be +optimized for running on your CPU. You can inspect both Triton or C++ +kernels for BERT, they’re obviously more complex than the trigonometry +example we had above but you can similarly skim it and understand if you +understand PyTorch. + +Similarly let’s try out a TIMM example + +.. code:: python + + import timm + import torch._dynamo as dynamo + import torch + model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2) + opt_model = dynamo.optimize("inductor")(model) + opt_model(torch.randn(64,3,7,7)) + +Our goal with dynamo and inductor was to build the highest coverage ML compiler which should work with any model you throw at it. + +Existing Backends +~~~~~~~~~~~~~~~~~ + +TorchDynamo has a growing list of backends, which can be found in +`backends.py `__ +or ``torchdynamo.list_backends()`` each of which with its optional dependencies. + +Some of the most commonly used backend include: + +* **Debugging backends**: \* ``dynamo.optimize("eager")`` - Uses PyTorch + to run the extracted GraphModule. This is quite useful in debugging + TorchDynamo issues. \* ``dynamo.optimize("aot_eager")`` - Uses + AotAutograd with no compiler, i.e, just using PyTorch eager for the + AotAutograd’s extracted forward and backward graphs. This is useful for + debugging, and unlikely to give speedups. + +* **Training & inference backends**: \* ``dynamo.optimize("inductor")`` - + Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging + codegened Triton kernels `Read + more `__ + + * ``dynamo.optimize("nvfuser")`` - nvFuser with TorchScript. `Read more `__ + + * ``dynamo.optimize("aot_nvfuser")`` - nvFuser with AotAutograd. `Read more `__ + + * ``dynamo.optimize("aot_cudagraphs")`` - cudagraphs with AotAutograd. `Read more `__ + +* **Inference-only backend**\ s: \* ``dynamo.optimize("ofi")`` - Uses + Torchscript optimize_for_inference. `Read + more `__ + + * ``dynamo.optimize("fx2trt")`` - Uses Nvidia TensorRT for inferenc optimizations. `Read more `__ + + * ``dynamo.optimize("onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more `__ \* ``dynamo.optimize("ipex")`` - Uses IPEX for inference on CPU. `Read more `__ + +Why do you need another way of optimizing PyTorch code? +------------------------------------------------------- + +While a number of other code optimization tools exist in the PyTorch +ecosystem, each of them has its own flow. Here is a few examples of +existing methods and their limitations: + +- ``torch.jit.trace()`` is silently wrong if it cannot trace e.g: + during control flow +- ``torch.jit.script()`` requires modifications to user or library code + by adding type annotations and removing non PyTorch code +- ``torch.fx.symbolic_trace()`` either traces correctly or gives a hard + error but it’s limited to traceable code so still can’t handle + control flow +- ``torch._dynamo`` works out of the box and produces partial graphs. + It still has the option of producing a single graph with + ``nopython=True`` which are needed for `some + situations <./documentation/FAQ.md#do-i-still-need-to-export-whole-graphs>`__ + but allows a smoother transition where partial graphs can be + optimized without code modification + +.. |image0| image:: ../_static/img/dynamo/TorchDynamo.png diff --git a/docs/source/dynamo/guards-overview.rst b/docs/source/dynamo/guards-overview.rst new file mode 100644 index 00000000000..99a004ec221 --- /dev/null +++ b/docs/source/dynamo/guards-overview.rst @@ -0,0 +1,513 @@ +Guards Overview +=============== + +From a UX perspective, TorchDynamo is very easy to use. The user invokes +``torchdynamo.optimize`` as an annotation: + +.. code-block:: python + + @torchdynamo.optimize(my_compiler) + def fn_foo(bar): + +Where a complete example looks like this: + +.. code-block:: python + + from typing import List + import torch + import torchdynamo + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @torchdynamo.optimize(my_compiler) + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) + +This allows TorchDynamo to capture the interpreted Python frames, grab +any and all relevant information, and speed things up wherever it can. +The speedup comes from a few places, and can be rather dependent on the +backend (my_compiler above) provided, but the one speedup we care about +most for today’s overview is **caching**. Caching itself is not a direct +speedup, so much as a critical enablement to allow us to prevent +recompilation. We dig a hole with dynamo, and caching allows us to get +out. Its a speedup from that perspective, but relatively neutral when +all things are considered - however, it enables us to hold perf +neutrality while then enabling backends - the true source of our +speedups. + +With even a pass-through no-op backend provided: + +.. code-block:: python + + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + return gm.forward + +We can see TorchDynamo speeding up Python execution quite a bit, even on +regular Python, not just PyTorch. + +Caching and Guards Overview +--------------------------- + +TorchDynamo operates through caching transformed (by TorchDynamo) user +bytecode. When we receive a frame for evaluation, we check if the +**objects referenced in the frame have changed** in certain ways, and if +not, we read the previously transformed user bytecode to evaluate it. +The details of how we do this will be saved for a later writeup. +Instead, we will focus on how we can identify whether or not the +**objects referenced in the frame have changed**. This is a critical +piece of functionality in TorchDynamo, because it drives the entire +invalidation lifecycle. We refer to this functionality as **guards**. + +At a very high level, the vastly oversimplified TLDR flow is this: + +1) We receive a python frame +2) We convert the given frame from (1), passing it through instruction + translation +3) For the objects captured in (2), we create tracking objects that are + (a) tracked on an output graph, which is an internal specialization + of a torch.fx.Tracer (and the topic of a later writeup), and (b) + guards, the topic of this document. +4) We process the guard objects created in (3), turning them into a + generated python function, check_fn, associated with a piece of code. +5) The check_fn is evaluated whenever we encounter this code a + subsequent time - if a check_fn passes and evaluates to True, we know + the code in the cache and the code encountered here is the same, and + can be safely used. If it fails and evaluates to False, we know the + code in the cache is not valid, and can be thrown out in favor of a + new entry, through recompilation or a graph break. + +Python Frame Evaluation and PEP 523 +----------------------------------- + +The functionality of TorchDynamo is based on +`PEP 523 `__. + +TorchDynamo installs a frame evaluation function on Python, via +`_PyInterpreterState_SetEvalFrameFunc`. The overview of function +selection, thread management, and cleanup is out of scope for this +writeup, but the important part is that TorchDynamo has a hook where +Python can hand control back to us during evaluation. + +The function we have installed is ``convert_frame`` or +``convert_frame_assert`` in the ``nopython=True`` case, but glossing +over that nuance for now, let’s take a look at ``convert_frame_assert``, +as ``convert_frame`` proxies to it anyway. + +We can find it on `line 20 of convert_frame.py +`__, +with a signature as follows: + +.. code-block:: python + + def convert_frame_assert(compiler_fn: Callable, one_graph=True): + +This function wraps the entry point of where Python invokes TorchDynamo +with a frame, glossing over the nuances of ``wrap_convert_context`` for +now: + +.. code-block:: python + + def _convert_frame_assert(frame: types.FrameType, cache_size: int): + +Here is what this function does: + +1) Checks if it has seen this ``code``\ (see: f_code `here + `__) before and exits + early if it did. +2) Checks if the code is an unsupported case. +3) Checks if the ``cache_size`` (second arg above) crosses the limit + defined in the config, ``cache_size_limit``. If it has, the function + drops the frame and logs warnings. This helps to avoid constant + recompilation of a frame as it generally means that the frame is hot + in an unexpected way and caching it produces needless overhead, + as it is likely to get evicted the next time it is encountered. +4) Passes the frame, alongside a function that creates an + ``InstructionTranslator`` through bytecode + transformation, via ``transform_code_object``. A few crucial things + happen under the hood here: + + 1) New code is produced through ``transform_code_object``. + + 2) An FX tracer named ``output`` is produced through + ``InstructionTranslator``. + + This can be a bit confusing, + as ``InstructionTranslator`` is not an `fx` tracer, but its stored + in a variable named tracer, and its output*\ **is**\ *an `fx`tracer.* + + 3) The function produces guards and stores them on ``output`` above. + + 4) The function produces ``output_instructions`` and stores them on + ``output`` above. + + 5) The function maps the newly produced transformed code to the initial code it + read off the frame. This mapping is worth remembering, we will + refer to it much later on below where we cover guard failures. + +5) Using the transformed code from 4.1 and the guards from 4.3 + the function produces a `GuardedCode`. + +Now that we have learned about frame evoluation, let’s review +``InstructionTranslator``, and see how it turns the frame we handed +it over into TorchDynamo internal types. + +InstructionTranslator +--------------------- + +`InstructionTranslator` does a lot! We won’t cover the details of +everything it does, but most importantly for this document, it produces +a mapping of ``symbolic_locals`` which maintains a mapping from the +frame’s f_locals to TorchDynamo internal Variable objects (more on these +in a moment. ``symbolic_locals`` is filled via traversing the frame’s +locals: + +.. code-block:: python + + self.symbolic_locals = collections.OrderedDict( + (k, VariableBuilder(self, LocalSource(k))(f_locals[k])) + for k in vars + if k in f_locals + ) + +We will get to how this works later, from a few other examples that lead +us to understanding ``VariableTracker`` and ``VariableBuilder``. The +important component here, for us, for now, is the invocation of a call +into ``VariableBuilder``. ``VariableBuilder``\ ’s call implementation +proxies into a function called ``_wrap``, which in turn both constructs +instances of ``VariableTracker`` and calls ``make_guards`` on them. More +on that later. + +This mapping, in turn, is critical as each Variable has associated +guards, which are then passed to ``self.output``, the instance of +``OutputGraph``, an fx tracer, mentioned in 4.2 of the section above. If +you recall, this ``OutputGraph``, stored in a variable called ``output`` +is where our guards are stored before being passed on to become +``GuardedCode`` + +How does ``InstructionTranslator`` do this? At the heart of it, there is +a loop that is pumped, which drives a function ``step``. + +``step`` is just that - a single processing step, taking exactly one +instruction and doing *something* with it. Note: These are real +instructions processed by TorchDynamo’s ``transform_code_object``, and +it’s pretty cool. + +.. note:: This section purposly skips the details of + `dis.get_instructions `__, + and how we set up the ``Instruction`` class. + +For the toy example above, here is a snippet of a what a few +``Instruction``\'s may look like: + +.. code-block:: python + + Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None) + Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None) + Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None) + +This is the core functionality of this function. Take a look at the ``opname``, +and then take a look at this little snippet from inside ``step``; + +.. code-block:: python + + if not hasattr(self, inst.opname): + unimplemented(f"missing: {inst.opname}") + getattr(self, inst.opname)(inst) + +As we can see, we check if the current class, the +``InstructionTranslator`` has a attribute set matching the operator name +(ex: LOAD_CONST). If it does, we invoke it, passing the whole +instruction object in. If it does not, we drop the frame as +unimplemented. + +For the LOAD_CONST example, we can see that we do indeed support it, +with a relatively straightforward definition: + +:: + + def LOAD_CONST(self, inst): + self.push(ConstantVariable(value=inst.argval)) + +Passing over, for now, on the other details of ``InstructionTranslator`` +we can see that this function creates a new instance of the class +``ConstantVariable`` , with a value, in our example case, -1, and then +pushes it onto the stack. + +There are dozens of such methods - see symbolic_convert.py for all of +them. Generally, we implement as many matching methods to python +bytecode instructions as possible. + +Across both the logic downstream of ``step`` and the logic from invoking +``VariableBuilder`` - we now have a lot of ``VariableTracker``\ s and of +course, we’ve spoken about creating guards quiet a bit. Let’s dig into +what Variables are, and get a little closer to understanding guards. + +Variables +--------- + +A ``ConstantVariable`` is an instance of\ ``VariableTracker``. +``VariableTracker`` represents a tracked python local or stack value. + +When it comes to representing an object inside TorchDynamo, a +VariableTracker does exactly what it says - it tracks a given variable. +Its an extremely flexible class, but there are a few points to keep in +mind: + +- It manages the ``guard`` relationship around the underlying object + through: + + - `make_guard` + - `replace_guards` + - `add_guard(s)` + - `propagate` - ``propagate(*vars: List[List["VariableTracker"]])`` - + Perhaps the most important of all, in that it combines guards from + all the provided VariableTracker instances passed in. It visits + the guards and combines the guards from these onto itself. + +- It acts as a proxy on behalf of the underlying object, implementing + methods for the rest of TorchDynamo to get information about the + tracked object: + + - `call_method` + - `call_function` + - `python_type` + - `as_proxy` + - `is/as_python_proxy` + +- It stores the variable ``source`` of type ``Source``, from + torchdynamo/source.py. This source type is a relatively self + contained class to help us organize and bookeep where the original + source came from, and helps provide convenience methods for things + like getting the name, and importantly for us, producing guards. + +And this class (``VariableTracker``) is built around subclassing, +somewhere between a full Abstract Base Class and fully fleshed out class +- it leaves many methods raising NotImplementedError - with reliance on +subclasses (see: torchdynamo/variables/ for all subclasses) to fulfill +contracts and custom behaviors. + +Knowing what we know now, we can see an example of how an instruction +from ``dis``, ``BUILD_TUPLE`` + + BUILD_TUPLE(count) Creates a tuple consuming count items from the + stack, and pushes the resulting tuple onto the stack. + +In our case, our signature will be a *little* different due to the way +we create ``Instruction`` objects, but the gist of it will be the same. +Instead of passing in ``count``, we pass in an object with a little +extra bookkeeping, and of course, we deal with turning regular old +python objects into TorchDynamo notions: + +:: + + def BUILD_TUPLE(self, inst): + items = self.popn(inst.argval) + options = VariableTracker.propagate(items) + self.push(TupleVariable(items, **options)) + +What is happening here? 1) We read argval, which in this case, is +analogous to ``counts`` in the pydoc for the equivalent instruction. + +2) We ``popn`` the items, in this case, the signature is + ``def popn(self, n: int) -> List[TensorVariable]:`` this hints at an + underlying contract - we are returning ``TensorVariables``. If we + take a closer look at sybmolic_convert.py and + ``InstructionTranslatorBase``/``InstructionTranslator``\ we see that + the only thing pushed onto and popped from our stack are + ``VariableTracker``\ s. + +3) We call ``VariableTracker.propogate`` (remember it, from above?) This + takes the guards from every single item popped off the stack in 2, + and recursively traverses it and combines all the guards into + ``options``: ``py return { "guards": guards, }`` + +4) We then make a new instance of a ``VariableTracker``, + ``TupleVariable``\ out of the ``items`` and ``options``. This then + allows us to install all the appropriate guards from the ``items`` + that make up the new ``TupleVariable`` + +Note: You may wonder - where did the first guards come from? Propagation +is good and all, but don’t we need something created before it can be +propagated. Yes! Remember that ``VariableBuilder`` above? It calls +``make_guards`` as it creates ``VariableTracker`` instances, from +``f_locals``. This in turn calls into the ``source``, to have it create +guards. + +After all this, bytecode translation is done and we are one step closer +to producing ``GuardedCode``. We now understand how locals become +``VariableTracker``\ s, how instructions are handled, and where guards +are called on for creation. Before we can go into seeing how code and +guards are combined into a GuardedCode object, we need to dig a little +bit into those ``make_guard`` and ``source.make_guard`` calls above. We +can then understand, really, what was going on when we made guards +alongside, and on, ``VariableTracker`` instances. + +Making Guards +------------- + +Guards are just python objects, of the class ``Guard``, however, theres +a good amount of detail around this little class. + +Looking at the definition of the dataclass (and therefore, ctor +signature), we see that it has a name, a source, and a create function. + +:: + + @dataclasses.dataclass + class Guard: + name: str + source: GuardSource + create_fn: Callable + +The name should be the name of the variable. + +The source here is an enum indicating what *kind* of source the guard +belongs to [Note: not to be confused with ``Source`` and the other types +in source.py, as stored on ``VariableTracker``, as discussed above] + +And create_fn is the heart of how we go from having this simple +dataclass to actually producing valid python code to be invoked for +knowing whether or not things have changed in between invocations, and +whether we can safely read from the code cache or not (In case you +forgot what all this was for!) + +The most common code paths for getting an instance of a guard are +through ``make_guards`` on ``VariableTracker``. +``make_guards``->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)`` + +Or, in a concrete example: + +.. code-block:: python + + ... + elif istype(value, range): + guards = self.make_guards(GuardBuilder.EQUALS_MATCH) + return RangeVariable(value=value, guards=guards) + +Since ``source`` was set at the construction time of this +``VariableTracker``, all that was needed here was to provide the fn, +``GuardBuilder.EQUALS_MATCH`` to the ``create_fn`` field. + +This ``create_fn`` must be a method on ``GuardBuilder``. The reason for +this becomes apparent in our next step. Once we have all the guards +created for a frame, we move on to ``CheckFunctionManager`` and +``compile_check_fn``. + +Remember that ``convert_frame`` function way above, in the first +section? Before it can produce a ``GuardedCode``, it needs to run the +``CheckFunctionManager``, with all the guards, to produce a ``check_fn`` +which will then, in turn get passed in alongside the code into +``GuardedCode``. This is the same ``check_fn`` that we store in our +cache entry, and the same one we run to know whether or not to retrieve +the code stored alongside. For reference, here is that code: + +.. code-block:: cpp + + static CacheEntry *create_cache_entry(CacheEntry *next, + PyObject *guarded_code) { + CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry)); + DEBUG_NULL_CHECK(e); + e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn"); + NULL_CHECK(e->check_fn); + e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code"); + NULL_CHECK(e->code); + e->next = next; + return e; + } + +We now know how a ``check_fn`` function is used, and who makes it, and +what it is composed of, but what we do not yet know is how. How does a +list of ``Guard`` objects become a function we can run later on? + +First, we iterate these guards: + +.. code-block:: python + + for guard in sorted(guards or [], key=Guard.sort_key): + if not config.guard_nn_modules and guard.is_nn_module(): + continue + guard.create(local_builder, global_builder) + +Calling ``guard.create`` runs that ``create_fn`` we set on the ``Guard`` +class above (don’t confuse it with the ``check_fn`` we are working on +producing, the names are similar, so it can get a little confusing). In +our example above, our ``create_fn`` is ``GuardBuilder.EQUALS_MATCH``. +So we are now invoking it, passing in the ``self``, the guard itself, +in. + +The signature is: ``def EQUALS_MATCH(self, guard: Guard):`` + +And internally to that function, we can use the ``name`` on the guard to +get back our original object, querying it for data and type information, +which in turn gets us to the most important bit: appending code. + +At its simplest, ``EQUALS_MATCH`` appends just one line of code: +``self.code.append(f"{ref} == {val!r}")``. Where ``ref`` is the name of +the variable, and val is the value. It might produce code like this: + +.. code-block:: + + y == 2 + +Pretty simple, but if we append a few other kinds of ``GuardBuilder`` +functions on (For a more complex case), and then combine them all with +``and`` in between each statement (as we do), we might get something +like this: + +.. code-block:: + + ___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x) + +Now we’re talking! Let’s see what we have here: 1) A check for +``.valid`` (we will come back to invalidation later on) 2) A type id +check 3) A value check 4) A tensor check + +This becomes the heart of the code our ``check_fn``, which in turn, as +you recall, is evaluated the **next** time we encounter this code. It +will then check: + +1) Is this code still valid? +2) If (1), Does ``y`` still have a type of ``94367738391392``? +3) If (2), is ``y`` still 2? +4) If (3), let’s check on if tensor ``x`` changed in some specific ways + +If all of these are still true, then we can use the code cached +alongside this ``check_fn``! Joyous day! [Note: a deeper dive for how +and where this happens if saved for a later writeup, but reading +``static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {`` of +``_eval_frame.c`` is a good place to start for the inquisitive reader +who has made it thus far]. + +If not, then, we can move on to recompiling the code anew, and storing +that in the cache alongside this code, and a whole new ``check_fn``, +again to be checked on yet another subsequent frame. + +There are lots of other such functions on ``GuardBuilder`` which get +coalesced into, at times massive, strings which then get evaluated as +python code and stored into ``check_fn``. Our example above is +illustrative of a simple case, but I urge you to read the other +functions on ``GuardBuilder``, or better yet, dump the ``code`` variable +in ``compile_check_fn`` to really see what’s getting produced, +especially on larger, real models! + +Summary +------- + +In this, we have glossed over: - The role of ``.valid`` and invalidation +around weak references (and potentially soon to be NN Module +invalidations) - How the C++ side of guard functions +(``___check_type_id``, ``___check_tensors``, etc) operate - What happens +when guards fail? - What happens if we produce invalid guard code? + +Despite all that, I hope this has been a useful read. We covered how +user provided code, wrapped in a TorchDynamo context goes on to get +traced and tracked internally, organized into ``VariableTracker``\ s +``Source``\ s and subsequently ``Guard``\ s, and how those ``Guards`` in +turn guide cache entry selection and invalidation when handing Python +code. diff --git a/docs/source/dynamo/index.rst b/docs/source/dynamo/index.rst new file mode 100644 index 00000000000..d34f6a7d275 --- /dev/null +++ b/docs/source/dynamo/index.rst @@ -0,0 +1,44 @@ +TorchDynamo Documentation +========================= + +**TorchDynamo** is a Python-level JIT compiler designed to make unmodified +PyTorch programs faster. TorchDynamo hooks into the frame evaluation API +in CPython (`PEP 523 `__) to +dynamically modify Python bytecode right before it is executed. It +rewrites Python bytecode in order to extract sequences of PyTorch +operations into an `FX Graph `__ +which is then just-in-time compiled with a customizable backend. +It creates this FX Graph through bytecode analysis and is designed to +mix Python execution with compiled backends to get the best of both +worlds: usability and performance. + +TorchDynamo makes it easy to experiment with different compiler +backends to make PyTorch code faster with a single line decorator +``torch._dynamo.optimize()`` + +.. image:: ../_static/img/dynamo/TorchDynamo.png + +For more information about `TorchInductor`, one of the backends +supported by `TorchDynamo Graph `__ +into `Triton `__ for GPUs or +`C++/OpenMP `__ for CPUs. We have a +`training performance dashboard `__ +that provides performance comparison for different training backends. You can read +more in the `TorchInductor post on PyTorch +dev-discuss `__. + +.. seealso:: + + * `TorchDynamo deep-dive video `__ + * `dev-discuss topics `__ + +.. toctree:: + :hidden: + + installation + get-started + guards-overview + custom-backends + deep-dive + troubleshooting + faq diff --git a/docs/source/dynamo/installation.rst b/docs/source/dynamo/installation.rst new file mode 100644 index 00000000000..6d1b09f0415 --- /dev/null +++ b/docs/source/dynamo/installation.rst @@ -0,0 +1,83 @@ +Installing TorchDynamo +====================== + +This section describes how to install TorchDynamo. + +Requirements and Setup +---------------------- + +Python 3.8 is recommended. Python 3.7 through 3.10 are supported and +tested. Make sure to have a development version of Python installed +locally as well. + +TorchDynamo is included in the nightly binaries of PyTorch. You can +find more information `here `__ + +Install GPU/CUDA version requirements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To use GPU back ends (and in particular Triton), please make sure that +the CUDA that you have installed locally matches the PyTorch version you +are running. + +The following command installs GPU PyTorch+TorchDynamo along with GPU +TorchDynamo dependencies (for CUDA 11.7): + +.. code-block:: python + + pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 + +CPU requirements +~~~~~~~~~~~~~~~~ + +There are no additional requirements for CPU TorchDynamo. CPU +TorchDynamo is included in the nightly versions of PyTorch, which, for +reference, can be installed with the following command: + +.. code-block:: shell + + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + + +Install from local source +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Build PyTorch from source: +https://github.com/pytorch/pytorch#from-source, which has TorchDynamo +included. + +To install GPU TorchDynamo dependencies, run ``make triton`` in the +PyTorch repo root directory. + +Verify Installation +~~~~~~~~~~~~~~~~~~~ + +If you built PyTorch from source, then you can run the following +commands (from the PyTorch repo root directory) that run minimal +examples to check that TorchDynamo is installed correctly: + +.. code:: shell + + cd tools/dynamo + python verify_dynamo.py + +If you do not have the PyTorch source locally, you can alternatively +copy the script (``tools/dynamo/verify_dynamo.py``) from the PyTorch +repo and run it locally. + +Docker installation +------------------- + +We also provide all the required dependencies in the PyTorch nightly +binaries which you can download with + +.. code-block:: + + docker pull ghcr.io/pytorch/pytorch-nightly + +And for ad hoc experiments just make sure that your container has access +to all your GPUs + +.. code-block:: bash + + docker run --gpus all -it ghcr.io/pytorch/pytorch-nightly:latest /bin/bash diff --git a/docs/source/dynamo/troubleshooting.rst b/docs/source/dynamo/troubleshooting.rst new file mode 100644 index 00000000000..8542d02bfa9 --- /dev/null +++ b/docs/source/dynamo/troubleshooting.rst @@ -0,0 +1,665 @@ +TorchDynamo Troubleshooting +=========================== + +**Author**: `Michael Lazos `_ + +TorchDynamo is still in active development, and many of the reasons for +graph breaks and excessive recompilation will be fixed with upcoming +support for `tracing dynamic tensor +shapes `__, +more careful choices for guards and better tuned heuristics. + +In the mean time, you may need to diagnose a particular issue and +determine if it is easy to work around with a change to your model, or +file an issue for support. + +Also, we are actively developing debug tools, profilers, and improving our +errors/warnings. Please give us feedback if you have an issue with this +infra, or an idea for an improvement. Below is a table of the available +tools and their typical usage. For additional help see +`Diagnosing Runtime Errors <#diagnosing-runtime-errors>`__. + +.. list-table:: Title + :widths: 25 25 50 + :header-rows: 1 + + * - Tool + - Purpose + - Usage + * - Info logging + - View summarized steps of compilation + - ``torch._dynamo.config.log_level = logging.INFO`` + * - Debug logging + - View detailed steps of compilation (print every instruction traced) + - ``torch._dynamo.config.log_level = logging.DEBUG`` and + ``torch._dynamo.config.verbose = True`` + * - Minifier for any backend + - Find smallest subgraph which reproduces errors for any backend + - set environment variable ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` + * - Minifier for ``TorchInductor`` + - If the error is known to occur after `AOTAutograd`` find + smallest subgraph wich reproduces errors during TorchInductor lowering + - set environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"`` + * - Accuracy minifier + - Finds the smallest subgraph which reproduces an accuracy issue + between an eager model model and optimized model + - ``TORCHDYNAMO_REPRO_AFTER=<"aot"/"dynamo"> TORCHDYNAMO_REPRO_LEVEL=4`` + * - ``torch._dynamo.explain`` + - Find graph breaks and display reasoning for them + - ``torch._dynamo.explain(fn, *inputs)`` + * - Record/Replay + - Record and replay frames which to reproduce errors during graph capture + - ``torch._dynamo.config.replay_record_enabled = True`` + * - TorchDynamo function name filtering + - Only compile functions with the given name to reduce noise when + debugging an issue + - set environment variable ``TORCHDYNAMO_DEBUG_FUNCTION=`` + * - TorchInductor Debug logging + - Print general TorchInductor debug info and generated Triton/C++ code + - ``torch._inductor.config.debug = True`` + * - TorchInductor Tracing + - Show time taken in each TorchInductor stage + output code and graph + visualization + - set the environment variable TORCHINDUCTOR_TRACE=1 or + ``torch._inductor.config.trace.enabled = True`` + +Diagnosing Runtime Errors +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Below is the TorchDynamo compiler stack. + +At a high level, the TorchDynamo stack consists of a graph capture from +Python code (TorchDynamo) and a backend compiler. In this example the +backend compiler consists of backward graph tracing (AOTAutograd) and +graph lowering (TorchInductor)*. Errors can occur in any component of +the stack and will provide full stack traces. + +You may use info logging +(``torch._dynamo.config.log_level = logging.INFO``) and look for +``Step #: ...`` outputs in order to determine in which component the +error occurred in. Logs are made at the beginning and end of each step, +so the step that an error should correspond to is the most recent logged +step whose end has not yet been logged. The steps correspond to the +following parts of the stack (according to the image above): + +==== ================ +Step Component +==== ================ +1 TorchDynamo +2 Compiler Backend +3 TorchInductor +==== ================ + +The beginning and end of AOTAutograd is currently not logged, but we +plan to add it soon. + +If info logging is insufficient, then there are also some backend +options which can enable you to determine which component is causing the +error if you’re unable to understand the error message that is +generated. These are the following: + +- ``"eager"``: only runs torchdynamo forward graph capture and then + runs the captured graph with PyTorch. This provides an indication as + to whether TorchDynamo is raising the error. + +- ``"aot_eager"``: runs torchdynamo to capture a forward graph, and + then AOTAutograd to trace the backward graph without any additional + backend compiler steps. PyTorch eager will then be used to run the + forward and backward graphs. This is useful to narrow down the issue + to AOTAutograd. + +The general procedure to narrow down an issue is the following: 1. Run +your program with the ``"eager"`` backend. If the error no longer +occurs, the issue is in the backend compiler that is being used (if +using TorchInductor, proceed to step 2, if not, see `this +section <#minifying-backend-compiler-errors>`__). If the error still +occurs with the ``"eager"`` backend, it is an `error while running +torchdynamo <#torchdynamo-errors>`__. + +2. This step is only necessary if TorchInductor is used as the backend + compiler. Run the model with the ``"aot_eager"`` backend. If this + backend raises an error then the error is occurring during + AOTAutograd tracing. If the error no longer occurs with this backend, + then `the error is in + TorchInductor\* <#minifying-torchinductor-errors>`__. + +Each of these cases are analyzed in the following sections. + +\*Note on TorchInductor naming: The TorchInductor backend consists of +both AOTAutograd tracing and the TorchInductor compiler itself. We will +disambiguate by referring to TorchInductor as the backend, and +TorchInductor lowering as the phase which lowers the graph traced by +AOTAutograd. + +Torchdynamo Errors +------------------ + +If the error that is generated occurs with the ``"eager"`` backend, then +torchdynamo is the most likely source of the error. Here is example code +which will generate an error. + +.. code:: py + + import torch + + import torch._dynamo as dynamo + + + @dynamo.optimize("eager") + def test_assertion_error(): + y = torch.ones(200, 200) + z = {y: 5} + return z + + + test_assertion_error() + +Which will generate the following error: + +:: + + torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26 + due to: + Traceback (most recent call last): + File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP + assert isinstance(k, ConstantVariable) or ( + AssertionError + + from user code: + File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error + z = {y: 5} + + Set torch._dynamo.config.verbose=True for more information + ========== + +As the message suggests you can set +``torch._dynamo.config.verbose=True`` to get a full stack trace to both +the error in torchdynamo and the user code. In addition to this flag, +you can also set the ``log_level`` of torchdynamo through +``torch._dynamo.config.log_level``. The available levels are the +following: - ``logging.DEBUG``: Print every instruction that is +encountered in addition to all below log levels - ``logging.INFO``: +Print each function that is compiled (original and modified bytecode) +and the graph that is captured in addition to all below log levels - +``logging.WARNING`` (default): Print graph breaks in addition to all +below log levels - ``logging.ERROR``: Print errors only + +If a model is sufficiently large, the logs can become overwhelming. If +an error occurs deep within a model’s python code, it can be useful to +execute only the frame in which the error occurs to enable easier +debugging. There are two tools available to enable this: - Setting the +environment variable TORCHDYNAMO_DEBUG_FUNCTION to the desired function +name will only run torchdynamo on functions with that name. - There is a +record/replay tool (set +``torch._dynamo.config.replay_record_enabled = True``) which dumps an +execution record when an error is encountered. This record can then be +replayed to run only the frame where an error occurred. + +TorchInductor Errors +-------------------- + +If the error doesn’t occur with the ``"eager"`` backend, then the +backend compiler is the source of the error (`example +error `__). +There are `different +choices `__ +for backend compilers for torchdynamo, with TorchInductor or nvfuser +fitting the needs of most users. This section focuses on TorchInductor +as the motivating example, but some tools will be usable with other +backend compilers. + +Below is the portion of the stack which we are focusing on: + +With TorchInductor as the chosen backend, AOTAutograd is used to +generate the backward graph from the forward graph captured by +torchdynamo. It’s important to note that errors can occur during this +tracing and also while TorchInductor lowers the forward and backward +graphs to GPU code or C++. A model can often consist of hundreds or +thousands of FX nodes, so narrowing the exact nodes where this problem +occurred can be very difficult. Fortunately, there are tools availabe to +automatically minify these input graphs to the nodes which are causing +the issue. The first step is to determine whether the error occurs +during tracing of the backward graph with AOTAutograd or during +TorchInductor lowering. As mentioned above in step 2, the +``"aot_eager"`` backend can be used to run only AOTAutograd in isolation +without lowering. If the error still occurs with this backend, this +indicates that the error is occurring during AOTAutograd tracing. + +Here’s an example: + +.. code:: py + + import torch + + import torch._dynamo as dynamo + + model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) + @dynamo.optimize("inductor") + def test_backend_error(): + + y = torch.ones(200, 200) + x = torch.ones(200, 200) + z = x + y + a = torch.ops.aten._foobar(z) # dummy function which errors + return model(a) + + + test_backend_error() + +Running this should give you this error (with a longer stack trace below +it) + +:: + + Traceback (most recent call last): + File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function + return lowerings[target](*args, **kwargs) + File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped + return decomp_fn(*args, **kwargs) + File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar + assert False + AssertionError + ... + +`error with full stack +trace `__ + +If you then change ``@dynamo.optimize("inductor")`` to +``@dynamo.optimize("aot_eager")``, it will run without error, because +`the +issue `__ +is in the TorchInductor lowering process, not in AOTAutograd. + +Minifying TorchInductor Errors +------------------------------ + +From here, let’s run the minifier to get a minimal repro. Setting the +environment variable TORCHDYNAMO_REPRO_AFTER=“aot” (or setting +``torch._dynamo.config.repro_after="aot"`` directly) will generate a +python program which reduces the graph produced by AOTAutograd to the +smallest subgraph which reproduces the error. (See below for an example +where we minify the graph produced by torchdynamo) Running the program +with this environment variable should show nearly `identical +output `__, +with an additional line indicating where ``minifier_launcher.py`` has +been written to. The output directory is configurable by setting +``torch._dynamo.config.base_dir`` to a valid directory name. The final +step is to run the minifier and check that it runs successfully. A +successful run looks like +`this `__. +If the minifier runs successfully, it generates runnable python code +which reproduces the exact error. For our example this is the following +code: + +.. code:: py + + import torch + from torch import tensor, device + import torch.fx as fx + from torch._dynamo.testing import rand_strided + from math import inf + from torch.fx.experimental.proxy_tensor import make_fx + + # torch version: 1.13.0a0+gitfddfc44 + # torch cuda version: 11.6 + # torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5 + + + # CUDA Info: + # nvcc: NVIDIA (R) Cuda compiler driver + # Copyright (c) 2005-2022 NVIDIA Corporation + # Built on Thu_Feb_10_18:23:41_PST_2022 + # Cuda compilation tools, release 11.6, V11.6.112 + # Build cuda_11.6.r11.6/compiler.30978841_0 + + # GPU Hardware Info: + # NVIDIA A100-SXM4-40GB : 8 + + + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + + + def forward(self, add): + _foobar = torch.ops.aten._foobar.default(add); add = None + return (_foobar,) + + args = [((200, 200), (200, 1), torch.float32, 'cpu')] + args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args] + mod = make_fx(Repro())(*args) + from torch._inductor.compile_fx import compile_fx_inner + + compiled = compile_fx_inner(mod, args) + compiled(*args) + +The ``forward`` method of the ``Repro`` module contains the exact op +which causes the issue. When filing an issue, please include any +minified repros to aid in debugging. + +Minifying Backend Compiler Errors +--------------------------------- + +With backend compilers other than TorchInductor the process for finding +the subgraph causing the error is nearly identical to the procedure in +`errors in TorchInductor <#torchinductor-errors>`__ with one important +caveat. Namely, that the minifier will now be run on the graph that is +traced by TorchDynamo, not the output graph of AOTAutograd. Let’s walk +through an example. + +.. code:: py + + import torch + + import torch._dynamo as dynamo + + model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) + # toy compiler which fails if graph contains relu + def toy_compiler(gm: torch.fx.GraphModule, _): + for node in gm.graph.nodes: + if node.target == torch.relu: + assert False + + return gm + + + @dynamo.optimize(toy_compiler) + def test_backend_error(): + y = torch.ones(200, 200) + x = torch.ones(200, 200) + z = x + y + a = torch.relu(z) + return model(a) + + + test_backend_error() + +In order to run the code after TorchDynamo has traced the forward graph, +the TORCHDYNAMO_REPRO_AFTER enviornment variable can be used. Running +this program with TORCHDYNAMO_REPRO_AFTER=“dynamo” (or +``torch._dynamo.config.repro_after="dynamo"``) should produce `this +output `__\ and +the following code in ``{torch._dynamo.config.base_dir}/repro.py``. +Note: the other option for TORCHDYNAMO_REPRO_AFTER are ``"aot"``, which +will run the minifier after the backward graph has been generated. + +.. code:: py + + import torch + import torch._dynamo as dynamo + from torch import tensor, device + import torch.fx as fx + from torch._dynamo.testing import rand_strided + from math import inf + from torch._dynamo.debug_utils import run_fwd_maybe_bwd + + + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + + + def forward(self, add): + relu = torch.relu(add); add = None + return (relu,) + + + mod = Repro().cuda() + opt_mod = dynamo.optimize("None")(mod) + + + args = [((200, 200), (200, 1), torch.float32, 'cpu', False)] + args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] + + + with torch.cuda.amp.autocast(enabled=False): + ref = run_fwd_maybe_bwd(mod, args) + res = run_fwd_maybe_bwd(opt_mod, args) + +The minifier successfully reduced the graph to the op that raises the +error in ``toy_compiler``. The other difference from the procedure in +`TorhInductor Errors <#torchinductor-errors>`__ is that the minifier is +automatically run after encountering a backend compiler error. After a +successful run, the minifier writes ``repro.py`` to +``torch._dynamo.config.base_dir``. + +Performance Profiling +~~~~~~~~~~~~~~~~~~~~~ + +Accessing TorchDynamo Profiler +------------------------------ + +TorchDynamo has a builtin stats function for collecting and displaying +the time spent in each compilation phase. These stats can be accessed by +calling ``torch._dynamo.utils.compile_times()`` after executing +Torch._Dynamo. By default, this returns a string representation of the +compile times spent in each TorchDynamo function by name. + +TorchInductor Debug Tracing +--------------------------- + +TorchInductor has a builtin stats and trace function for displaying time +spent in each compilation phase, output code, output graph visualization +and IR dump. This is a debugging tool designed to make it easier to +debug/understand the internals of TorchInductor. + +Setting the environment variable ``TORCHINDUCTOR_TRACE=1`` will cause a +debug trace directory to be created and printed: + +:: + + $ env TORCHINDUCTOR_TRACE=1 python repro.py + torch._inductor.debug: [WARNING] model_forward_0 debug trace: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.debug + +Here is an `example debug directory +output `__ +for the test program: + +:: + + torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.LayerNorm(10), + torch.nn.ReLU(), + ) + +Note each file in that debug trace can be enabled/disabled via +``torch._inductor.config.trace.*``. The profile and the diagram are both +disabled by default since they are expensive to generate. + +A single node in this new debug format looks like: + +:: + + buf1: SchedulerNode(ComputedBuffer) + buf1.writes = + { MemoryDep(name='buf1', index=0, size=()), + MemoryDep(name='buf1', index=0, size=(s0,))} + buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} + buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} + buf1.group.device = cuda:0 + buf1.group.iteration = (1, s0) + buf1.sizes = ([], [s0]) + class buf1_loop_body: + var_ranges = {z0: s0} + index0 = z0 + index1 = 0 + def body(self, ops): + get_index = self.get_index('index0') + load = ops.load('buf0', get_index, False) + get_index_1 = self.get_index('index0') + load_1 = ops.load('primals_2', get_index_1, False) + add = ops.add(load, load_1) + get_index_2 = self.get_index('index1') + reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add) + return reduction + +See the `example debug directory +output `__ +for more examples. + +Memory Profiling +---------------- + +TBD + +Graph Breaks +------------ + +Given a program like this: + +.. code-block:: python + + @dynamo.optimize(...) + def some_fun(x): + ... + some_fun(x) + ... + +TorchDynamo will attempt to compile all of the torch/tensor operations +within some_fun into a single FX graph, but it may fail to capture +everything into one graph. + +Some graph break reasons are insurmountable to TorchDynamo, and can’t be +easily fixed. - calling into a C extension other than torch is invisible +to torchdynamo, and could do arbitrary things without TorchDynamo being +able to introduce necessary `guards <./GuardsOverviewPt1.md>`__ to +ensure that the compiled program would be safe to reuse. Graph breaks +can hinder performance if the resulting fragments are small. To maximize +performance, it’s important to have as few graph breaks as possible. + +Identifying the cause of a graph break +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To identify all graph breaks in a program and the associated reasons for +the breaks, ``torch._dynamo.explain`` can be used. This tool runs +TorchDynamo on the supplied function and aggregates the graph breaks +that are encountered. Here is an example usage: + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + print("woo") + if b.sum() < 0: + b = b * -1 + return x * b + explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10)) + print(explanation) + """ + Dynamo produced 3 graphs, with 2 graph break and 6 ops. + Break reasons: + 1. call_function BuiltinVariable(print) [ConstantVariable(str)] {} + File "t2.py", line 16, in toy_example + print("woo") + + 2. generic_jump + File "t2.py", line 17, in toy_example + if b.sum() < 0: + """ + +Note on other outputs: - ``out_guards`` - a list of lists where each +sublist contains the guards that must pass to ensure the traced graphs +are valid - ``graphs`` - a list of graph modules which were successfully +traced - ``ops_per_graph`` - a list of lists where each sublist contains +the ops thatare run in the graph + +To throw an error on the first graph break encountered, ``nopython`` +mode can be used. This disables TorchDynamo’s python fallback, and only +succeeds if the entire program is convertible to a single graph. Example +usage: + +.. code-block:: python + + @dynamo.optimize(, nopython=True) + def toy_example(a, b): + ... + +Excessive Recompilation +----------------------- + +When TorchDynamo compiles a function (or part of one), it makes certain +assumptions about locals and globals in order to allow compiler +optimizations, and expresses these assumptions as guards that check +particular values at runtime. If any of these guards fail, Dynamo will +recompile that function (or part) up to +``torch._dynamo.config.cache_size_limit`` times. If your program is +hitting the cache limit, you will first need to determine which guard is +failing and what part of your program is triggering it. + +The `recompilation profiler <#recompilation-profiler>`__ automates the +process of setting TorchDynamo’s cache limit to 1 and running your +program under an observation-only ‘compiler’ that records the causes of +any guard failures. You should be sure to run your program for at least +as long (as many iterations) as you were running when you ran into +trouble, and the profiler will accumulate statistics over this duration. + +If your program exhibits a bounded amount of dynamism, you may be able +to tune the TorchDynamo cache limit to allow for each variation to be +compiled and cached, but if the cache limit is too high you may find the +cost of recompilation outweighs any optimization benefits. + +:: + + torch._dynamo.config.cache_size_limit = + +Torchdynamo plans to support many common cases of dynamic tensor shapes, +such as varying batch size or sequence length. It does not plan to +support rank-dynamism. In the mean time, setting a specific cache limit +can be used in coordination with bucketing techniques to achieve an +acceptable number of recompilations for some dynamic models. + +.. code-block:: python + + prof = dynamo.utils.CompilationProfiler() + @dynamo.optimize(prof) + def my_model(): + ... + my_model() + print(prof.report()) + +Accuracy Debugging +~~~~~~~~~~~~~~~~~~ + +Accuracy issues can also be minified if you set the environment variable +``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect +model and a full repro might be something like +``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason +we need this is downstream compilers will codegen code whether it’s +Triton code or the C++ backend, the numerics from those downstream +compilers can be different in subtle ways yet have dramatic impact on +your training stability. So the accuracy debugger is very useful for us +to detect bugs in our codegen or with a backend compiler. + +File an Issue +~~~~~~~~~~~~~ + +You should feel encouraged to `file a github +issue `__ and expect a +timely response. + +Before filing an issue, read over the `README <../README.md>`__, +`TROUBLESHOOTING <./TROUBLESHOOTING.md>`__, and search for similar +issues. + +When filing an issue, please include - your +OS/python/pytorch/CUDA/triton info by running: + +.. code-block:: sh + + python tools/verify_install.py + +- A minimal repro script if possible, which can be generated by running + Minifier +- A description of the error +- the expected behavior +- A log (set ``torch._dynamo.config.log_file`` to a valid file name to + dump the logs to a file and + ``torch._dynamo.config.log_level = logging.DEBUG`` and + ``torch._dynamo.config.verbose = True``) diff --git a/docs/source/index.rst b/docs/source/index.rst index e4b6a124d6b..e43160f668f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,13 @@ Features described in this documentation are classified by release status: notes/* +.. toctree:: + :glob: + :maxdepth: 1 + :caption: torch.compile + + dynamo/* + .. toctree:: :maxdepth: 1 :caption: Language Bindings