mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Move Dynamo docs back to core (#89769)
With contributions from @svekars and @malfet Waiting for doc build job to complete Pull Request resolved: https://github.com/pytorch/pytorch/pull/89769 Approved by: https://github.com/soumith, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b522670d2
commit
9048cf16fe
BIN
docs/source/_static/img/dynamo/TorchDynamo.png
Normal file
BIN
docs/source/_static/img/dynamo/TorchDynamo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 341 KiB |
BIN
docs/source/_static/img/dynamo/td_stack.png
Normal file
BIN
docs/source/_static/img/dynamo/td_stack.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 301 KiB |
BIN
docs/source/_static/img/dynamo/torchinductor_backend.png
Normal file
BIN
docs/source/_static/img/dynamo/torchinductor_backend.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 120 KiB |
154
docs/source/dynamo/custom-backends.rst
Normal file
154
docs/source/dynamo/custom-backends.rst
Normal file
@@ -0,0 +1,154 @@
|
||||
Custom Backends
|
||||
===============
|
||||
|
||||
Debugging Backend
|
||||
-----------------
|
||||
|
||||
Suppose you wanted to better understand what is going on during a
|
||||
compilation you can create a custom compiler which we’ll refer to as a
|
||||
backend that will print pretty print the fx ``GraphModule`` extracted
|
||||
from dynamo’s bytecode analysis and return a ``forward()`` callable.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
return gm.forward # return a python callable
|
||||
@dynamo.optimize(my_compiler)
|
||||
def fn(x, y):
|
||||
a = torch.cos(x)
|
||||
b = torch.sin(y)
|
||||
return a + b
|
||||
fn(torch.randn(10), torch.randn(10))
|
||||
|
||||
Running the above example produces the following output:
|
||||
|
||||
::
|
||||
|
||||
my_compiler() called with FX graph:
|
||||
opcode name target args kwargs
|
||||
------------- ------ ------------------------------------------------------ ---------- --------
|
||||
placeholder x x () {}
|
||||
placeholder y y () {}
|
||||
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
|
||||
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
|
||||
call_function add <built-in function add> (cos, sin) {}
|
||||
output output output ((add,),) {}
|
||||
|
||||
This works for ``torch.nn.Module`` as well as shown below
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
def forward(self, x):
|
||||
return self.relu(torch.cos(x))
|
||||
mod = MockModule()
|
||||
optimized_mod = dynamo.optimize(my_compiler)(mod)
|
||||
optimized_mod(torch.randn(10))
|
||||
|
||||
Let’s take a look at one more example with control flow.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
return gm.forward # return a python callable
|
||||
@dynamo.optimize(my_compiler)
|
||||
def toy_example(a, b):
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
for _ in range(100):
|
||||
toy_example(torch.randn(10), torch.randn(10))
|
||||
|
||||
Running this example produces the following output:
|
||||
|
||||
::
|
||||
|
||||
my_compiler() called with FX graph:
|
||||
opcode name target args kwargs
|
||||
------------- ------- ------------------------------------------------------ ---------------- --------
|
||||
placeholder a a () {}
|
||||
placeholder b b () {}
|
||||
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
|
||||
call_function add <built-in function add> (abs_1, 1) {}
|
||||
call_function truediv <built-in function truediv> (a, add) {}
|
||||
call_method sum_1 sum (b,) {}
|
||||
call_function lt <built-in function lt> (sum_1, 0) {}
|
||||
output output output ((truediv, lt),) {}
|
||||
|
||||
my_compiler() called with FX graph:
|
||||
opcode name target args kwargs
|
||||
------------- ------ ----------------------- ----------- --------
|
||||
placeholder b b () {}
|
||||
placeholder x x () {}
|
||||
call_function mul <built-in function mul> (b, -1) {}
|
||||
call_function mul_1 <built-in function mul> (x, mul) {}
|
||||
output output output ((mul_1,),) {}
|
||||
|
||||
my_compiler() called with FX graph:
|
||||
opcode name target args kwargs
|
||||
------------- ------ ----------------------- --------- --------
|
||||
placeholder b b () {}
|
||||
placeholder x x () {}
|
||||
call_function mul <built-in function mul> (x, b) {}
|
||||
output output output ((mul,),) {}
|
||||
|
||||
The order of the last two graphs is nondeterministic depending
|
||||
on which one is encountered first by the just-in-time compiler.
|
||||
|
||||
Speedy Backend
|
||||
--------------
|
||||
|
||||
Integrating a custom backend that offers superior performance is also
|
||||
easy and we’ll integrate a real one
|
||||
with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__:
|
||||
|
||||
.. code-block :: python
|
||||
|
||||
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
scripted = torch.jit.trace(gm, example_inputs)
|
||||
return torch.jit.optimize_for_inference(scripted)
|
||||
|
||||
And then you should be able to optimize any existing code with
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dynamo.optimize(optimize_for_inference_compiler)
|
||||
def code_to_accelerate():
|
||||
...
|
||||
|
||||
Composable Backends
|
||||
-------------------
|
||||
|
||||
TorchDynamo includes many backends, which can be found in
|
||||
`backends.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py>`__
|
||||
or ``torchdynamo.list_backends()``. You can combine these backends
|
||||
together with the following code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch._dynamo.optimizations import BACKENDS
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
trt_compiled = BACKENDS["tensorrt"](gm, example_inputs)
|
||||
if trt_compiled is not None:
|
||||
return trt_compiled
|
||||
# first backend failed, try something else...
|
||||
cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs)
|
||||
if cudagraphs_compiled is not None:
|
||||
return cudagraphs_compiled
|
||||
return gm.forward
|
||||
145
docs/source/dynamo/deep-dive.rst
Normal file
145
docs/source/dynamo/deep-dive.rst
Normal file
@@ -0,0 +1,145 @@
|
||||
TorchDynamo Deeper Dive
|
||||
=======================
|
||||
**Author**: `Jason Ansel <https://github.com/jansel>`_
|
||||
|
||||
What is a guard?
|
||||
----------------
|
||||
|
||||
TorchDynamo operates just-in-time and specializes graphs based on
|
||||
dynamic properties. For example, the first graph above has the following
|
||||
guards:
|
||||
|
||||
::
|
||||
|
||||
GUARDS:
|
||||
- local 'a' TENSOR_MATCH
|
||||
- local 'b' TENSOR_MATCH
|
||||
- global 'torch' FUNCTION_MATCH
|
||||
|
||||
If any of those guards fail, the graph will be recaptured and
|
||||
recompiled. The interesting guard type there is ``TENSOR_MATCH``, which
|
||||
checks the following torch.Tensor properties:
|
||||
|
||||
- Python class of the tensor (tensor subclassing, etc)
|
||||
- dtype
|
||||
- device
|
||||
- requires_grad
|
||||
- dispatch_key (with thread-local includes/excludes applied)
|
||||
- ndim
|
||||
- sizes\* (optional)
|
||||
- strides\* (optional)
|
||||
|
||||
For sizes/strides you can disable this specialization by setting the
|
||||
following parameter:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
torch._dynamo.config.dynamic_shapes = True
|
||||
|
||||
The full specialization mode allows the backend compiler to assume an
|
||||
entirely static graph. Unfortunately, most backends require this.
|
||||
Operators which return dynamic shapes will trigger a graph break when
|
||||
not in dynamic shape mode.
|
||||
|
||||
What is dynamo doing?
|
||||
---------------------
|
||||
|
||||
If you want to understand better what TorchDynamo is doing, you can set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
torchdynamo.config.debug = True
|
||||
|
||||
which triggers useful (but spammy) printouts.
|
||||
|
||||
For example, the printouts for the first graph in the ``toy_example``
|
||||
above are:
|
||||
|
||||
::
|
||||
|
||||
__compiled_fn_0 <eval_with_key>.1
|
||||
opcode name target args kwargs
|
||||
------------- ------- ------------------------------------------------------ ---------------- --------
|
||||
placeholder a a () {}
|
||||
placeholder b b () {}
|
||||
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
|
||||
call_function add <built-in function add> (abs_1, 1) {}
|
||||
call_function truediv <built-in function truediv> (a, add) {}
|
||||
call_method sum_1 sum (b,) {}
|
||||
call_function lt <built-in function lt> (sum_1, 0) {}
|
||||
output output output ((truediv, lt),) {}
|
||||
|
||||
ORIGINAL BYTECODE toy_example example.py 9
|
||||
10 0 LOAD_FAST 0 (a)
|
||||
2 LOAD_GLOBAL 0 (torch)
|
||||
4 LOAD_METHOD 1 (abs)
|
||||
6 LOAD_FAST 0 (a)
|
||||
8 CALL_METHOD 1
|
||||
10 LOAD_CONST 1 (1)
|
||||
12 BINARY_ADD
|
||||
14 BINARY_TRUE_DIVIDE
|
||||
16 STORE_FAST 2 (x)
|
||||
|
||||
11 18 LOAD_FAST 1 (b)
|
||||
20 LOAD_METHOD 2 (sum)
|
||||
22 CALL_METHOD 0
|
||||
24 LOAD_CONST 2 (0)
|
||||
26 COMPARE_OP 0 (<)
|
||||
28 POP_JUMP_IF_FALSE 38
|
||||
|
||||
12 30 LOAD_FAST 1 (b)
|
||||
32 LOAD_CONST 3 (-1)
|
||||
34 BINARY_MULTIPLY
|
||||
36 STORE_FAST 1 (b)
|
||||
|
||||
13 >> 38 LOAD_FAST 2 (x)
|
||||
40 LOAD_FAST 1 (b)
|
||||
42 BINARY_MULTIPLY
|
||||
44 RETURN_VALUE
|
||||
|
||||
MODIFIED BYTECODE
|
||||
9 0 LOAD_GLOBAL 3 (__compiled_fn_0)
|
||||
2 LOAD_FAST 0 (a)
|
||||
4 LOAD_FAST 1 (b)
|
||||
6 CALL_FUNCTION 2
|
||||
8 UNPACK_SEQUENCE 2
|
||||
10 STORE_FAST 2 (x)
|
||||
12 POP_JUMP_IF_FALSE 24
|
||||
14 LOAD_GLOBAL 4 (__resume_at_30_1)
|
||||
16 LOAD_FAST 1 (b)
|
||||
18 LOAD_FAST 2 (x)
|
||||
20 CALL_FUNCTION 2
|
||||
22 RETURN_VALUE
|
||||
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
|
||||
26 LOAD_FAST 1 (b)
|
||||
28 LOAD_FAST 2 (x)
|
||||
30 CALL_FUNCTION 2
|
||||
32 RETURN_VALUE
|
||||
|
||||
GUARDS:
|
||||
- local 'a' TENSOR_MATCH
|
||||
- local 'b' TENSOR_MATCH
|
||||
- global 'torch' FUNCTION_MATCH
|
||||
|
||||
At the top you can see the FX graph (which we already shared above).
|
||||
Next you see the original bytecode of the function, followed by the
|
||||
modified bytecode generated by TorchDynamo. Finally, you see the guards
|
||||
which we covered above.
|
||||
|
||||
In the modified bytecode ``__compiled_fn_0`` is the return value of
|
||||
``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and
|
||||
``__resume_at_38_2`` are both generated continuation functions that pick
|
||||
up execution after a graph break (at bytecode offsets 30 and 38). Each
|
||||
of these functions take the form:
|
||||
|
||||
::
|
||||
|
||||
__resume_at_<offset>:
|
||||
... restore stack state if needed ...
|
||||
JUMP_ABSOLUTE <offset> into toy_example
|
||||
... original bytecode of toy_example ...
|
||||
|
||||
By generating this `resume_at` function we force the remainder of the
|
||||
function to be executed in a new Python frame which recursively
|
||||
triggers TorchDynamo to restart its capture once execution reaches that
|
||||
point for the first time.
|
||||
376
docs/source/dynamo/faq.rst
Normal file
376
docs/source/dynamo/faq.rst
Normal file
@@ -0,0 +1,376 @@
|
||||
Frequently Asked Questions
|
||||
==========================
|
||||
|
||||
At a high level, the TorchDynamo stack consists of a graph capture from
|
||||
Python code using dynamo and a backend compiler. In this example the
|
||||
backend compiler consists of backward graph tracing using AOTAutograd
|
||||
and graph lowering using TorchInductor. There are of course many more
|
||||
compilers available `here <https://github.com/pytorch/torchdynamo/blob/0b8aaf340dad4777a080ef24bf09623f1aa6f3dd/README.md#existing-backend>`__
|
||||
but for this document we will focus on inductor as a motivating example.
|
||||
|
||||
Torchdynamo supports training, using AotAutograd to capture backwards:
|
||||
|
||||
1. the ``.forward()`` graph and ``optimizer.step()`` is captured by torchdynamo’s python evalframe frontend
|
||||
2. for each segment of ``.forward()`` that torchdynamo captures, it uses AotAutograd to generate a backward graph segment
|
||||
3. each pair of forward, backward graph are (optionally) min-cut partitioned to save the minimal state between forward/backward
|
||||
4. the forward, backward pairs are wrapped in autograd.function modules 5. usercode calling\ ``.backward()`` still triggers eager’s autograd engine, which runs each ‘compiled backward’ graph as if it were one op, also running any non-compiled eager ops’ .backward() functions
|
||||
|
||||
Do you support Distributed code?
|
||||
--------------------------------
|
||||
|
||||
DDP has been tested and works, support for other distributed training
|
||||
libraries is under discussion.
|
||||
|
||||
The main reason why Distributed code is challenging with dynamo is
|
||||
because AOTAutograd unrolls both the forward and backward pass and
|
||||
provides 2 graphs for backends to optimize. This is a problem for
|
||||
distributed code because we’d like to ideally overlap communication
|
||||
operations with computations. Eager pytorch accomplishes this in
|
||||
different ways for DDP/FSDP- using autograd hooks, module hooks, and
|
||||
modifications/mutations of module states. In a naive application of
|
||||
dynamo, hooks that should run directly after an operation during
|
||||
backwards may be delayed until after the entire compiled region of
|
||||
backwards ops, due to how AOTAutograd compiled functions interact with
|
||||
dispatcher hooks.
|
||||
|
||||
The basic strategy for optimizing DDP with Dynamo is outlined in
|
||||
`distributed.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/distributed.py>`__
|
||||
where the main idea will be to graph break on `DDP bucket
|
||||
boundaries <https://pytorch.org/docs/stable/notes/ddp.html#internal-design>`__.
|
||||
|
||||
When each node in DDP needs to synchronize its weights with the other
|
||||
nodes it organizes its gradients and parameters into buckets which
|
||||
reduces communication times and allows a node to broadcast a fraction of
|
||||
its gradients to other waiting nodes.
|
||||
|
||||
Graph breaks in distributed code means you can expect dynamo and its
|
||||
backends to optimize the compute overhead of a distributed program but
|
||||
not its communication overhead. Graph-breaks may interfere with
|
||||
compilation speedups, if the reduced graph-size robs the compiler of
|
||||
fusion opportunities. However, there are diminishing returns with
|
||||
increasing graph size since most of the current compute optimizations
|
||||
are local fusions. So in practice this approach may be sufficient.
|
||||
|
||||
Do I still need to export whole graphs?
|
||||
---------------------------------------
|
||||
|
||||
For the vast majority of models you probably don’t and you can use
|
||||
``torch._dynamo()`` optimize as is but there are a few situations where
|
||||
full graphs are necessary and you can can ensure a full graph by simply
|
||||
running ``torch.dynamo(..., nopython=True)`` \* Large scale training
|
||||
runs, think $250K+ that require pipeline parallelism and other advanced
|
||||
sharding strategies \* Inference optimizers like
|
||||
`TensorRT <https://github.com/pytorch/TensorRT>`__ or
|
||||
`AITemplate <https://github.com/facebookincubator/AITemplate>`__ that rely
|
||||
on fusing much more aggressively than training optimizers \* Mobile training or
|
||||
inference.
|
||||
|
||||
Future work will include tracing communication operations into graphs,
|
||||
coordinating these operations with compute optimizations, and optimizing
|
||||
the communciation operations.
|
||||
|
||||
Why is my code crashing?
|
||||
------------------------
|
||||
|
||||
If your code ran just fine without dynamo and started to crash with it
|
||||
enabled then the most important first step is figuring out which part of
|
||||
the stack your failure occurred in so try running things in the below
|
||||
order and only try the next step if the previous step succeeded.
|
||||
|
||||
1. ``dynamo.optimize("eager")`` which only runs torchdynamo forward graph
|
||||
capture and then runs the captured graph with PyTorch. If this fails
|
||||
then there’s an issue with TorchDynamo.
|
||||
|
||||
2. ``dynamo.optimize("aot_eager")``
|
||||
which runs torchdynamo to capture a forward graph, and then AOTAutograd
|
||||
to trace the backward graph without any additional backend compiler
|
||||
steps. PyTorch eager will then be used to run the forward and backward
|
||||
graphs. If this fails then there’s an issue with AOTAutograd.
|
||||
|
||||
3. ``dynamo.optimize("inductor")`` which runs torchdynamo to capture a
|
||||
forward graph, and then AOTAutograd to trace the backward graph with the
|
||||
TorchInductor compiler. If this fails then there’s an issue with TorchInductor
|
||||
|
||||
TorchDynamo Errors
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If the error that is generated occurs with the ``"eager"`` backend, then
|
||||
torchdynamo is the most likely source of the error.
|
||||
|
||||
To debug these issues we recommend setting
|
||||
``torch._dynamo.config.verbose=True`` to get a full stack trace to both
|
||||
the error in torchdynamo and the user code. In addition to this flag,
|
||||
you can also set the ``log_level`` of torchdynamo through
|
||||
``torch._dynamo.config.log_level``. The available levels are the
|
||||
following: - ``logging.DEBUG``: Print every instruction that is
|
||||
encountered in addition to all below log levels - ``logging.INFO``:
|
||||
Print each function that is compiled (original and modified bytecode)
|
||||
and the graph that is captured in addition to all below log levels -
|
||||
``logging.WARNING`` (default): Print graph breaks in addition to all
|
||||
below log levels - ``logging.ERROR``: Print errors only
|
||||
|
||||
If a model is sufficiently large, the logs can become overwhelming. If
|
||||
an error occurs deep within a model’s python code, it can be useful to
|
||||
execute only the frame in which the error occurs to enable easier
|
||||
debugging. There are 2 tools available to enable this:
|
||||
|
||||
* ``env TORCHDYNAMO_DEBUG_FUNCTION=<desired_function_name>`` will only run TorchDynamo on functions with that name.
|
||||
|
||||
* ``env torch._dynamo.config.replay_record_enabled = True``) which dumps an execution record when an error is encountered. This record can then be replayed to run only the frame where an error occurred.
|
||||
|
||||
TorchInductor Errors
|
||||
--------------------
|
||||
|
||||
With TorchInductor as the chosen backend, AOTAutograd is used to
|
||||
generate the backward graph from the forward graph captured by
|
||||
torchdynamo. It’s important to note that errors can occur during this
|
||||
tracing and also while TorchInductor lowers the forward and backward
|
||||
graphs to GPU code or C++.
|
||||
|
||||
A model can often consist of hundreds or thousands of FX nodes, so
|
||||
narrowing the exact nodes where this problem occurred can be very
|
||||
difficult which is why we highly recommend you use our minifier to
|
||||
create tiny reproducible examples of failures you’re seeing. We can
|
||||
minify errors that occur either at the AOTAutograd layer or Inductor
|
||||
layer which you should try in the following order.
|
||||
|
||||
1. ``env TORCHDYNAMO_REPRO_AFTER="aot" python your_model.py``
|
||||
2. ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py``
|
||||
|
||||
Minifying your error is the quickest path to getting it fixed.
|
||||
|
||||
The minifier will actually create a ``repro.py`` for you at the location
|
||||
set by ``env TORCHDYNAMO_REPRO_DIR`` so make you have right access to
|
||||
that directory. You can then run ``python repro.py`` and confirm that
|
||||
you are getting the same error.
|
||||
|
||||
.. note::
|
||||
For other compilers such as nvfuser, the process is similar but
|
||||
instead you would leverage ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py``.
|
||||
|
||||
Why is compilation slow?
|
||||
------------------------
|
||||
|
||||
Dynamo Compilation
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
TorchDynamo has a builtin stats function for collecting and displaying
|
||||
the time spent in each compilation phase. These stats can be accessed by
|
||||
calling ``torch._dynamo.utils.compile_times()`` after executing
|
||||
``torch._dynamo``. By default, this returns a string representation of
|
||||
the compile times spent in each TorchDynamo function by name.
|
||||
|
||||
Inductor Compilation
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
TorchInductor has a builtin stats and trace function for displaying time
|
||||
spent in each compilation phase, output code, output graph visualization
|
||||
and IR dump. ``env TORCHINDUCTOR_TRACE=1 python repro.py``. This is a
|
||||
debugging tool designed to make it easier to debug/understand the
|
||||
internals of TorchInductor with an output that will look something like
|
||||
`this <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
|
||||
|
||||
Each file in that debug trace can be enabled/disabled via
|
||||
``torch._inductor.config.trace.*``. The profile and the diagram are both
|
||||
disabled by default since they are expensive to generate. See the
|
||||
`example debug directory
|
||||
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
|
||||
for more examples.
|
||||
|
||||
Excessive Recompilation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
When TorchDynamo compiles a function (or part of one), it makes certain
|
||||
assumptions about locals and globals in order to allow compiler
|
||||
optimizations, and expresses these assumptions as guards that check
|
||||
particular values at runtime. If any of these guards fail, Dynamo will
|
||||
recompile that function (or part) up to
|
||||
``torch._dynamo.config.cache_size_limit`` times. If your program is
|
||||
hitting the cache limit, you will first need to determine which guard is
|
||||
failing and what part of your program is triggering it.
|
||||
|
||||
The `recompilation profiler <#recompilation-profiler>`__ automates the
|
||||
process of setting TorchDynamo’s cache limit to 1 and running your
|
||||
program under an observation-only ‘compiler’ that records the causes of
|
||||
any guard failures. You should be sure to run your program for at least
|
||||
as long (as many iterations) as you were running when you ran into
|
||||
trouble, and the profiler will accumulate statistics over this duration.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prof = dynamo.utils.CompilationProfiler()
|
||||
@dynamo.optimize(prof)
|
||||
def my_model():
|
||||
...
|
||||
my_model()
|
||||
print(prof.report())
|
||||
|
||||
Many of the reasons for graph breaks and excessive recompilation will be
|
||||
fixed with upcoming support for `tracing dynamic tensor
|
||||
shapes <https://docs.google.com/document/d/1QJB-GOnbv-9PygGlOMXwiO9K6vVNm8sNg_olixJ9koc/edit?usp=sharing>`__,
|
||||
more careful choices for guards and better tuned heuristics.
|
||||
|
||||
Why are you recompiling in production?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In some cases, you may not want unexpected compiles after a program has
|
||||
warmed up. For example, if you are serving production traffic in a
|
||||
latency critical application. For this, TorchDynamo provides an
|
||||
alternate mode where prior compiled graphs are used, but no new ones are
|
||||
generated:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
frozen_toy_example = dynamo.run(toy_example)
|
||||
frozen_toy_example(torch.randn(10), torch.randn(10))
|
||||
|
||||
How are you speeding up my code?
|
||||
--------------------------------
|
||||
|
||||
There are 3 major ways to accelerat PyTorch code:
|
||||
|
||||
1. Kernel fusion via vertical fusions which fuse sequential operations to avoid
|
||||
excessive read/writes. For example, fuse 2 subsequent cosines means you
|
||||
can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion:
|
||||
the simplest example being batching where a single matrix is multiplied
|
||||
with a batch of examples but the more general scenario is a grouped GEMM
|
||||
where a group of matrix multiplications are scheduled together
|
||||
|
||||
2. Out of order execution: A general optimization for compilers, by looking ahead
|
||||
at the exact data dependencies within a graph we can decide on the most
|
||||
opportune time to execute a node and which buffers can be reused
|
||||
|
||||
3. Automatic work placement: Similar of the out of order execution point,
|
||||
but by matching nodes of a graph to resources like physical hardware or
|
||||
memory we can design an appropriate schedule
|
||||
|
||||
The above are general principles for accelerating PyTorch code but
|
||||
different backends will each make different tradeoffs on what to
|
||||
optimize. For example Inductor first takes care of fusing whatever it
|
||||
can and only then generates `Triton <https://openai.com/blog/triton/>`__
|
||||
kernels. It can also
|
||||
|
||||
Triton in addition offers speedups because of automatic memory
|
||||
coalescing, memory management and scheduling within each Streaming
|
||||
Multiprocessor and has been designed to handle tiled computations.
|
||||
|
||||
However, regardless of the backend you use it’s best to use a benchmark
|
||||
and see approach so try out the PyTorch profiler, visually inspect the
|
||||
generated kernels and try to see what’s going on for yourself.
|
||||
|
||||
Why am I not seeing speedups?
|
||||
-----------------------------
|
||||
|
||||
Graph Breaks
|
||||
~~~~~~~~~~~~
|
||||
|
||||
The main reason you won’t see the speedups you’d like to by using dynamo
|
||||
is excessive graph breaks. So what’s a graph break?
|
||||
|
||||
Given a program like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dynamo.optimize(...)
|
||||
def some_fun(x):
|
||||
...
|
||||
some_fun(x)
|
||||
...
|
||||
|
||||
Torchdynamo will attempt to compile all of the torch/tensor operations
|
||||
within ``some_fun()`` into a single FX graph, but it may fail to capture
|
||||
everything into one graph.
|
||||
|
||||
Some graph break reasons are insurmountable to TorchDynamo like calling
|
||||
into a C extension other than torch is invisible to torchdynamo, and
|
||||
could do arbitrary things without TorchDynamo being able to introduce
|
||||
necessary guards to ensure that the compiled program would be safe to reuse.
|
||||
|
||||
To maximize performance, it’s important to have as few graph breaks
|
||||
as possible.
|
||||
|
||||
Identifying the cause of a graph break
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To identify all graph breaks in a program and the associated reasons for
|
||||
the breaks, ``torch._dynamo.explain`` can be used. This tool runs
|
||||
TorchDynamo on the supplied function and aggregates the graph breaks
|
||||
that are encountered. Here is an example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def toy_example(a, b):
|
||||
x = a / (torch.abs(a) + 1)
|
||||
print("woo")
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
|
||||
print(explanation)
|
||||
"""
|
||||
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
|
||||
Break reasons:
|
||||
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
|
||||
File "t2.py", line 16, in toy_example
|
||||
print("woo")
|
||||
|
||||
2. generic_jump
|
||||
File "t2.py", line 17, in toy_example
|
||||
if b.sum() < 0:
|
||||
"""
|
||||
|
||||
To throw an error on the first graph break encountered you can use
|
||||
disable python fallback by using ``nopython=True``, this should be
|
||||
familiar if you’ve worked with export based compilers.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dynamo.optimize(<compiler>, nopython=True)
|
||||
def toy_example(a, b):
|
||||
...
|
||||
|
||||
Why didn’t my code recompile when I changed it?
|
||||
-----------------------------------------------
|
||||
|
||||
If you went ahead and enabled dynamic shapes via
|
||||
``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code
|
||||
won’t recompile on shape changes. We’ve added support for dynamic shapes
|
||||
which avoids recompilations in the case when shapes vary by less than a
|
||||
factor of 2. This is especially useful in scenarios like varying image
|
||||
sizes in CV or variable sequence length in NLP. In inference scenarios
|
||||
it’s often not possible to know what a batch size will be beforehand
|
||||
because you take what you can get from different client apps.
|
||||
|
||||
In general, TorchDynamo tries very hard not to recompile things
|
||||
unnecessarily so if for example torchdynamo finds 3 graphs and your
|
||||
change only modified one graph then only that graph will recompile. So
|
||||
another tip to avoid potentially slow compilation times is to warmup a
|
||||
model by compiling it once after which subsequent compilations will be
|
||||
much faster. Cold start compile times is still a metric we track
|
||||
visibly.
|
||||
|
||||
Why am I getting incorrect results?
|
||||
-----------------------------------
|
||||
|
||||
Accuracy issues can also be minified if you set the environment variable
|
||||
``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect
|
||||
model and a full repro might be something like
|
||||
``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason
|
||||
we need this is downstream compilers will codegen code whether it’s
|
||||
Triton code or the C++ backend, the numerics from those downstream
|
||||
compilers can be different in subtle ways yet have dramatic impact on
|
||||
your training stability. So the accuracy debugger is very useful for us
|
||||
to detect bugs in our codegen or with a backend compiler.
|
||||
|
||||
Why am I getting OOMs?
|
||||
----------------------
|
||||
|
||||
Dynamo is still an alpha product so there’s a few sources of OOMs and if
|
||||
you’re seeing an OOM try disabling the following configurations in this
|
||||
order and then open an issue on Github so we can solve the root problem
|
||||
1. If you’re using dynamic shapes try disabling them, we’ve disabled
|
||||
them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
|
||||
CUDA graphs with Triton are enabled by default in inductor but removing
|
||||
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.
|
||||
181
docs/source/dynamo/get-started.rst
Normal file
181
docs/source/dynamo/get-started.rst
Normal file
@@ -0,0 +1,181 @@
|
||||
Getting Started
|
||||
===============
|
||||
|
||||
Let’s start with a simple example and make things more complicated step
|
||||
by step. Please note that you’re likely to see more significant speedups
|
||||
the newer your GPU is.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from torch._dynamo import optimize
|
||||
import torch
|
||||
def fn(x, y):
|
||||
a = torch.cos(x).cuda()
|
||||
b = torch.sin(y).cuda()
|
||||
return a + b
|
||||
new_fn = optimize("inductor")(fn)
|
||||
input_tensor = torch.randn(10000).to(device="cuda:0")
|
||||
a = new_fn()
|
||||
|
||||
This example will not actually run faster. Its purpose is to demonstrate
|
||||
the ``torch.cos()`` and ``torch.sin()`` features which are
|
||||
examples of pointwise ops as in they operate element by element on a
|
||||
vector. A more famous pointwise op you might actually want to use would
|
||||
be something like ``torch.relu()``. Pointwise ops in eager mode are
|
||||
suboptimal because each one would need to need to read a tensor from
|
||||
memory, make some changes and then write back those changes. The single
|
||||
most important optimization that inductor does is fusion. So back to our
|
||||
example we can turn 2 reads and 2 writes into 1 read and 1 write which
|
||||
is crucial especially for newer GPUs where the bottleneck is memory
|
||||
bandwidth (how quickly you can send data to a GPU) instead of compute
|
||||
(how quickly your GPU can crunch floating point operations)
|
||||
|
||||
Another major optimization that inductor makes available is automatic
|
||||
support for CUDA graphs.
|
||||
CUDA graphs help eliminate the overhead from launching individual
|
||||
kernels from a python program which is especially relevant for newer GPUs.
|
||||
|
||||
dynamo supports many different backends but inductor specifically works
|
||||
by generating `Triton <https://github.com/openai/triton>`__ kernels and
|
||||
we can inspect them by running ``TORCHINDUCTOR_TRACE=1 python trig.py``
|
||||
with the actual generated kernel being
|
||||
|
||||
.. code:: python
|
||||
|
||||
@pointwise(size_hints=[16384], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
||||
xnumel = 10000
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
|
||||
xmask = xindex < xnumel
|
||||
x0 = xindex
|
||||
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
||||
tmp1 = tl.sin(tmp0)
|
||||
tmp2 = tl.sin(tmp1)
|
||||
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
|
||||
|
||||
And you can verify that fusing the two ``sins`` did actually occur
|
||||
because the two ``sin`` operations occur within a single Triton kernel
|
||||
and the temporary variables are held in registers with very fast access.
|
||||
|
||||
You can read up a lot more on Triton’s performance
|
||||
`here <https://openai.com/blog/triton/>`__ but the key is it’s in python
|
||||
so you can easily understand it even if you haven’t written all that
|
||||
many CUDA kernels.
|
||||
|
||||
As a next step let’s try a real model like resnet50 from the PyTorch
|
||||
hub.
|
||||
|
||||
.. code:: python
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
|
||||
opt_model = dynamo.optimize("inductor")(model)
|
||||
model(torch.randn(1,3,64,64))
|
||||
|
||||
And that’s not the only available backend, you can run in a REPL
|
||||
``dynamo.list_backends()`` to see all the available ones. Try out the
|
||||
``aot_cudagraphs`` or ``nvfuser`` next as inspiration.
|
||||
|
||||
Let’s do something a bit more interesting now, our community frequently
|
||||
uses pretrained models from
|
||||
`transformers <https://github.com/huggingface/transformers>`__ or
|
||||
`TIMM <https://github.com/rwightman/pytorch-image-models>`__ and one of
|
||||
our design goals is for dynamo and inductor to work out of the box with
|
||||
any model that people would like to author.
|
||||
|
||||
So we’re going to directly download a pretrained model from the
|
||||
HuggingFace hub and optimize it:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import torch
|
||||
from transformers import BertTokenizer, BertModel
|
||||
import torch._dynamo as dynamo
|
||||
# Copy pasted from here https://huggingface.co/bert-base-uncased
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
|
||||
model = dynamo.optimize("inductor")(model) # This is the only line of code that we changed
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
|
||||
output = model(**encoded_input)
|
||||
|
||||
If you remove the ``to(device="cuda:0")`` from the model and
|
||||
encoded_input then triton will generate C++ kernels that will be
|
||||
optimized for running on your CPU. You can inspect both Triton or C++
|
||||
kernels for BERT, they’re obviously more complex than the trigonometry
|
||||
example we had above but you can similarly skim it and understand if you
|
||||
understand PyTorch.
|
||||
|
||||
Similarly let’s try out a TIMM example
|
||||
|
||||
.. code:: python
|
||||
|
||||
import timm
|
||||
import torch._dynamo as dynamo
|
||||
import torch
|
||||
model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2)
|
||||
opt_model = dynamo.optimize("inductor")(model)
|
||||
opt_model(torch.randn(64,3,7,7))
|
||||
|
||||
Our goal with dynamo and inductor was to build the highest coverage ML compiler which should work with any model you throw at it.
|
||||
|
||||
Existing Backends
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
TorchDynamo has a growing list of backends, which can be found in
|
||||
`backends.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py>`__
|
||||
or ``torchdynamo.list_backends()`` each of which with its optional dependencies.
|
||||
|
||||
Some of the most commonly used backend include:
|
||||
|
||||
* **Debugging backends**: \* ``dynamo.optimize("eager")`` - Uses PyTorch
|
||||
to run the extracted GraphModule. This is quite useful in debugging
|
||||
TorchDynamo issues. \* ``dynamo.optimize("aot_eager")`` - Uses
|
||||
AotAutograd with no compiler, i.e, just using PyTorch eager for the
|
||||
AotAutograd’s extracted forward and backward graphs. This is useful for
|
||||
debugging, and unlikely to give speedups.
|
||||
|
||||
* **Training & inference backends**: \* ``dynamo.optimize("inductor")`` -
|
||||
Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging
|
||||
codegened Triton kernels `Read
|
||||
more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
|
||||
|
||||
* ``dynamo.optimize("nvfuser")`` - nvFuser with TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
|
||||
* ``dynamo.optimize("aot_nvfuser")`` - nvFuser with AotAutograd. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
|
||||
* ``dynamo.optimize("aot_cudagraphs")`` - cudagraphs with AotAutograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
|
||||
|
||||
* **Inference-only backend**\ s: \* ``dynamo.optimize("ofi")`` - Uses
|
||||
Torchscript optimize_for_inference. `Read
|
||||
more <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__
|
||||
|
||||
* ``dynamo.optimize("fx2trt")`` - Uses Nvidia TensorRT for inferenc optimizations. `Read more <https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst>`__
|
||||
|
||||
* ``dynamo.optimize("onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__ \* ``dynamo.optimize("ipex")`` - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
|
||||
|
||||
Why do you need another way of optimizing PyTorch code?
|
||||
-------------------------------------------------------
|
||||
|
||||
While a number of other code optimization tools exist in the PyTorch
|
||||
ecosystem, each of them has its own flow. Here is a few examples of
|
||||
existing methods and their limitations:
|
||||
|
||||
- ``torch.jit.trace()`` is silently wrong if it cannot trace e.g:
|
||||
during control flow
|
||||
- ``torch.jit.script()`` requires modifications to user or library code
|
||||
by adding type annotations and removing non PyTorch code
|
||||
- ``torch.fx.symbolic_trace()`` either traces correctly or gives a hard
|
||||
error but it’s limited to traceable code so still can’t handle
|
||||
control flow
|
||||
- ``torch._dynamo`` works out of the box and produces partial graphs.
|
||||
It still has the option of producing a single graph with
|
||||
``nopython=True`` which are needed for `some
|
||||
situations <./documentation/FAQ.md#do-i-still-need-to-export-whole-graphs>`__
|
||||
but allows a smoother transition where partial graphs can be
|
||||
optimized without code modification
|
||||
|
||||
.. |image0| image:: ../_static/img/dynamo/TorchDynamo.png
|
||||
513
docs/source/dynamo/guards-overview.rst
Normal file
513
docs/source/dynamo/guards-overview.rst
Normal file
@@ -0,0 +1,513 @@
|
||||
Guards Overview
|
||||
===============
|
||||
|
||||
From a UX perspective, TorchDynamo is very easy to use. The user invokes
|
||||
``torchdynamo.optimize`` as an annotation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@torchdynamo.optimize(my_compiler)
|
||||
def fn_foo(bar):
|
||||
|
||||
Where a complete example looks like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torchdynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
return gm.forward # return a python callable
|
||||
@torchdynamo.optimize(my_compiler)
|
||||
def toy_example(a, b):
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
for _ in range(100):
|
||||
toy_example(torch.randn(10), torch.randn(10))
|
||||
|
||||
This allows TorchDynamo to capture the interpreted Python frames, grab
|
||||
any and all relevant information, and speed things up wherever it can.
|
||||
The speedup comes from a few places, and can be rather dependent on the
|
||||
backend (my_compiler above) provided, but the one speedup we care about
|
||||
most for today’s overview is **caching**. Caching itself is not a direct
|
||||
speedup, so much as a critical enablement to allow us to prevent
|
||||
recompilation. We dig a hole with dynamo, and caching allows us to get
|
||||
out. Its a speedup from that perspective, but relatively neutral when
|
||||
all things are considered - however, it enables us to hold perf
|
||||
neutrality while then enabling backends - the true source of our
|
||||
speedups.
|
||||
|
||||
With even a pass-through no-op backend provided:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
return gm.forward
|
||||
|
||||
We can see TorchDynamo speeding up Python execution quite a bit, even on
|
||||
regular Python, not just PyTorch.
|
||||
|
||||
Caching and Guards Overview
|
||||
---------------------------
|
||||
|
||||
TorchDynamo operates through caching transformed (by TorchDynamo) user
|
||||
bytecode. When we receive a frame for evaluation, we check if the
|
||||
**objects referenced in the frame have changed** in certain ways, and if
|
||||
not, we read the previously transformed user bytecode to evaluate it.
|
||||
The details of how we do this will be saved for a later writeup.
|
||||
Instead, we will focus on how we can identify whether or not the
|
||||
**objects referenced in the frame have changed**. This is a critical
|
||||
piece of functionality in TorchDynamo, because it drives the entire
|
||||
invalidation lifecycle. We refer to this functionality as **guards**.
|
||||
|
||||
At a very high level, the vastly oversimplified TLDR flow is this:
|
||||
|
||||
1) We receive a python frame
|
||||
2) We convert the given frame from (1), passing it through instruction
|
||||
translation
|
||||
3) For the objects captured in (2), we create tracking objects that are
|
||||
(a) tracked on an output graph, which is an internal specialization
|
||||
of a torch.fx.Tracer (and the topic of a later writeup), and (b)
|
||||
guards, the topic of this document.
|
||||
4) We process the guard objects created in (3), turning them into a
|
||||
generated python function, check_fn, associated with a piece of code.
|
||||
5) The check_fn is evaluated whenever we encounter this code a
|
||||
subsequent time - if a check_fn passes and evaluates to True, we know
|
||||
the code in the cache and the code encountered here is the same, and
|
||||
can be safely used. If it fails and evaluates to False, we know the
|
||||
code in the cache is not valid, and can be thrown out in favor of a
|
||||
new entry, through recompilation or a graph break.
|
||||
|
||||
Python Frame Evaluation and PEP 523
|
||||
-----------------------------------
|
||||
|
||||
The functionality of TorchDynamo is based on
|
||||
`PEP 523 <https://peps.python.org/pep-0523/>`__.
|
||||
|
||||
TorchDynamo installs a frame evaluation function on Python, via
|
||||
`_PyInterpreterState_SetEvalFrameFunc`. The overview of function
|
||||
selection, thread management, and cleanup is out of scope for this
|
||||
writeup, but the important part is that TorchDynamo has a hook where
|
||||
Python can hand control back to us during evaluation.
|
||||
|
||||
The function we have installed is ``convert_frame`` or
|
||||
``convert_frame_assert`` in the ``nopython=True`` case, but glossing
|
||||
over that nuance for now, let’s take a look at ``convert_frame_assert``,
|
||||
as ``convert_frame`` proxies to it anyway.
|
||||
|
||||
We can find it on `line 20 of convert_frame.py
|
||||
<https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/convert_frame.py#L200>`__,
|
||||
with a signature as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def convert_frame_assert(compiler_fn: Callable, one_graph=True):
|
||||
|
||||
This function wraps the entry point of where Python invokes TorchDynamo
|
||||
with a frame, glossing over the nuances of ``wrap_convert_context`` for
|
||||
now:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def _convert_frame_assert(frame: types.FrameType, cache_size: int):
|
||||
|
||||
Here is what this function does:
|
||||
|
||||
1) Checks if it has seen this ``code``\ (see: f_code `here
|
||||
<https://docs.python.org/3/library/inspect.html>`__) before and exits
|
||||
early if it did.
|
||||
2) Checks if the code is an unsupported case.
|
||||
3) Checks if the ``cache_size`` (second arg above) crosses the limit
|
||||
defined in the config, ``cache_size_limit``. If it has, the function
|
||||
drops the frame and logs warnings. This helps to avoid constant
|
||||
recompilation of a frame as it generally means that the frame is hot
|
||||
in an unexpected way and caching it produces needless overhead,
|
||||
as it is likely to get evicted the next time it is encountered.
|
||||
4) Passes the frame, alongside a function that creates an
|
||||
``InstructionTranslator`` through bytecode
|
||||
transformation, via ``transform_code_object``. A few crucial things
|
||||
happen under the hood here:
|
||||
|
||||
1) New code is produced through ``transform_code_object``.
|
||||
|
||||
2) An FX tracer named ``output`` is produced through
|
||||
``InstructionTranslator``.
|
||||
|
||||
This can be a bit confusing,
|
||||
as ``InstructionTranslator`` is not an `fx` tracer, but its stored
|
||||
in a variable named tracer, and its output*\ **is**\ *an `fx`tracer.*
|
||||
|
||||
3) The function produces guards and stores them on ``output`` above.
|
||||
|
||||
4) The function produces ``output_instructions`` and stores them on
|
||||
``output`` above.
|
||||
|
||||
5) The function maps the newly produced transformed code to the initial code it
|
||||
read off the frame. This mapping is worth remembering, we will
|
||||
refer to it much later on below where we cover guard failures.
|
||||
|
||||
5) Using the transformed code from 4.1 and the guards from 4.3
|
||||
the function produces a `GuardedCode`.
|
||||
|
||||
Now that we have learned about frame evoluation, let’s review
|
||||
``InstructionTranslator``, and see how it turns the frame we handed
|
||||
it over into TorchDynamo internal types.
|
||||
|
||||
InstructionTranslator
|
||||
---------------------
|
||||
|
||||
`InstructionTranslator` does a lot! We won’t cover the details of
|
||||
everything it does, but most importantly for this document, it produces
|
||||
a mapping of ``symbolic_locals`` which maintains a mapping from the
|
||||
frame’s f_locals to TorchDynamo internal Variable objects (more on these
|
||||
in a moment. ``symbolic_locals`` is filled via traversing the frame’s
|
||||
locals:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
self.symbolic_locals = collections.OrderedDict(
|
||||
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
|
||||
for k in vars
|
||||
if k in f_locals
|
||||
)
|
||||
|
||||
We will get to how this works later, from a few other examples that lead
|
||||
us to understanding ``VariableTracker`` and ``VariableBuilder``. The
|
||||
important component here, for us, for now, is the invocation of a call
|
||||
into ``VariableBuilder``. ``VariableBuilder``\ ’s call implementation
|
||||
proxies into a function called ``_wrap``, which in turn both constructs
|
||||
instances of ``VariableTracker`` and calls ``make_guards`` on them. More
|
||||
on that later.
|
||||
|
||||
This mapping, in turn, is critical as each Variable has associated
|
||||
guards, which are then passed to ``self.output``, the instance of
|
||||
``OutputGraph``, an fx tracer, mentioned in 4.2 of the section above. If
|
||||
you recall, this ``OutputGraph``, stored in a variable called ``output``
|
||||
is where our guards are stored before being passed on to become
|
||||
``GuardedCode``
|
||||
|
||||
How does ``InstructionTranslator`` do this? At the heart of it, there is
|
||||
a loop that is pumped, which drives a function ``step``.
|
||||
|
||||
``step`` is just that - a single processing step, taking exactly one
|
||||
instruction and doing *something* with it. Note: These are real
|
||||
instructions processed by TorchDynamo’s ``transform_code_object``, and
|
||||
it’s pretty cool.
|
||||
|
||||
.. note:: This section purposly skips the details of
|
||||
`dis.get_instructions <https://docs.python.org/3/library/dis.html>`__,
|
||||
and how we set up the ``Instruction`` class.
|
||||
|
||||
For the toy example above, here is a snippet of a what a few
|
||||
``Instruction``\'s may look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
|
||||
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
|
||||
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)
|
||||
|
||||
This is the core functionality of this function. Take a look at the ``opname``,
|
||||
and then take a look at this little snippet from inside ``step``;
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if not hasattr(self, inst.opname):
|
||||
unimplemented(f"missing: {inst.opname}")
|
||||
getattr(self, inst.opname)(inst)
|
||||
|
||||
As we can see, we check if the current class, the
|
||||
``InstructionTranslator`` has a attribute set matching the operator name
|
||||
(ex: LOAD_CONST). If it does, we invoke it, passing the whole
|
||||
instruction object in. If it does not, we drop the frame as
|
||||
unimplemented.
|
||||
|
||||
For the LOAD_CONST example, we can see that we do indeed support it,
|
||||
with a relatively straightforward definition:
|
||||
|
||||
::
|
||||
|
||||
def LOAD_CONST(self, inst):
|
||||
self.push(ConstantVariable(value=inst.argval))
|
||||
|
||||
Passing over, for now, on the other details of ``InstructionTranslator``
|
||||
we can see that this function creates a new instance of the class
|
||||
``ConstantVariable`` , with a value, in our example case, -1, and then
|
||||
pushes it onto the stack.
|
||||
|
||||
There are dozens of such methods - see symbolic_convert.py for all of
|
||||
them. Generally, we implement as many matching methods to python
|
||||
bytecode instructions as possible.
|
||||
|
||||
Across both the logic downstream of ``step`` and the logic from invoking
|
||||
``VariableBuilder`` - we now have a lot of ``VariableTracker``\ s and of
|
||||
course, we’ve spoken about creating guards quiet a bit. Let’s dig into
|
||||
what Variables are, and get a little closer to understanding guards.
|
||||
|
||||
Variables
|
||||
---------
|
||||
|
||||
A ``ConstantVariable`` is an instance of\ ``VariableTracker``.
|
||||
``VariableTracker`` represents a tracked python local or stack value.
|
||||
|
||||
When it comes to representing an object inside TorchDynamo, a
|
||||
VariableTracker does exactly what it says - it tracks a given variable.
|
||||
Its an extremely flexible class, but there are a few points to keep in
|
||||
mind:
|
||||
|
||||
- It manages the ``guard`` relationship around the underlying object
|
||||
through:
|
||||
|
||||
- `make_guard`
|
||||
- `replace_guards`
|
||||
- `add_guard(s)`
|
||||
- `propagate` - ``propagate(*vars: List[List["VariableTracker"]])`` -
|
||||
Perhaps the most important of all, in that it combines guards from
|
||||
all the provided VariableTracker instances passed in. It visits
|
||||
the guards and combines the guards from these onto itself.
|
||||
|
||||
- It acts as a proxy on behalf of the underlying object, implementing
|
||||
methods for the rest of TorchDynamo to get information about the
|
||||
tracked object:
|
||||
|
||||
- `call_method`
|
||||
- `call_function`
|
||||
- `python_type`
|
||||
- `as_proxy`
|
||||
- `is/as_python_proxy`
|
||||
|
||||
- It stores the variable ``source`` of type ``Source``, from
|
||||
torchdynamo/source.py. This source type is a relatively self
|
||||
contained class to help us organize and bookeep where the original
|
||||
source came from, and helps provide convenience methods for things
|
||||
like getting the name, and importantly for us, producing guards.
|
||||
|
||||
And this class (``VariableTracker``) is built around subclassing,
|
||||
somewhere between a full Abstract Base Class and fully fleshed out class
|
||||
- it leaves many methods raising NotImplementedError - with reliance on
|
||||
subclasses (see: torchdynamo/variables/ for all subclasses) to fulfill
|
||||
contracts and custom behaviors.
|
||||
|
||||
Knowing what we know now, we can see an example of how an instruction
|
||||
from ``dis``, ``BUILD_TUPLE``
|
||||
|
||||
BUILD_TUPLE(count) Creates a tuple consuming count items from the
|
||||
stack, and pushes the resulting tuple onto the stack.
|
||||
|
||||
In our case, our signature will be a *little* different due to the way
|
||||
we create ``Instruction`` objects, but the gist of it will be the same.
|
||||
Instead of passing in ``count``, we pass in an object with a little
|
||||
extra bookkeeping, and of course, we deal with turning regular old
|
||||
python objects into TorchDynamo notions:
|
||||
|
||||
::
|
||||
|
||||
def BUILD_TUPLE(self, inst):
|
||||
items = self.popn(inst.argval)
|
||||
options = VariableTracker.propagate(items)
|
||||
self.push(TupleVariable(items, **options))
|
||||
|
||||
What is happening here? 1) We read argval, which in this case, is
|
||||
analogous to ``counts`` in the pydoc for the equivalent instruction.
|
||||
|
||||
2) We ``popn`` the items, in this case, the signature is
|
||||
``def popn(self, n: int) -> List[TensorVariable]:`` this hints at an
|
||||
underlying contract - we are returning ``TensorVariables``. If we
|
||||
take a closer look at sybmolic_convert.py and
|
||||
``InstructionTranslatorBase``/``InstructionTranslator``\ we see that
|
||||
the only thing pushed onto and popped from our stack are
|
||||
``VariableTracker``\ s.
|
||||
|
||||
3) We call ``VariableTracker.propogate`` (remember it, from above?) This
|
||||
takes the guards from every single item popped off the stack in 2,
|
||||
and recursively traverses it and combines all the guards into
|
||||
``options``: ``py return { "guards": guards, }``
|
||||
|
||||
4) We then make a new instance of a ``VariableTracker``,
|
||||
``TupleVariable``\ out of the ``items`` and ``options``. This then
|
||||
allows us to install all the appropriate guards from the ``items``
|
||||
that make up the new ``TupleVariable``
|
||||
|
||||
Note: You may wonder - where did the first guards come from? Propagation
|
||||
is good and all, but don’t we need something created before it can be
|
||||
propagated. Yes! Remember that ``VariableBuilder`` above? It calls
|
||||
``make_guards`` as it creates ``VariableTracker`` instances, from
|
||||
``f_locals``. This in turn calls into the ``source``, to have it create
|
||||
guards.
|
||||
|
||||
After all this, bytecode translation is done and we are one step closer
|
||||
to producing ``GuardedCode``. We now understand how locals become
|
||||
``VariableTracker``\ s, how instructions are handled, and where guards
|
||||
are called on for creation. Before we can go into seeing how code and
|
||||
guards are combined into a GuardedCode object, we need to dig a little
|
||||
bit into those ``make_guard`` and ``source.make_guard`` calls above. We
|
||||
can then understand, really, what was going on when we made guards
|
||||
alongside, and on, ``VariableTracker`` instances.
|
||||
|
||||
Making Guards
|
||||
-------------
|
||||
|
||||
Guards are just python objects, of the class ``Guard``, however, theres
|
||||
a good amount of detail around this little class.
|
||||
|
||||
Looking at the definition of the dataclass (and therefore, ctor
|
||||
signature), we see that it has a name, a source, and a create function.
|
||||
|
||||
::
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Guard:
|
||||
name: str
|
||||
source: GuardSource
|
||||
create_fn: Callable
|
||||
|
||||
The name should be the name of the variable.
|
||||
|
||||
The source here is an enum indicating what *kind* of source the guard
|
||||
belongs to [Note: not to be confused with ``Source`` and the other types
|
||||
in source.py, as stored on ``VariableTracker``, as discussed above]
|
||||
|
||||
And create_fn is the heart of how we go from having this simple
|
||||
dataclass to actually producing valid python code to be invoked for
|
||||
knowing whether or not things have changed in between invocations, and
|
||||
whether we can safely read from the code cache or not (In case you
|
||||
forgot what all this was for!)
|
||||
|
||||
The most common code paths for getting an instance of a guard are
|
||||
through ``make_guards`` on ``VariableTracker``.
|
||||
``make_guards``->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)``
|
||||
|
||||
Or, in a concrete example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
...
|
||||
elif istype(value, range):
|
||||
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
|
||||
return RangeVariable(value=value, guards=guards)
|
||||
|
||||
Since ``source`` was set at the construction time of this
|
||||
``VariableTracker``, all that was needed here was to provide the fn,
|
||||
``GuardBuilder.EQUALS_MATCH`` to the ``create_fn`` field.
|
||||
|
||||
This ``create_fn`` must be a method on ``GuardBuilder``. The reason for
|
||||
this becomes apparent in our next step. Once we have all the guards
|
||||
created for a frame, we move on to ``CheckFunctionManager`` and
|
||||
``compile_check_fn``.
|
||||
|
||||
Remember that ``convert_frame`` function way above, in the first
|
||||
section? Before it can produce a ``GuardedCode``, it needs to run the
|
||||
``CheckFunctionManager``, with all the guards, to produce a ``check_fn``
|
||||
which will then, in turn get passed in alongside the code into
|
||||
``GuardedCode``. This is the same ``check_fn`` that we store in our
|
||||
cache entry, and the same one we run to know whether or not to retrieve
|
||||
the code stored alongside. For reference, here is that code:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
static CacheEntry *create_cache_entry(CacheEntry *next,
|
||||
PyObject *guarded_code) {
|
||||
CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
|
||||
DEBUG_NULL_CHECK(e);
|
||||
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
|
||||
NULL_CHECK(e->check_fn);
|
||||
e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
|
||||
NULL_CHECK(e->code);
|
||||
e->next = next;
|
||||
return e;
|
||||
}
|
||||
|
||||
We now know how a ``check_fn`` function is used, and who makes it, and
|
||||
what it is composed of, but what we do not yet know is how. How does a
|
||||
list of ``Guard`` objects become a function we can run later on?
|
||||
|
||||
First, we iterate these guards:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for guard in sorted(guards or [], key=Guard.sort_key):
|
||||
if not config.guard_nn_modules and guard.is_nn_module():
|
||||
continue
|
||||
guard.create(local_builder, global_builder)
|
||||
|
||||
Calling ``guard.create`` runs that ``create_fn`` we set on the ``Guard``
|
||||
class above (don’t confuse it with the ``check_fn`` we are working on
|
||||
producing, the names are similar, so it can get a little confusing). In
|
||||
our example above, our ``create_fn`` is ``GuardBuilder.EQUALS_MATCH``.
|
||||
So we are now invoking it, passing in the ``self``, the guard itself,
|
||||
in.
|
||||
|
||||
The signature is: ``def EQUALS_MATCH(self, guard: Guard):``
|
||||
|
||||
And internally to that function, we can use the ``name`` on the guard to
|
||||
get back our original object, querying it for data and type information,
|
||||
which in turn gets us to the most important bit: appending code.
|
||||
|
||||
At its simplest, ``EQUALS_MATCH`` appends just one line of code:
|
||||
``self.code.append(f"{ref} == {val!r}")``. Where ``ref`` is the name of
|
||||
the variable, and val is the value. It might produce code like this:
|
||||
|
||||
.. code-block::
|
||||
|
||||
y == 2
|
||||
|
||||
Pretty simple, but if we append a few other kinds of ``GuardBuilder``
|
||||
functions on (For a more complex case), and then combine them all with
|
||||
``and`` in between each statement (as we do), we might get something
|
||||
like this:
|
||||
|
||||
.. code-block::
|
||||
|
||||
___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)
|
||||
|
||||
Now we’re talking! Let’s see what we have here: 1) A check for
|
||||
``.valid`` (we will come back to invalidation later on) 2) A type id
|
||||
check 3) A value check 4) A tensor check
|
||||
|
||||
This becomes the heart of the code our ``check_fn``, which in turn, as
|
||||
you recall, is evaluated the **next** time we encounter this code. It
|
||||
will then check:
|
||||
|
||||
1) Is this code still valid?
|
||||
2) If (1), Does ``y`` still have a type of ``94367738391392``?
|
||||
3) If (2), is ``y`` still 2?
|
||||
4) If (3), let’s check on if tensor ``x`` changed in some specific ways
|
||||
|
||||
If all of these are still true, then we can use the code cached
|
||||
alongside this ``check_fn``! Joyous day! [Note: a deeper dive for how
|
||||
and where this happens if saved for a later writeup, but reading
|
||||
``static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {`` of
|
||||
``_eval_frame.c`` is a good place to start for the inquisitive reader
|
||||
who has made it thus far].
|
||||
|
||||
If not, then, we can move on to recompiling the code anew, and storing
|
||||
that in the cache alongside this code, and a whole new ``check_fn``,
|
||||
again to be checked on yet another subsequent frame.
|
||||
|
||||
There are lots of other such functions on ``GuardBuilder`` which get
|
||||
coalesced into, at times massive, strings which then get evaluated as
|
||||
python code and stored into ``check_fn``. Our example above is
|
||||
illustrative of a simple case, but I urge you to read the other
|
||||
functions on ``GuardBuilder``, or better yet, dump the ``code`` variable
|
||||
in ``compile_check_fn`` to really see what’s getting produced,
|
||||
especially on larger, real models!
|
||||
|
||||
Summary
|
||||
-------
|
||||
|
||||
In this, we have glossed over: - The role of ``.valid`` and invalidation
|
||||
around weak references (and potentially soon to be NN Module
|
||||
invalidations) - How the C++ side of guard functions
|
||||
(``___check_type_id``, ``___check_tensors``, etc) operate - What happens
|
||||
when guards fail? - What happens if we produce invalid guard code?
|
||||
|
||||
Despite all that, I hope this has been a useful read. We covered how
|
||||
user provided code, wrapped in a TorchDynamo context goes on to get
|
||||
traced and tracked internally, organized into ``VariableTracker``\ s
|
||||
``Source``\ s and subsequently ``Guard``\ s, and how those ``Guards`` in
|
||||
turn guide cache entry selection and invalidation when handing Python
|
||||
code.
|
||||
44
docs/source/dynamo/index.rst
Normal file
44
docs/source/dynamo/index.rst
Normal file
@@ -0,0 +1,44 @@
|
||||
TorchDynamo Documentation
|
||||
=========================
|
||||
|
||||
**TorchDynamo** is a Python-level JIT compiler designed to make unmodified
|
||||
PyTorch programs faster. TorchDynamo hooks into the frame evaluation API
|
||||
in CPython (`PEP 523 <https://peps.python.org/pep-0523/>`__) to
|
||||
dynamically modify Python bytecode right before it is executed. It
|
||||
rewrites Python bytecode in order to extract sequences of PyTorch
|
||||
operations into an `FX Graph <https://pytorch.org/docs/stable/fx.html>`__
|
||||
which is then just-in-time compiled with a customizable backend.
|
||||
It creates this FX Graph through bytecode analysis and is designed to
|
||||
mix Python execution with compiled backends to get the best of both
|
||||
worlds: usability and performance.
|
||||
|
||||
TorchDynamo makes it easy to experiment with different compiler
|
||||
backends to make PyTorch code faster with a single line decorator
|
||||
``torch._dynamo.optimize()``
|
||||
|
||||
.. image:: ../_static/img/dynamo/TorchDynamo.png
|
||||
|
||||
For more information about `TorchInductor`, one of the backends
|
||||
supported by `TorchDynamo Graph <https://pytorch.org/docs/stable/fx.html>`__
|
||||
into `Triton <https://github.com/openai/triton>`__ for GPUs or
|
||||
`C++/OpenMP <https://www.openmp.org/>`__ for CPUs. We have a
|
||||
`training performance dashboard <https://github.com/pytorch/torchdynamo/issues/681#issuecomment-1233828468>`__
|
||||
that provides performance comparison for different training backends. You can read
|
||||
more in the `TorchInductor post on PyTorch
|
||||
dev-discuss <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__.
|
||||
|
||||
.. seealso::
|
||||
|
||||
* `TorchDynamo deep-dive video <https://www.youtube.com/watch?v=egZB5Uxki0I>`__
|
||||
* `dev-discuss topics <https://dev-discuss.pytorch.org/search?q=TorchDynamo%20order%3Alatest>`__
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
installation
|
||||
get-started
|
||||
guards-overview
|
||||
custom-backends
|
||||
deep-dive
|
||||
troubleshooting
|
||||
faq
|
||||
83
docs/source/dynamo/installation.rst
Normal file
83
docs/source/dynamo/installation.rst
Normal file
@@ -0,0 +1,83 @@
|
||||
Installing TorchDynamo
|
||||
======================
|
||||
|
||||
This section describes how to install TorchDynamo.
|
||||
|
||||
Requirements and Setup
|
||||
----------------------
|
||||
|
||||
Python 3.8 is recommended. Python 3.7 through 3.10 are supported and
|
||||
tested. Make sure to have a development version of Python installed
|
||||
locally as well.
|
||||
|
||||
TorchDynamo is included in the nightly binaries of PyTorch. You can
|
||||
find more information `here <https://pytorch.org/get-started/locally/>`__
|
||||
|
||||
Install GPU/CUDA version requirements
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To use GPU back ends (and in particular Triton), please make sure that
|
||||
the CUDA that you have installed locally matches the PyTorch version you
|
||||
are running.
|
||||
|
||||
The following command installs GPU PyTorch+TorchDynamo along with GPU
|
||||
TorchDynamo dependencies (for CUDA 11.7):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
|
||||
CPU requirements
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
There are no additional requirements for CPU TorchDynamo. CPU
|
||||
TorchDynamo is included in the nightly versions of PyTorch, which, for
|
||||
reference, can be installed with the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
|
||||
|
||||
Install from local source
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Build PyTorch from source:
|
||||
https://github.com/pytorch/pytorch#from-source, which has TorchDynamo
|
||||
included.
|
||||
|
||||
To install GPU TorchDynamo dependencies, run ``make triton`` in the
|
||||
PyTorch repo root directory.
|
||||
|
||||
Verify Installation
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If you built PyTorch from source, then you can run the following
|
||||
commands (from the PyTorch repo root directory) that run minimal
|
||||
examples to check that TorchDynamo is installed correctly:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
cd tools/dynamo
|
||||
python verify_dynamo.py
|
||||
|
||||
If you do not have the PyTorch source locally, you can alternatively
|
||||
copy the script (``tools/dynamo/verify_dynamo.py``) from the PyTorch
|
||||
repo and run it locally.
|
||||
|
||||
Docker installation
|
||||
-------------------
|
||||
|
||||
We also provide all the required dependencies in the PyTorch nightly
|
||||
binaries which you can download with
|
||||
|
||||
.. code-block::
|
||||
|
||||
docker pull ghcr.io/pytorch/pytorch-nightly
|
||||
|
||||
And for ad hoc experiments just make sure that your container has access
|
||||
to all your GPUs
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
docker run --gpus all -it ghcr.io/pytorch/pytorch-nightly:latest /bin/bash
|
||||
665
docs/source/dynamo/troubleshooting.rst
Normal file
665
docs/source/dynamo/troubleshooting.rst
Normal file
@@ -0,0 +1,665 @@
|
||||
TorchDynamo Troubleshooting
|
||||
===========================
|
||||
|
||||
**Author**: `Michael Lazos <https://github.com/mlazos>`_
|
||||
|
||||
TorchDynamo is still in active development, and many of the reasons for
|
||||
graph breaks and excessive recompilation will be fixed with upcoming
|
||||
support for `tracing dynamic tensor
|
||||
shapes <https://docs.google.com/document/d/1QJB-GOnbv-9PygGlOMXwiO9K6vVNm8sNg_olixJ9koc/edit?usp=sharing>`__,
|
||||
more careful choices for guards and better tuned heuristics.
|
||||
|
||||
In the mean time, you may need to diagnose a particular issue and
|
||||
determine if it is easy to work around with a change to your model, or
|
||||
file an issue for support.
|
||||
|
||||
Also, we are actively developing debug tools, profilers, and improving our
|
||||
errors/warnings. Please give us feedback if you have an issue with this
|
||||
infra, or an idea for an improvement. Below is a table of the available
|
||||
tools and their typical usage. For additional help see
|
||||
`Diagnosing Runtime Errors <#diagnosing-runtime-errors>`__.
|
||||
|
||||
.. list-table:: Title
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Tool
|
||||
- Purpose
|
||||
- Usage
|
||||
* - Info logging
|
||||
- View summarized steps of compilation
|
||||
- ``torch._dynamo.config.log_level = logging.INFO``
|
||||
* - Debug logging
|
||||
- View detailed steps of compilation (print every instruction traced)
|
||||
- ``torch._dynamo.config.log_level = logging.DEBUG`` and
|
||||
``torch._dynamo.config.verbose = True``
|
||||
* - Minifier for any backend
|
||||
- Find smallest subgraph which reproduces errors for any backend
|
||||
- set environment variable ``TORCHDYNAMO_REPRO_AFTER="dynamo"``
|
||||
* - Minifier for ``TorchInductor``
|
||||
- If the error is known to occur after `AOTAutograd`` find
|
||||
smallest subgraph wich reproduces errors during TorchInductor lowering
|
||||
- set environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"``
|
||||
* - Accuracy minifier
|
||||
- Finds the smallest subgraph which reproduces an accuracy issue
|
||||
between an eager model model and optimized model
|
||||
- ``TORCHDYNAMO_REPRO_AFTER=<"aot"/"dynamo"> TORCHDYNAMO_REPRO_LEVEL=4``
|
||||
* - ``torch._dynamo.explain``
|
||||
- Find graph breaks and display reasoning for them
|
||||
- ``torch._dynamo.explain(fn, *inputs)``
|
||||
* - Record/Replay
|
||||
- Record and replay frames which to reproduce errors during graph capture
|
||||
- ``torch._dynamo.config.replay_record_enabled = True``
|
||||
* - TorchDynamo function name filtering
|
||||
- Only compile functions with the given name to reduce noise when
|
||||
debugging an issue
|
||||
- set environment variable ``TORCHDYNAMO_DEBUG_FUNCTION=<name>``
|
||||
* - TorchInductor Debug logging
|
||||
- Print general TorchInductor debug info and generated Triton/C++ code
|
||||
- ``torch._inductor.config.debug = True``
|
||||
* - TorchInductor Tracing
|
||||
- Show time taken in each TorchInductor stage + output code and graph
|
||||
visualization
|
||||
- set the environment variable TORCHINDUCTOR_TRACE=1 or
|
||||
``torch._inductor.config.trace.enabled = True``
|
||||
|
||||
Diagnosing Runtime Errors
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Below is the TorchDynamo compiler stack.
|
||||
|
||||
At a high level, the TorchDynamo stack consists of a graph capture from
|
||||
Python code (TorchDynamo) and a backend compiler. In this example the
|
||||
backend compiler consists of backward graph tracing (AOTAutograd) and
|
||||
graph lowering (TorchInductor)*. Errors can occur in any component of
|
||||
the stack and will provide full stack traces.
|
||||
|
||||
You may use info logging
|
||||
(``torch._dynamo.config.log_level = logging.INFO``) and look for
|
||||
``Step #: ...`` outputs in order to determine in which component the
|
||||
error occurred in. Logs are made at the beginning and end of each step,
|
||||
so the step that an error should correspond to is the most recent logged
|
||||
step whose end has not yet been logged. The steps correspond to the
|
||||
following parts of the stack (according to the image above):
|
||||
|
||||
==== ================
|
||||
Step Component
|
||||
==== ================
|
||||
1 TorchDynamo
|
||||
2 Compiler Backend
|
||||
3 TorchInductor
|
||||
==== ================
|
||||
|
||||
The beginning and end of AOTAutograd is currently not logged, but we
|
||||
plan to add it soon.
|
||||
|
||||
If info logging is insufficient, then there are also some backend
|
||||
options which can enable you to determine which component is causing the
|
||||
error if you’re unable to understand the error message that is
|
||||
generated. These are the following:
|
||||
|
||||
- ``"eager"``: only runs torchdynamo forward graph capture and then
|
||||
runs the captured graph with PyTorch. This provides an indication as
|
||||
to whether TorchDynamo is raising the error.
|
||||
|
||||
- ``"aot_eager"``: runs torchdynamo to capture a forward graph, and
|
||||
then AOTAutograd to trace the backward graph without any additional
|
||||
backend compiler steps. PyTorch eager will then be used to run the
|
||||
forward and backward graphs. This is useful to narrow down the issue
|
||||
to AOTAutograd.
|
||||
|
||||
The general procedure to narrow down an issue is the following: 1. Run
|
||||
your program with the ``"eager"`` backend. If the error no longer
|
||||
occurs, the issue is in the backend compiler that is being used (if
|
||||
using TorchInductor, proceed to step 2, if not, see `this
|
||||
section <#minifying-backend-compiler-errors>`__). If the error still
|
||||
occurs with the ``"eager"`` backend, it is an `error while running
|
||||
torchdynamo <#torchdynamo-errors>`__.
|
||||
|
||||
2. This step is only necessary if TorchInductor is used as the backend
|
||||
compiler. Run the model with the ``"aot_eager"`` backend. If this
|
||||
backend raises an error then the error is occurring during
|
||||
AOTAutograd tracing. If the error no longer occurs with this backend,
|
||||
then `the error is in
|
||||
TorchInductor\* <#minifying-torchinductor-errors>`__.
|
||||
|
||||
Each of these cases are analyzed in the following sections.
|
||||
|
||||
\*Note on TorchInductor naming: The TorchInductor backend consists of
|
||||
both AOTAutograd tracing and the TorchInductor compiler itself. We will
|
||||
disambiguate by referring to TorchInductor as the backend, and
|
||||
TorchInductor lowering as the phase which lowers the graph traced by
|
||||
AOTAutograd.
|
||||
|
||||
Torchdynamo Errors
|
||||
------------------
|
||||
|
||||
If the error that is generated occurs with the ``"eager"`` backend, then
|
||||
torchdynamo is the most likely source of the error. Here is example code
|
||||
which will generate an error.
|
||||
|
||||
.. code:: py
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
|
||||
@dynamo.optimize("eager")
|
||||
def test_assertion_error():
|
||||
y = torch.ones(200, 200)
|
||||
z = {y: 5}
|
||||
return z
|
||||
|
||||
|
||||
test_assertion_error()
|
||||
|
||||
Which will generate the following error:
|
||||
|
||||
::
|
||||
|
||||
torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26
|
||||
due to:
|
||||
Traceback (most recent call last):
|
||||
File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP
|
||||
assert isinstance(k, ConstantVariable) or (
|
||||
AssertionError
|
||||
|
||||
from user code:
|
||||
File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error
|
||||
z = {y: 5}
|
||||
|
||||
Set torch._dynamo.config.verbose=True for more information
|
||||
==========
|
||||
|
||||
As the message suggests you can set
|
||||
``torch._dynamo.config.verbose=True`` to get a full stack trace to both
|
||||
the error in torchdynamo and the user code. In addition to this flag,
|
||||
you can also set the ``log_level`` of torchdynamo through
|
||||
``torch._dynamo.config.log_level``. The available levels are the
|
||||
following: - ``logging.DEBUG``: Print every instruction that is
|
||||
encountered in addition to all below log levels - ``logging.INFO``:
|
||||
Print each function that is compiled (original and modified bytecode)
|
||||
and the graph that is captured in addition to all below log levels -
|
||||
``logging.WARNING`` (default): Print graph breaks in addition to all
|
||||
below log levels - ``logging.ERROR``: Print errors only
|
||||
|
||||
If a model is sufficiently large, the logs can become overwhelming. If
|
||||
an error occurs deep within a model’s python code, it can be useful to
|
||||
execute only the frame in which the error occurs to enable easier
|
||||
debugging. There are two tools available to enable this: - Setting the
|
||||
environment variable TORCHDYNAMO_DEBUG_FUNCTION to the desired function
|
||||
name will only run torchdynamo on functions with that name. - There is a
|
||||
record/replay tool (set
|
||||
``torch._dynamo.config.replay_record_enabled = True``) which dumps an
|
||||
execution record when an error is encountered. This record can then be
|
||||
replayed to run only the frame where an error occurred.
|
||||
|
||||
TorchInductor Errors
|
||||
--------------------
|
||||
|
||||
If the error doesn’t occur with the ``"eager"`` backend, then the
|
||||
backend compiler is the source of the error (`example
|
||||
error <https://gist.github.com/mlazos/2f13681e3cc6c43b3911f336327032de%5D>`__).
|
||||
There are `different
|
||||
choices <https://github.com/pytorch/torchdynamo/blob/0b8aaf340dad4777a080ef24bf09623f1aa6f3dd/README.md#existing-backends>`__
|
||||
for backend compilers for torchdynamo, with TorchInductor or nvfuser
|
||||
fitting the needs of most users. This section focuses on TorchInductor
|
||||
as the motivating example, but some tools will be usable with other
|
||||
backend compilers.
|
||||
|
||||
Below is the portion of the stack which we are focusing on:
|
||||
|
||||
With TorchInductor as the chosen backend, AOTAutograd is used to
|
||||
generate the backward graph from the forward graph captured by
|
||||
torchdynamo. It’s important to note that errors can occur during this
|
||||
tracing and also while TorchInductor lowers the forward and backward
|
||||
graphs to GPU code or C++. A model can often consist of hundreds or
|
||||
thousands of FX nodes, so narrowing the exact nodes where this problem
|
||||
occurred can be very difficult. Fortunately, there are tools availabe to
|
||||
automatically minify these input graphs to the nodes which are causing
|
||||
the issue. The first step is to determine whether the error occurs
|
||||
during tracing of the backward graph with AOTAutograd or during
|
||||
TorchInductor lowering. As mentioned above in step 2, the
|
||||
``"aot_eager"`` backend can be used to run only AOTAutograd in isolation
|
||||
without lowering. If the error still occurs with this backend, this
|
||||
indicates that the error is occurring during AOTAutograd tracing.
|
||||
|
||||
Here’s an example:
|
||||
|
||||
.. code:: py
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
|
||||
@dynamo.optimize("inductor")
|
||||
def test_backend_error():
|
||||
|
||||
y = torch.ones(200, 200)
|
||||
x = torch.ones(200, 200)
|
||||
z = x + y
|
||||
a = torch.ops.aten._foobar(z) # dummy function which errors
|
||||
return model(a)
|
||||
|
||||
|
||||
test_backend_error()
|
||||
|
||||
Running this should give you this error (with a longer stack trace below
|
||||
it)
|
||||
|
||||
::
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function
|
||||
return lowerings[target](*args, **kwargs)
|
||||
File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped
|
||||
return decomp_fn(*args, **kwargs)
|
||||
File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar
|
||||
assert False
|
||||
AssertionError
|
||||
...
|
||||
|
||||
`error with full stack
|
||||
trace <https://gist.github.com/mlazos/d6947854aa56d686800259a164c62100>`__
|
||||
|
||||
If you then change ``@dynamo.optimize("inductor")`` to
|
||||
``@dynamo.optimize("aot_eager")``, it will run without error, because
|
||||
`the
|
||||
issue <https://github.com/pytorch/torchdynamo/blob/d09e50fbee388d466b5252a63045643166006f77/torchinductor/lowering.py#:~:text=%23%20This%20shouldn%27t%20be,assert%20False>`__
|
||||
is in the TorchInductor lowering process, not in AOTAutograd.
|
||||
|
||||
Minifying TorchInductor Errors
|
||||
------------------------------
|
||||
|
||||
From here, let’s run the minifier to get a minimal repro. Setting the
|
||||
environment variable TORCHDYNAMO_REPRO_AFTER=“aot” (or setting
|
||||
``torch._dynamo.config.repro_after="aot"`` directly) will generate a
|
||||
python program which reduces the graph produced by AOTAutograd to the
|
||||
smallest subgraph which reproduces the error. (See below for an example
|
||||
where we minify the graph produced by torchdynamo) Running the program
|
||||
with this environment variable should show nearly `identical
|
||||
output <https://gist.github.com/mlazos/0458ab828aa403c779fe73c012aa5982>`__,
|
||||
with an additional line indicating where ``minifier_launcher.py`` has
|
||||
been written to. The output directory is configurable by setting
|
||||
``torch._dynamo.config.base_dir`` to a valid directory name. The final
|
||||
step is to run the minifier and check that it runs successfully. A
|
||||
successful run looks like
|
||||
`this <https://gist.github.com/mlazos/e6ea41ccce68a7b1b8a7a09acb1b206a>`__.
|
||||
If the minifier runs successfully, it generates runnable python code
|
||||
which reproduces the exact error. For our example this is the following
|
||||
code:
|
||||
|
||||
.. code:: py
|
||||
|
||||
import torch
|
||||
from torch import tensor, device
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from math import inf
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
# torch version: 1.13.0a0+gitfddfc44
|
||||
# torch cuda version: 11.6
|
||||
# torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5
|
||||
|
||||
|
||||
# CUDA Info:
|
||||
# nvcc: NVIDIA (R) Cuda compiler driver
|
||||
# Copyright (c) 2005-2022 NVIDIA Corporation
|
||||
# Built on Thu_Feb_10_18:23:41_PST_2022
|
||||
# Cuda compilation tools, release 11.6, V11.6.112
|
||||
# Build cuda_11.6.r11.6/compiler.30978841_0
|
||||
|
||||
# GPU Hardware Info:
|
||||
# NVIDIA A100-SXM4-40GB : 8
|
||||
|
||||
|
||||
from torch.nn import *
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
|
||||
def forward(self, add):
|
||||
_foobar = torch.ops.aten._foobar.default(add); add = None
|
||||
return (_foobar,)
|
||||
|
||||
args = [((200, 200), (200, 1), torch.float32, 'cpu')]
|
||||
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
|
||||
mod = make_fx(Repro())(*args)
|
||||
from torch._inductor.compile_fx import compile_fx_inner
|
||||
|
||||
compiled = compile_fx_inner(mod, args)
|
||||
compiled(*args)
|
||||
|
||||
The ``forward`` method of the ``Repro`` module contains the exact op
|
||||
which causes the issue. When filing an issue, please include any
|
||||
minified repros to aid in debugging.
|
||||
|
||||
Minifying Backend Compiler Errors
|
||||
---------------------------------
|
||||
|
||||
With backend compilers other than TorchInductor the process for finding
|
||||
the subgraph causing the error is nearly identical to the procedure in
|
||||
`errors in TorchInductor <#torchinductor-errors>`__ with one important
|
||||
caveat. Namely, that the minifier will now be run on the graph that is
|
||||
traced by TorchDynamo, not the output graph of AOTAutograd. Let’s walk
|
||||
through an example.
|
||||
|
||||
.. code:: py
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
|
||||
# toy compiler which fails if graph contains relu
|
||||
def toy_compiler(gm: torch.fx.GraphModule, _):
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
assert False
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
@dynamo.optimize(toy_compiler)
|
||||
def test_backend_error():
|
||||
y = torch.ones(200, 200)
|
||||
x = torch.ones(200, 200)
|
||||
z = x + y
|
||||
a = torch.relu(z)
|
||||
return model(a)
|
||||
|
||||
|
||||
test_backend_error()
|
||||
|
||||
In order to run the code after TorchDynamo has traced the forward graph,
|
||||
the TORCHDYNAMO_REPRO_AFTER enviornment variable can be used. Running
|
||||
this program with TORCHDYNAMO_REPRO_AFTER=“dynamo” (or
|
||||
``torch._dynamo.config.repro_after="dynamo"``) should produce `this
|
||||
output <https://gist.github.com/mlazos/244e3d5b53667e44078e194762c0c92b>`__\ and
|
||||
the following code in ``{torch._dynamo.config.base_dir}/repro.py``.
|
||||
Note: the other option for TORCHDYNAMO_REPRO_AFTER are ``"aot"``, which
|
||||
will run the minifier after the backward graph has been generated.
|
||||
|
||||
.. code:: py
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
from torch import tensor, device
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from math import inf
|
||||
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
|
||||
|
||||
|
||||
from torch.nn import *
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
|
||||
def forward(self, add):
|
||||
relu = torch.relu(add); add = None
|
||||
return (relu,)
|
||||
|
||||
|
||||
mod = Repro().cuda()
|
||||
opt_mod = dynamo.optimize("None")(mod)
|
||||
|
||||
|
||||
args = [((200, 200), (200, 1), torch.float32, 'cpu', False)]
|
||||
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ref = run_fwd_maybe_bwd(mod, args)
|
||||
res = run_fwd_maybe_bwd(opt_mod, args)
|
||||
|
||||
The minifier successfully reduced the graph to the op that raises the
|
||||
error in ``toy_compiler``. The other difference from the procedure in
|
||||
`TorhInductor Errors <#torchinductor-errors>`__ is that the minifier is
|
||||
automatically run after encountering a backend compiler error. After a
|
||||
successful run, the minifier writes ``repro.py`` to
|
||||
``torch._dynamo.config.base_dir``.
|
||||
|
||||
Performance Profiling
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Accessing TorchDynamo Profiler
|
||||
------------------------------
|
||||
|
||||
TorchDynamo has a builtin stats function for collecting and displaying
|
||||
the time spent in each compilation phase. These stats can be accessed by
|
||||
calling ``torch._dynamo.utils.compile_times()`` after executing
|
||||
Torch._Dynamo. By default, this returns a string representation of the
|
||||
compile times spent in each TorchDynamo function by name.
|
||||
|
||||
TorchInductor Debug Tracing
|
||||
---------------------------
|
||||
|
||||
TorchInductor has a builtin stats and trace function for displaying time
|
||||
spent in each compilation phase, output code, output graph visualization
|
||||
and IR dump. This is a debugging tool designed to make it easier to
|
||||
debug/understand the internals of TorchInductor.
|
||||
|
||||
Setting the environment variable ``TORCHINDUCTOR_TRACE=1`` will cause a
|
||||
debug trace directory to be created and printed:
|
||||
|
||||
::
|
||||
|
||||
$ env TORCHINDUCTOR_TRACE=1 python repro.py
|
||||
torch._inductor.debug: [WARNING] model_forward_0 debug trace: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.debug
|
||||
|
||||
Here is an `example debug directory
|
||||
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
|
||||
for the test program:
|
||||
|
||||
::
|
||||
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.LayerNorm(10),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
Note each file in that debug trace can be enabled/disabled via
|
||||
``torch._inductor.config.trace.*``. The profile and the diagram are both
|
||||
disabled by default since they are expensive to generate.
|
||||
|
||||
A single node in this new debug format looks like:
|
||||
|
||||
::
|
||||
|
||||
buf1: SchedulerNode(ComputedBuffer)
|
||||
buf1.writes =
|
||||
{ MemoryDep(name='buf1', index=0, size=()),
|
||||
MemoryDep(name='buf1', index=0, size=(s0,))}
|
||||
buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
|
||||
buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
|
||||
buf1.group.device = cuda:0
|
||||
buf1.group.iteration = (1, s0)
|
||||
buf1.sizes = ([], [s0])
|
||||
class buf1_loop_body:
|
||||
var_ranges = {z0: s0}
|
||||
index0 = z0
|
||||
index1 = 0
|
||||
def body(self, ops):
|
||||
get_index = self.get_index('index0')
|
||||
load = ops.load('buf0', get_index, False)
|
||||
get_index_1 = self.get_index('index0')
|
||||
load_1 = ops.load('primals_2', get_index_1, False)
|
||||
add = ops.add(load, load_1)
|
||||
get_index_2 = self.get_index('index1')
|
||||
reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add)
|
||||
return reduction
|
||||
|
||||
See the `example debug directory
|
||||
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
|
||||
for more examples.
|
||||
|
||||
Memory Profiling
|
||||
----------------
|
||||
|
||||
TBD
|
||||
|
||||
Graph Breaks
|
||||
------------
|
||||
|
||||
Given a program like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dynamo.optimize(...)
|
||||
def some_fun(x):
|
||||
...
|
||||
some_fun(x)
|
||||
...
|
||||
|
||||
TorchDynamo will attempt to compile all of the torch/tensor operations
|
||||
within some_fun into a single FX graph, but it may fail to capture
|
||||
everything into one graph.
|
||||
|
||||
Some graph break reasons are insurmountable to TorchDynamo, and can’t be
|
||||
easily fixed. - calling into a C extension other than torch is invisible
|
||||
to torchdynamo, and could do arbitrary things without TorchDynamo being
|
||||
able to introduce necessary `guards <./GuardsOverviewPt1.md>`__ to
|
||||
ensure that the compiled program would be safe to reuse. Graph breaks
|
||||
can hinder performance if the resulting fragments are small. To maximize
|
||||
performance, it’s important to have as few graph breaks as possible.
|
||||
|
||||
Identifying the cause of a graph break
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To identify all graph breaks in a program and the associated reasons for
|
||||
the breaks, ``torch._dynamo.explain`` can be used. This tool runs
|
||||
TorchDynamo on the supplied function and aggregates the graph breaks
|
||||
that are encountered. Here is an example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def toy_example(a, b):
|
||||
x = a / (torch.abs(a) + 1)
|
||||
print("woo")
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
|
||||
print(explanation)
|
||||
"""
|
||||
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
|
||||
Break reasons:
|
||||
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
|
||||
File "t2.py", line 16, in toy_example
|
||||
print("woo")
|
||||
|
||||
2. generic_jump
|
||||
File "t2.py", line 17, in toy_example
|
||||
if b.sum() < 0:
|
||||
"""
|
||||
|
||||
Note on other outputs: - ``out_guards`` - a list of lists where each
|
||||
sublist contains the guards that must pass to ensure the traced graphs
|
||||
are valid - ``graphs`` - a list of graph modules which were successfully
|
||||
traced - ``ops_per_graph`` - a list of lists where each sublist contains
|
||||
the ops thatare run in the graph
|
||||
|
||||
To throw an error on the first graph break encountered, ``nopython``
|
||||
mode can be used. This disables TorchDynamo’s python fallback, and only
|
||||
succeeds if the entire program is convertible to a single graph. Example
|
||||
usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dynamo.optimize(<compiler>, nopython=True)
|
||||
def toy_example(a, b):
|
||||
...
|
||||
|
||||
Excessive Recompilation
|
||||
-----------------------
|
||||
|
||||
When TorchDynamo compiles a function (or part of one), it makes certain
|
||||
assumptions about locals and globals in order to allow compiler
|
||||
optimizations, and expresses these assumptions as guards that check
|
||||
particular values at runtime. If any of these guards fail, Dynamo will
|
||||
recompile that function (or part) up to
|
||||
``torch._dynamo.config.cache_size_limit`` times. If your program is
|
||||
hitting the cache limit, you will first need to determine which guard is
|
||||
failing and what part of your program is triggering it.
|
||||
|
||||
The `recompilation profiler <#recompilation-profiler>`__ automates the
|
||||
process of setting TorchDynamo’s cache limit to 1 and running your
|
||||
program under an observation-only ‘compiler’ that records the causes of
|
||||
any guard failures. You should be sure to run your program for at least
|
||||
as long (as many iterations) as you were running when you ran into
|
||||
trouble, and the profiler will accumulate statistics over this duration.
|
||||
|
||||
If your program exhibits a bounded amount of dynamism, you may be able
|
||||
to tune the TorchDynamo cache limit to allow for each variation to be
|
||||
compiled and cached, but if the cache limit is too high you may find the
|
||||
cost of recompilation outweighs any optimization benefits.
|
||||
|
||||
::
|
||||
|
||||
torch._dynamo.config.cache_size_limit = <your desired cache limit>
|
||||
|
||||
Torchdynamo plans to support many common cases of dynamic tensor shapes,
|
||||
such as varying batch size or sequence length. It does not plan to
|
||||
support rank-dynamism. In the mean time, setting a specific cache limit
|
||||
can be used in coordination with bucketing techniques to achieve an
|
||||
acceptable number of recompilations for some dynamic models.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prof = dynamo.utils.CompilationProfiler()
|
||||
@dynamo.optimize(prof)
|
||||
def my_model():
|
||||
...
|
||||
my_model()
|
||||
print(prof.report())
|
||||
|
||||
Accuracy Debugging
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Accuracy issues can also be minified if you set the environment variable
|
||||
``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect
|
||||
model and a full repro might be something like
|
||||
``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason
|
||||
we need this is downstream compilers will codegen code whether it’s
|
||||
Triton code or the C++ backend, the numerics from those downstream
|
||||
compilers can be different in subtle ways yet have dramatic impact on
|
||||
your training stability. So the accuracy debugger is very useful for us
|
||||
to detect bugs in our codegen or with a backend compiler.
|
||||
|
||||
File an Issue
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
You should feel encouraged to `file a github
|
||||
issue <https://github.com/pytorch/torchdynamo/issues>`__ and expect a
|
||||
timely response.
|
||||
|
||||
Before filing an issue, read over the `README <../README.md>`__,
|
||||
`TROUBLESHOOTING <./TROUBLESHOOTING.md>`__, and search for similar
|
||||
issues.
|
||||
|
||||
When filing an issue, please include - your
|
||||
OS/python/pytorch/CUDA/triton info by running:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
python tools/verify_install.py
|
||||
|
||||
- A minimal repro script if possible, which can be generated by running
|
||||
Minifier
|
||||
- A description of the error
|
||||
- the expected behavior
|
||||
- A log (set ``torch._dynamo.config.log_file`` to a valid file name to
|
||||
dump the logs to a file and
|
||||
``torch._dynamo.config.log_level = logging.DEBUG`` and
|
||||
``torch._dynamo.config.verbose = True``)
|
||||
@@ -42,6 +42,13 @@ Features described in this documentation are classified by release status:
|
||||
|
||||
notes/*
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
:caption: torch.compile
|
||||
|
||||
dynamo/*
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Language Bindings
|
||||
|
||||
Reference in New Issue
Block a user