Move Dynamo docs back to core (#89769)

With contributions from @svekars and @malfet

Waiting for doc build job to complete
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89769
Approved by: https://github.com/soumith, https://github.com/malfet
This commit is contained in:
Mark Saroufim
2022-11-29 04:38:53 +00:00
committed by PyTorch MergeBot
parent 2b522670d2
commit 9048cf16fe
12 changed files with 2168 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 341 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 301 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

View File

@@ -0,0 +1,154 @@
Custom Backends
===============
Debugging Backend
-----------------
Suppose you wanted to better understand what is going on during a
compilation you can create a custom compiler which well refer to as a
backend that will print pretty print the fx ``GraphModule`` extracted
from dynamos bytecode analysis and return a ``forward()`` callable.
.. code-block:: python
from typing import List
import torch
import torch._dynamo as dynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@dynamo.optimize(my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
Running the above example produces the following output:
::
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
This works for ``torch.nn.Module`` as well as shown below
.. code-block:: python
import torch
import torch._dynamo as dynamo
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = dynamo.optimize(my_compiler)(mod)
optimized_mod(torch.randn(10))
Lets take a look at one more example with control flow.
.. code-block:: python
from typing import List
import torch
import torch._dynamo as dynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@dynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
Running this example produces the following output:
::
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
Speedy Backend
--------------
Integrating a custom backend that offers superior performance is also
easy and well integrate a real one
with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__:
.. code-block :: python
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.trace(gm, example_inputs)
return torch.jit.optimize_for_inference(scripted)
And then you should be able to optimize any existing code with
.. code-block:: python
@dynamo.optimize(optimize_for_inference_compiler)
def code_to_accelerate():
...
Composable Backends
-------------------
TorchDynamo includes many backends, which can be found in
`backends.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py>`__
or ``torchdynamo.list_backends()``. You can combine these backends
together with the following code:
.. code-block:: python
from torch._dynamo.optimizations import BACKENDS
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
trt_compiled = BACKENDS["tensorrt"](gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
# first backend failed, try something else...
cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs)
if cudagraphs_compiled is not None:
return cudagraphs_compiled
return gm.forward

View File

@@ -0,0 +1,145 @@
TorchDynamo Deeper Dive
=======================
**Author**: `Jason Ansel <https://github.com/jansel>`_
What is a guard?
----------------
TorchDynamo operates just-in-time and specializes graphs based on
dynamic properties. For example, the first graph above has the following
guards:
::
GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH
If any of those guards fail, the graph will be recaptured and
recompiled. The interesting guard type there is ``TENSOR_MATCH``, which
checks the following torch.Tensor properties:
- Python class of the tensor (tensor subclassing, etc)
- dtype
- device
- requires_grad
- dispatch_key (with thread-local includes/excludes applied)
- ndim
- sizes\* (optional)
- strides\* (optional)
For sizes/strides you can disable this specialization by setting the
following parameter:
.. code-block:: python
torch._dynamo.config.dynamic_shapes = True
The full specialization mode allows the backend compiler to assume an
entirely static graph. Unfortunately, most backends require this.
Operators which return dynamic shapes will trigger a graph break when
not in dynamic shape mode.
What is dynamo doing?
---------------------
If you want to understand better what TorchDynamo is doing, you can set:
.. code-block:: python
torchdynamo.config.debug = True
which triggers useful (but spammy) printouts.
For example, the printouts for the first graph in the ``toy_example``
above are:
::
__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
ORIGINAL BYTECODE toy_example example.py 9
10 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
11 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
12 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
13 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
MODIFIED BYTECODE
9 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 24
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH
At the top you can see the FX graph (which we already shared above).
Next you see the original bytecode of the function, followed by the
modified bytecode generated by TorchDynamo. Finally, you see the guards
which we covered above.
In the modified bytecode ``__compiled_fn_0`` is the return value of
``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and
``__resume_at_38_2`` are both generated continuation functions that pick
up execution after a graph break (at bytecode offsets 30 and 38). Each
of these functions take the form:
::
__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...
By generating this `resume_at` function we force the remainder of the
function to be executed in a new Python frame which recursively
triggers TorchDynamo to restart its capture once execution reaches that
point for the first time.

376
docs/source/dynamo/faq.rst Normal file
View File

@@ -0,0 +1,376 @@
Frequently Asked Questions
==========================
At a high level, the TorchDynamo stack consists of a graph capture from
Python code using dynamo and a backend compiler. In this example the
backend compiler consists of backward graph tracing using AOTAutograd
and graph lowering using TorchInductor. There are of course many more
compilers available `here <https://github.com/pytorch/torchdynamo/blob/0b8aaf340dad4777a080ef24bf09623f1aa6f3dd/README.md#existing-backend>`__
but for this document we will focus on inductor as a motivating example.
Torchdynamo supports training, using AotAutograd to capture backwards:
1. the ``.forward()`` graph and ``optimizer.step()`` is captured by torchdynamos python evalframe frontend
2. for each segment of ``.forward()`` that torchdynamo captures, it uses AotAutograd to generate a backward graph segment
3. each pair of forward, backward graph are (optionally) min-cut partitioned to save the minimal state between forward/backward
4. the forward, backward pairs are wrapped in autograd.function modules 5. usercode calling\ ``.backward()`` still triggers eagers autograd engine, which runs each compiled backward graph as if it were one op, also running any non-compiled eager ops .backward() functions
Do you support Distributed code?
--------------------------------
DDP has been tested and works, support for other distributed training
libraries is under discussion.
The main reason why Distributed code is challenging with dynamo is
because AOTAutograd unrolls both the forward and backward pass and
provides 2 graphs for backends to optimize. This is a problem for
distributed code because wed like to ideally overlap communication
operations with computations. Eager pytorch accomplishes this in
different ways for DDP/FSDP- using autograd hooks, module hooks, and
modifications/mutations of module states. In a naive application of
dynamo, hooks that should run directly after an operation during
backwards may be delayed until after the entire compiled region of
backwards ops, due to how AOTAutograd compiled functions interact with
dispatcher hooks.
The basic strategy for optimizing DDP with Dynamo is outlined in
`distributed.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/distributed.py>`__
where the main idea will be to graph break on `DDP bucket
boundaries <https://pytorch.org/docs/stable/notes/ddp.html#internal-design>`__.
When each node in DDP needs to synchronize its weights with the other
nodes it organizes its gradients and parameters into buckets which
reduces communication times and allows a node to broadcast a fraction of
its gradients to other waiting nodes.
Graph breaks in distributed code means you can expect dynamo and its
backends to optimize the compute overhead of a distributed program but
not its communication overhead. Graph-breaks may interfere with
compilation speedups, if the reduced graph-size robs the compiler of
fusion opportunities. However, there are diminishing returns with
increasing graph size since most of the current compute optimizations
are local fusions. So in practice this approach may be sufficient.
Do I still need to export whole graphs?
---------------------------------------
For the vast majority of models you probably dont and you can use
``torch._dynamo()`` optimize as is but there are a few situations where
full graphs are necessary and you can can ensure a full graph by simply
running ``torch.dynamo(..., nopython=True)`` \* Large scale training
runs, think $250K+ that require pipeline parallelism and other advanced
sharding strategies \* Inference optimizers like
`TensorRT <https://github.com/pytorch/TensorRT>`__ or
`AITemplate <https://github.com/facebookincubator/AITemplate>`__ that rely
on fusing much more aggressively than training optimizers \* Mobile training or
inference.
Future work will include tracing communication operations into graphs,
coordinating these operations with compute optimizations, and optimizing
the communciation operations.
Why is my code crashing?
------------------------
If your code ran just fine without dynamo and started to crash with it
enabled then the most important first step is figuring out which part of
the stack your failure occurred in so try running things in the below
order and only try the next step if the previous step succeeded.
1. ``dynamo.optimize("eager")`` which only runs torchdynamo forward graph
capture and then runs the captured graph with PyTorch. If this fails
then theres an issue with TorchDynamo.
2. ``dynamo.optimize("aot_eager")``
which runs torchdynamo to capture a forward graph, and then AOTAutograd
to trace the backward graph without any additional backend compiler
steps. PyTorch eager will then be used to run the forward and backward
graphs. If this fails then theres an issue with AOTAutograd.
3. ``dynamo.optimize("inductor")`` which runs torchdynamo to capture a
forward graph, and then AOTAutograd to trace the backward graph with the
TorchInductor compiler. If this fails then theres an issue with TorchInductor
TorchDynamo Errors
~~~~~~~~~~~~~~~~~~
If the error that is generated occurs with the ``"eager"`` backend, then
torchdynamo is the most likely source of the error.
To debug these issues we recommend setting
``torch._dynamo.config.verbose=True`` to get a full stack trace to both
the error in torchdynamo and the user code. In addition to this flag,
you can also set the ``log_level`` of torchdynamo through
``torch._dynamo.config.log_level``. The available levels are the
following: - ``logging.DEBUG``: Print every instruction that is
encountered in addition to all below log levels - ``logging.INFO``:
Print each function that is compiled (original and modified bytecode)
and the graph that is captured in addition to all below log levels -
``logging.WARNING`` (default): Print graph breaks in addition to all
below log levels - ``logging.ERROR``: Print errors only
If a model is sufficiently large, the logs can become overwhelming. If
an error occurs deep within a models python code, it can be useful to
execute only the frame in which the error occurs to enable easier
debugging. There are 2 tools available to enable this:
* ``env TORCHDYNAMO_DEBUG_FUNCTION=<desired_function_name>`` will only run TorchDynamo on functions with that name.
* ``env torch._dynamo.config.replay_record_enabled = True``) which dumps an execution record when an error is encountered. This record can then be replayed to run only the frame where an error occurred.
TorchInductor Errors
--------------------
With TorchInductor as the chosen backend, AOTAutograd is used to
generate the backward graph from the forward graph captured by
torchdynamo. Its important to note that errors can occur during this
tracing and also while TorchInductor lowers the forward and backward
graphs to GPU code or C++.
A model can often consist of hundreds or thousands of FX nodes, so
narrowing the exact nodes where this problem occurred can be very
difficult which is why we highly recommend you use our minifier to
create tiny reproducible examples of failures youre seeing. We can
minify errors that occur either at the AOTAutograd layer or Inductor
layer which you should try in the following order.
1. ``env TORCHDYNAMO_REPRO_AFTER="aot" python your_model.py``
2. ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py``
Minifying your error is the quickest path to getting it fixed.
The minifier will actually create a ``repro.py`` for you at the location
set by ``env TORCHDYNAMO_REPRO_DIR`` so make you have right access to
that directory. You can then run ``python repro.py`` and confirm that
you are getting the same error.
.. note::
For other compilers such as nvfuser, the process is similar but
instead you would leverage ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py``.
Why is compilation slow?
------------------------
Dynamo Compilation
~~~~~~~~~~~~~~~~~~
TorchDynamo has a builtin stats function for collecting and displaying
the time spent in each compilation phase. These stats can be accessed by
calling ``torch._dynamo.utils.compile_times()`` after executing
``torch._dynamo``. By default, this returns a string representation of
the compile times spent in each TorchDynamo function by name.
Inductor Compilation
~~~~~~~~~~~~~~~~~~~~
TorchInductor has a builtin stats and trace function for displaying time
spent in each compilation phase, output code, output graph visualization
and IR dump. ``env TORCHINDUCTOR_TRACE=1 python repro.py``. This is a
debugging tool designed to make it easier to debug/understand the
internals of TorchInductor with an output that will look something like
`this <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
Each file in that debug trace can be enabled/disabled via
``torch._inductor.config.trace.*``. The profile and the diagram are both
disabled by default since they are expensive to generate. See the
`example debug directory
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
for more examples.
Excessive Recompilation
~~~~~~~~~~~~~~~~~~~~~~~
When TorchDynamo compiles a function (or part of one), it makes certain
assumptions about locals and globals in order to allow compiler
optimizations, and expresses these assumptions as guards that check
particular values at runtime. If any of these guards fail, Dynamo will
recompile that function (or part) up to
``torch._dynamo.config.cache_size_limit`` times. If your program is
hitting the cache limit, you will first need to determine which guard is
failing and what part of your program is triggering it.
The `recompilation profiler <#recompilation-profiler>`__ automates the
process of setting TorchDynamos cache limit to 1 and running your
program under an observation-only compiler that records the causes of
any guard failures. You should be sure to run your program for at least
as long (as many iterations) as you were running when you ran into
trouble, and the profiler will accumulate statistics over this duration.
.. code-block:: python
prof = dynamo.utils.CompilationProfiler()
@dynamo.optimize(prof)
def my_model():
...
my_model()
print(prof.report())
Many of the reasons for graph breaks and excessive recompilation will be
fixed with upcoming support for `tracing dynamic tensor
shapes <https://docs.google.com/document/d/1QJB-GOnbv-9PygGlOMXwiO9K6vVNm8sNg_olixJ9koc/edit?usp=sharing>`__,
more careful choices for guards and better tuned heuristics.
Why are you recompiling in production?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In some cases, you may not want unexpected compiles after a program has
warmed up. For example, if you are serving production traffic in a
latency critical application. For this, TorchDynamo provides an
alternate mode where prior compiled graphs are used, but no new ones are
generated:
.. code-block:: python
frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))
How are you speeding up my code?
--------------------------------
There are 3 major ways to accelerat PyTorch code:
1. Kernel fusion via vertical fusions which fuse sequential operations to avoid
excessive read/writes. For example, fuse 2 subsequent cosines means you
can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion:
the simplest example being batching where a single matrix is multiplied
with a batch of examples but the more general scenario is a grouped GEMM
where a group of matrix multiplications are scheduled together
2. Out of order execution: A general optimization for compilers, by looking ahead
at the exact data dependencies within a graph we can decide on the most
opportune time to execute a node and which buffers can be reused
3. Automatic work placement: Similar of the out of order execution point,
but by matching nodes of a graph to resources like physical hardware or
memory we can design an appropriate schedule
The above are general principles for accelerating PyTorch code but
different backends will each make different tradeoffs on what to
optimize. For example Inductor first takes care of fusing whatever it
can and only then generates `Triton <https://openai.com/blog/triton/>`__
kernels. It can also
Triton in addition offers speedups because of automatic memory
coalescing, memory management and scheduling within each Streaming
Multiprocessor and has been designed to handle tiled computations.
However, regardless of the backend you use its best to use a benchmark
and see approach so try out the PyTorch profiler, visually inspect the
generated kernels and try to see whats going on for yourself.
Why am I not seeing speedups?
-----------------------------
Graph Breaks
~~~~~~~~~~~~
The main reason you wont see the speedups youd like to by using dynamo
is excessive graph breaks. So whats a graph break?
Given a program like:
.. code-block:: python
@dynamo.optimize(...)
def some_fun(x):
...
some_fun(x)
...
Torchdynamo will attempt to compile all of the torch/tensor operations
within ``some_fun()`` into a single FX graph, but it may fail to capture
everything into one graph.
Some graph break reasons are insurmountable to TorchDynamo like calling
into a C extension other than torch is invisible to torchdynamo, and
could do arbitrary things without TorchDynamo being able to introduce
necessary guards to ensure that the compiled program would be safe to reuse.
To maximize performance, its important to have as few graph breaks
as possible.
Identifying the cause of a graph break
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To identify all graph breaks in a program and the associated reasons for
the breaks, ``torch._dynamo.explain`` can be used. This tool runs
TorchDynamo on the supplied function and aggregates the graph breaks
that are encountered. Here is an example usage:
.. code-block:: python
import torch
import torch._dynamo as dynamo
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
print("woo")
if b.sum() < 0:
b = b * -1
return x * b
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
print(explanation)
"""
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
Break reasons:
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
File "t2.py", line 16, in toy_example
print("woo")
2. generic_jump
File "t2.py", line 17, in toy_example
if b.sum() < 0:
"""
To throw an error on the first graph break encountered you can use
disable python fallback by using ``nopython=True``, this should be
familiar if youve worked with export based compilers.
.. code-block:: python
@dynamo.optimize(<compiler>, nopython=True)
def toy_example(a, b):
...
Why didnt my code recompile when I changed it?
-----------------------------------------------
If you went ahead and enabled dynamic shapes via
``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code
wont recompile on shape changes. Weve added support for dynamic shapes
which avoids recompilations in the case when shapes vary by less than a
factor of 2. This is especially useful in scenarios like varying image
sizes in CV or variable sequence length in NLP. In inference scenarios
its often not possible to know what a batch size will be beforehand
because you take what you can get from different client apps.
In general, TorchDynamo tries very hard not to recompile things
unnecessarily so if for example torchdynamo finds 3 graphs and your
change only modified one graph then only that graph will recompile. So
another tip to avoid potentially slow compilation times is to warmup a
model by compiling it once after which subsequent compilations will be
much faster. Cold start compile times is still a metric we track
visibly.
Why am I getting incorrect results?
-----------------------------------
Accuracy issues can also be minified if you set the environment variable
``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect
model and a full repro might be something like
``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason
we need this is downstream compilers will codegen code whether its
Triton code or the C++ backend, the numerics from those downstream
compilers can be different in subtle ways yet have dramatic impact on
your training stability. So the accuracy debugger is very useful for us
to detect bugs in our codegen or with a backend compiler.
Why am I getting OOMs?
----------------------
Dynamo is still an alpha product so theres a few sources of OOMs and if
youre seeing an OOM try disabling the following configurations in this
order and then open an issue on Github so we can solve the root problem
1. If youre using dynamic shapes try disabling them, weve disabled
them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
CUDA graphs with Triton are enabled by default in inductor but removing
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.

View File

@@ -0,0 +1,181 @@
Getting Started
===============
Lets start with a simple example and make things more complicated step
by step. Please note that youre likely to see more significant speedups
the newer your GPU is.
.. code:: python
from torch._dynamo import optimize
import torch
def fn(x, y):
a = torch.cos(x).cuda()
b = torch.sin(y).cuda()
return a + b
new_fn = optimize("inductor")(fn)
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn()
This example will not actually run faster. Its purpose is to demonstrate
the ``torch.cos()`` and ``torch.sin()`` features which are
examples of pointwise ops as in they operate element by element on a
vector. A more famous pointwise op you might actually want to use would
be something like ``torch.relu()``. Pointwise ops in eager mode are
suboptimal because each one would need to need to read a tensor from
memory, make some changes and then write back those changes. The single
most important optimization that inductor does is fusion. So back to our
example we can turn 2 reads and 2 writes into 1 read and 1 write which
is crucial especially for newer GPUs where the bottleneck is memory
bandwidth (how quickly you can send data to a GPU) instead of compute
(how quickly your GPU can crunch floating point operations)
Another major optimization that inductor makes available is automatic
support for CUDA graphs.
CUDA graphs help eliminate the overhead from launching individual
kernels from a python program which is especially relevant for newer GPUs.
dynamo supports many different backends but inductor specifically works
by generating `Triton <https://github.com/openai/triton>`__ kernels and
we can inspect them by running ``TORCHINDUCTOR_TRACE=1 python trig.py``
with the actual generated kernel being
.. code:: python
@pointwise(size_hints=[16384], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 10000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.sin(tmp0)
tmp2 = tl.sin(tmp1)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
And you can verify that fusing the two ``sins`` did actually occur
because the two ``sin`` operations occur within a single Triton kernel
and the temporary variables are held in registers with very fast access.
You can read up a lot more on Tritons performance
`here <https://openai.com/blog/triton/>`__ but the key is its in python
so you can easily understand it even if you havent written all that
many CUDA kernels.
As a next step lets try a real model like resnet50 from the PyTorch
hub.
.. code:: python
import torch
import torch._dynamo as dynamo
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = dynamo.optimize("inductor")(model)
model(torch.randn(1,3,64,64))
And thats not the only available backend, you can run in a REPL
``dynamo.list_backends()`` to see all the available ones. Try out the
``aot_cudagraphs`` or ``nvfuser`` next as inspiration.
Lets do something a bit more interesting now, our community frequently
uses pretrained models from
`transformers <https://github.com/huggingface/transformers>`__ or
`TIMM <https://github.com/rwightman/pytorch-image-models>`__ and one of
our design goals is for dynamo and inductor to work out of the box with
any model that people would like to author.
So were going to directly download a pretrained model from the
HuggingFace hub and optimize it:
.. code:: python
import torch
from transformers import BertTokenizer, BertModel
import torch._dynamo as dynamo
# Copy pasted from here https://huggingface.co/bert-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
model = dynamo.optimize("inductor")(model) # This is the only line of code that we changed
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
output = model(**encoded_input)
If you remove the ``to(device="cuda:0")`` from the model and
encoded_input then triton will generate C++ kernels that will be
optimized for running on your CPU. You can inspect both Triton or C++
kernels for BERT, theyre obviously more complex than the trigonometry
example we had above but you can similarly skim it and understand if you
understand PyTorch.
Similarly lets try out a TIMM example
.. code:: python
import timm
import torch._dynamo as dynamo
import torch
model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2)
opt_model = dynamo.optimize("inductor")(model)
opt_model(torch.randn(64,3,7,7))
Our goal with dynamo and inductor was to build the highest coverage ML compiler which should work with any model you throw at it.
Existing Backends
~~~~~~~~~~~~~~~~~
TorchDynamo has a growing list of backends, which can be found in
`backends.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py>`__
or ``torchdynamo.list_backends()`` each of which with its optional dependencies.
Some of the most commonly used backend include:
* **Debugging backends**: \* ``dynamo.optimize("eager")`` - Uses PyTorch
to run the extracted GraphModule. This is quite useful in debugging
TorchDynamo issues. \* ``dynamo.optimize("aot_eager")`` - Uses
AotAutograd with no compiler, i.e, just using PyTorch eager for the
AotAutograds extracted forward and backward graphs. This is useful for
debugging, and unlikely to give speedups.
* **Training & inference backends**: \* ``dynamo.optimize("inductor")`` -
Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging
codegened Triton kernels `Read
more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
* ``dynamo.optimize("nvfuser")`` - nvFuser with TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* ``dynamo.optimize("aot_nvfuser")`` - nvFuser with AotAutograd. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* ``dynamo.optimize("aot_cudagraphs")`` - cudagraphs with AotAutograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
* **Inference-only backend**\ s: \* ``dynamo.optimize("ofi")`` - Uses
Torchscript optimize_for_inference. `Read
more <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__
* ``dynamo.optimize("fx2trt")`` - Uses Nvidia TensorRT for inferenc optimizations. `Read more <https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst>`__
* ``dynamo.optimize("onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__ \* ``dynamo.optimize("ipex")`` - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
Why do you need another way of optimizing PyTorch code?
-------------------------------------------------------
While a number of other code optimization tools exist in the PyTorch
ecosystem, each of them has its own flow. Here is a few examples of
existing methods and their limitations:
- ``torch.jit.trace()`` is silently wrong if it cannot trace e.g:
during control flow
- ``torch.jit.script()`` requires modifications to user or library code
by adding type annotations and removing non PyTorch code
- ``torch.fx.symbolic_trace()`` either traces correctly or gives a hard
error but its limited to traceable code so still cant handle
control flow
- ``torch._dynamo`` works out of the box and produces partial graphs.
It still has the option of producing a single graph with
``nopython=True`` which are needed for `some
situations <./documentation/FAQ.md#do-i-still-need-to-export-whole-graphs>`__
but allows a smoother transition where partial graphs can be
optimized without code modification
.. |image0| image:: ../_static/img/dynamo/TorchDynamo.png

View File

@@ -0,0 +1,513 @@
Guards Overview
===============
From a UX perspective, TorchDynamo is very easy to use. The user invokes
``torchdynamo.optimize`` as an annotation:
.. code-block:: python
@torchdynamo.optimize(my_compiler)
def fn_foo(bar):
Where a complete example looks like this:
.. code-block:: python
from typing import List
import torch
import torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
This allows TorchDynamo to capture the interpreted Python frames, grab
any and all relevant information, and speed things up wherever it can.
The speedup comes from a few places, and can be rather dependent on the
backend (my_compiler above) provided, but the one speedup we care about
most for todays overview is **caching**. Caching itself is not a direct
speedup, so much as a critical enablement to allow us to prevent
recompilation. We dig a hole with dynamo, and caching allows us to get
out. Its a speedup from that perspective, but relatively neutral when
all things are considered - however, it enables us to hold perf
neutrality while then enabling backends - the true source of our
speedups.
With even a pass-through no-op backend provided:
.. code-block:: python
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return gm.forward
We can see TorchDynamo speeding up Python execution quite a bit, even on
regular Python, not just PyTorch.
Caching and Guards Overview
---------------------------
TorchDynamo operates through caching transformed (by TorchDynamo) user
bytecode. When we receive a frame for evaluation, we check if the
**objects referenced in the frame have changed** in certain ways, and if
not, we read the previously transformed user bytecode to evaluate it.
The details of how we do this will be saved for a later writeup.
Instead, we will focus on how we can identify whether or not the
**objects referenced in the frame have changed**. This is a critical
piece of functionality in TorchDynamo, because it drives the entire
invalidation lifecycle. We refer to this functionality as **guards**.
At a very high level, the vastly oversimplified TLDR flow is this:
1) We receive a python frame
2) We convert the given frame from (1), passing it through instruction
translation
3) For the objects captured in (2), we create tracking objects that are
(a) tracked on an output graph, which is an internal specialization
of a torch.fx.Tracer (and the topic of a later writeup), and (b)
guards, the topic of this document.
4) We process the guard objects created in (3), turning them into a
generated python function, check_fn, associated with a piece of code.
5) The check_fn is evaluated whenever we encounter this code a
subsequent time - if a check_fn passes and evaluates to True, we know
the code in the cache and the code encountered here is the same, and
can be safely used. If it fails and evaluates to False, we know the
code in the cache is not valid, and can be thrown out in favor of a
new entry, through recompilation or a graph break.
Python Frame Evaluation and PEP 523
-----------------------------------
The functionality of TorchDynamo is based on
`PEP 523 <https://peps.python.org/pep-0523/>`__.
TorchDynamo installs a frame evaluation function on Python, via
`_PyInterpreterState_SetEvalFrameFunc`. The overview of function
selection, thread management, and cleanup is out of scope for this
writeup, but the important part is that TorchDynamo has a hook where
Python can hand control back to us during evaluation.
The function we have installed is ``convert_frame`` or
``convert_frame_assert`` in the ``nopython=True`` case, but glossing
over that nuance for now, lets take a look at ``convert_frame_assert``,
as ``convert_frame`` proxies to it anyway.
We can find it on `line 20 of convert_frame.py
<https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/convert_frame.py#L200>`__,
with a signature as follows:
.. code-block:: python
def convert_frame_assert(compiler_fn: Callable, one_graph=True):
This function wraps the entry point of where Python invokes TorchDynamo
with a frame, glossing over the nuances of ``wrap_convert_context`` for
now:
.. code-block:: python
def _convert_frame_assert(frame: types.FrameType, cache_size: int):
Here is what this function does:
1) Checks if it has seen this ``code``\ (see: f_code `here
<https://docs.python.org/3/library/inspect.html>`__) before and exits
early if it did.
2) Checks if the code is an unsupported case.
3) Checks if the ``cache_size`` (second arg above) crosses the limit
defined in the config, ``cache_size_limit``. If it has, the function
drops the frame and logs warnings. This helps to avoid constant
recompilation of a frame as it generally means that the frame is hot
in an unexpected way and caching it produces needless overhead,
as it is likely to get evicted the next time it is encountered.
4) Passes the frame, alongside a function that creates an
``InstructionTranslator`` through bytecode
transformation, via ``transform_code_object``. A few crucial things
happen under the hood here:
1) New code is produced through ``transform_code_object``.
2) An FX tracer named ``output`` is produced through
``InstructionTranslator``.
This can be a bit confusing,
as ``InstructionTranslator`` is not an `fx` tracer, but its stored
in a variable named tracer, and its output*\ **is**\ *an `fx`tracer.*
3) The function produces guards and stores them on ``output`` above.
4) The function produces ``output_instructions`` and stores them on
``output`` above.
5) The function maps the newly produced transformed code to the initial code it
read off the frame. This mapping is worth remembering, we will
refer to it much later on below where we cover guard failures.
5) Using the transformed code from 4.1 and the guards from 4.3
the function produces a `GuardedCode`.
Now that we have learned about frame evoluation, lets review
``InstructionTranslator``, and see how it turns the frame we handed
it over into TorchDynamo internal types.
InstructionTranslator
---------------------
`InstructionTranslator` does a lot! We wont cover the details of
everything it does, but most importantly for this document, it produces
a mapping of ``symbolic_locals`` which maintains a mapping from the
frames f_locals to TorchDynamo internal Variable objects (more on these
in a moment. ``symbolic_locals`` is filled via traversing the frames
locals:
.. code-block:: python
self.symbolic_locals = collections.OrderedDict(
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
for k in vars
if k in f_locals
)
We will get to how this works later, from a few other examples that lead
us to understanding ``VariableTracker`` and ``VariableBuilder``. The
important component here, for us, for now, is the invocation of a call
into ``VariableBuilder``. ``VariableBuilder``\ s call implementation
proxies into a function called ``_wrap``, which in turn both constructs
instances of ``VariableTracker`` and calls ``make_guards`` on them. More
on that later.
This mapping, in turn, is critical as each Variable has associated
guards, which are then passed to ``self.output``, the instance of
``OutputGraph``, an fx tracer, mentioned in 4.2 of the section above. If
you recall, this ``OutputGraph``, stored in a variable called ``output``
is where our guards are stored before being passed on to become
``GuardedCode``
How does ``InstructionTranslator`` do this? At the heart of it, there is
a loop that is pumped, which drives a function ``step``.
``step`` is just that - a single processing step, taking exactly one
instruction and doing *something* with it. Note: These are real
instructions processed by TorchDynamos ``transform_code_object``, and
its pretty cool.
.. note:: This section purposly skips the details of
`dis.get_instructions <https://docs.python.org/3/library/dis.html>`__,
and how we set up the ``Instruction`` class.
For the toy example above, here is a snippet of a what a few
``Instruction``\'s may look like:
.. code-block:: python
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)
This is the core functionality of this function. Take a look at the ``opname``,
and then take a look at this little snippet from inside ``step``;
.. code-block:: python
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
As we can see, we check if the current class, the
``InstructionTranslator`` has a attribute set matching the operator name
(ex: LOAD_CONST). If it does, we invoke it, passing the whole
instruction object in. If it does not, we drop the frame as
unimplemented.
For the LOAD_CONST example, we can see that we do indeed support it,
with a relatively straightforward definition:
::
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
Passing over, for now, on the other details of ``InstructionTranslator``
we can see that this function creates a new instance of the class
``ConstantVariable`` , with a value, in our example case, -1, and then
pushes it onto the stack.
There are dozens of such methods - see symbolic_convert.py for all of
them. Generally, we implement as many matching methods to python
bytecode instructions as possible.
Across both the logic downstream of ``step`` and the logic from invoking
``VariableBuilder`` - we now have a lot of ``VariableTracker``\ s and of
course, weve spoken about creating guards quiet a bit. Lets dig into
what Variables are, and get a little closer to understanding guards.
Variables
---------
A ``ConstantVariable`` is an instance of\ ``VariableTracker``.
``VariableTracker`` represents a tracked python local or stack value.
When it comes to representing an object inside TorchDynamo, a
VariableTracker does exactly what it says - it tracks a given variable.
Its an extremely flexible class, but there are a few points to keep in
mind:
- It manages the ``guard`` relationship around the underlying object
through:
- `make_guard`
- `replace_guards`
- `add_guard(s)`
- `propagate` - ``propagate(*vars: List[List["VariableTracker"]])`` -
Perhaps the most important of all, in that it combines guards from
all the provided VariableTracker instances passed in. It visits
the guards and combines the guards from these onto itself.
- It acts as a proxy on behalf of the underlying object, implementing
methods for the rest of TorchDynamo to get information about the
tracked object:
- `call_method`
- `call_function`
- `python_type`
- `as_proxy`
- `is/as_python_proxy`
- It stores the variable ``source`` of type ``Source``, from
torchdynamo/source.py. This source type is a relatively self
contained class to help us organize and bookeep where the original
source came from, and helps provide convenience methods for things
like getting the name, and importantly for us, producing guards.
And this class (``VariableTracker``) is built around subclassing,
somewhere between a full Abstract Base Class and fully fleshed out class
- it leaves many methods raising NotImplementedError - with reliance on
subclasses (see: torchdynamo/variables/ for all subclasses) to fulfill
contracts and custom behaviors.
Knowing what we know now, we can see an example of how an instruction
from ``dis``, ``BUILD_TUPLE``
BUILD_TUPLE(count) Creates a tuple consuming count items from the
stack, and pushes the resulting tuple onto the stack.
In our case, our signature will be a *little* different due to the way
we create ``Instruction`` objects, but the gist of it will be the same.
Instead of passing in ``count``, we pass in an object with a little
extra bookkeeping, and of course, we deal with turning regular old
python objects into TorchDynamo notions:
::
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(TupleVariable(items, **options))
What is happening here? 1) We read argval, which in this case, is
analogous to ``counts`` in the pydoc for the equivalent instruction.
2) We ``popn`` the items, in this case, the signature is
``def popn(self, n: int) -> List[TensorVariable]:`` this hints at an
underlying contract - we are returning ``TensorVariables``. If we
take a closer look at sybmolic_convert.py and
``InstructionTranslatorBase``/``InstructionTranslator``\ we see that
the only thing pushed onto and popped from our stack are
``VariableTracker``\ s.
3) We call ``VariableTracker.propogate`` (remember it, from above?) This
takes the guards from every single item popped off the stack in 2,
and recursively traverses it and combines all the guards into
``options``: ``py return { "guards": guards, }``
4) We then make a new instance of a ``VariableTracker``,
``TupleVariable``\ out of the ``items`` and ``options``. This then
allows us to install all the appropriate guards from the ``items``
that make up the new ``TupleVariable``
Note: You may wonder - where did the first guards come from? Propagation
is good and all, but dont we need something created before it can be
propagated. Yes! Remember that ``VariableBuilder`` above? It calls
``make_guards`` as it creates ``VariableTracker`` instances, from
``f_locals``. This in turn calls into the ``source``, to have it create
guards.
After all this, bytecode translation is done and we are one step closer
to producing ``GuardedCode``. We now understand how locals become
``VariableTracker``\ s, how instructions are handled, and where guards
are called on for creation. Before we can go into seeing how code and
guards are combined into a GuardedCode object, we need to dig a little
bit into those ``make_guard`` and ``source.make_guard`` calls above. We
can then understand, really, what was going on when we made guards
alongside, and on, ``VariableTracker`` instances.
Making Guards
-------------
Guards are just python objects, of the class ``Guard``, however, theres
a good amount of detail around this little class.
Looking at the definition of the dataclass (and therefore, ctor
signature), we see that it has a name, a source, and a create function.
::
@dataclasses.dataclass
class Guard:
name: str
source: GuardSource
create_fn: Callable
The name should be the name of the variable.
The source here is an enum indicating what *kind* of source the guard
belongs to [Note: not to be confused with ``Source`` and the other types
in source.py, as stored on ``VariableTracker``, as discussed above]
And create_fn is the heart of how we go from having this simple
dataclass to actually producing valid python code to be invoked for
knowing whether or not things have changed in between invocations, and
whether we can safely read from the code cache or not (In case you
forgot what all this was for!)
The most common code paths for getting an instance of a guard are
through ``make_guards`` on ``VariableTracker``.
``make_guards``->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)``
Or, in a concrete example:
.. code-block:: python
...
elif istype(value, range):
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
return RangeVariable(value=value, guards=guards)
Since ``source`` was set at the construction time of this
``VariableTracker``, all that was needed here was to provide the fn,
``GuardBuilder.EQUALS_MATCH`` to the ``create_fn`` field.
This ``create_fn`` must be a method on ``GuardBuilder``. The reason for
this becomes apparent in our next step. Once we have all the guards
created for a frame, we move on to ``CheckFunctionManager`` and
``compile_check_fn``.
Remember that ``convert_frame`` function way above, in the first
section? Before it can produce a ``GuardedCode``, it needs to run the
``CheckFunctionManager``, with all the guards, to produce a ``check_fn``
which will then, in turn get passed in alongside the code into
``GuardedCode``. This is the same ``check_fn`` that we store in our
cache entry, and the same one we run to know whether or not to retrieve
the code stored alongside. For reference, here is that code:
.. code-block:: cpp
static CacheEntry *create_cache_entry(CacheEntry *next,
PyObject *guarded_code) {
CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
return e;
}
We now know how a ``check_fn`` function is used, and who makes it, and
what it is composed of, but what we do not yet know is how. How does a
list of ``Guard`` objects become a function we can run later on?
First, we iterate these guards:
.. code-block:: python
for guard in sorted(guards or [], key=Guard.sort_key):
if not config.guard_nn_modules and guard.is_nn_module():
continue
guard.create(local_builder, global_builder)
Calling ``guard.create`` runs that ``create_fn`` we set on the ``Guard``
class above (dont confuse it with the ``check_fn`` we are working on
producing, the names are similar, so it can get a little confusing). In
our example above, our ``create_fn`` is ``GuardBuilder.EQUALS_MATCH``.
So we are now invoking it, passing in the ``self``, the guard itself,
in.
The signature is: ``def EQUALS_MATCH(self, guard: Guard):``
And internally to that function, we can use the ``name`` on the guard to
get back our original object, querying it for data and type information,
which in turn gets us to the most important bit: appending code.
At its simplest, ``EQUALS_MATCH`` appends just one line of code:
``self.code.append(f"{ref} == {val!r}")``. Where ``ref`` is the name of
the variable, and val is the value. It might produce code like this:
.. code-block::
y == 2
Pretty simple, but if we append a few other kinds of ``GuardBuilder``
functions on (For a more complex case), and then combine them all with
``and`` in between each statement (as we do), we might get something
like this:
.. code-block::
___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)
Now were talking! Lets see what we have here: 1) A check for
``.valid`` (we will come back to invalidation later on) 2) A type id
check 3) A value check 4) A tensor check
This becomes the heart of the code our ``check_fn``, which in turn, as
you recall, is evaluated the **next** time we encounter this code. It
will then check:
1) Is this code still valid?
2) If (1), Does ``y`` still have a type of ``94367738391392``?
3) If (2), is ``y`` still 2?
4) If (3), lets check on if tensor ``x`` changed in some specific ways
If all of these are still true, then we can use the code cached
alongside this ``check_fn``! Joyous day! [Note: a deeper dive for how
and where this happens if saved for a later writeup, but reading
``static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {`` of
``_eval_frame.c`` is a good place to start for the inquisitive reader
who has made it thus far].
If not, then, we can move on to recompiling the code anew, and storing
that in the cache alongside this code, and a whole new ``check_fn``,
again to be checked on yet another subsequent frame.
There are lots of other such functions on ``GuardBuilder`` which get
coalesced into, at times massive, strings which then get evaluated as
python code and stored into ``check_fn``. Our example above is
illustrative of a simple case, but I urge you to read the other
functions on ``GuardBuilder``, or better yet, dump the ``code`` variable
in ``compile_check_fn`` to really see whats getting produced,
especially on larger, real models!
Summary
-------
In this, we have glossed over: - The role of ``.valid`` and invalidation
around weak references (and potentially soon to be NN Module
invalidations) - How the C++ side of guard functions
(``___check_type_id``, ``___check_tensors``, etc) operate - What happens
when guards fail? - What happens if we produce invalid guard code?
Despite all that, I hope this has been a useful read. We covered how
user provided code, wrapped in a TorchDynamo context goes on to get
traced and tracked internally, organized into ``VariableTracker``\ s
``Source``\ s and subsequently ``Guard``\ s, and how those ``Guards`` in
turn guide cache entry selection and invalidation when handing Python
code.

View File

@@ -0,0 +1,44 @@
TorchDynamo Documentation
=========================
**TorchDynamo** is a Python-level JIT compiler designed to make unmodified
PyTorch programs faster. TorchDynamo hooks into the frame evaluation API
in CPython (`PEP 523 <https://peps.python.org/pep-0523/>`__) to
dynamically modify Python bytecode right before it is executed. It
rewrites Python bytecode in order to extract sequences of PyTorch
operations into an `FX Graph <https://pytorch.org/docs/stable/fx.html>`__
which is then just-in-time compiled with a customizable backend.
It creates this FX Graph through bytecode analysis and is designed to
mix Python execution with compiled backends to get the best of both
worlds: usability and performance.
TorchDynamo makes it easy to experiment with different compiler
backends to make PyTorch code faster with a single line decorator
``torch._dynamo.optimize()``
.. image:: ../_static/img/dynamo/TorchDynamo.png
For more information about `TorchInductor`, one of the backends
supported by `TorchDynamo Graph <https://pytorch.org/docs/stable/fx.html>`__
into `Triton <https://github.com/openai/triton>`__ for GPUs or
`C++/OpenMP <https://www.openmp.org/>`__ for CPUs. We have a
`training performance dashboard <https://github.com/pytorch/torchdynamo/issues/681#issuecomment-1233828468>`__
that provides performance comparison for different training backends. You can read
more in the `TorchInductor post on PyTorch
dev-discuss <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__.
.. seealso::
* `TorchDynamo deep-dive video <https://www.youtube.com/watch?v=egZB5Uxki0I>`__
* `dev-discuss topics <https://dev-discuss.pytorch.org/search?q=TorchDynamo%20order%3Alatest>`__
.. toctree::
:hidden:
installation
get-started
guards-overview
custom-backends
deep-dive
troubleshooting
faq

View File

@@ -0,0 +1,83 @@
Installing TorchDynamo
======================
This section describes how to install TorchDynamo.
Requirements and Setup
----------------------
Python 3.8 is recommended. Python 3.7 through 3.10 are supported and
tested. Make sure to have a development version of Python installed
locally as well.
TorchDynamo is included in the nightly binaries of PyTorch. You can
find more information `here <https://pytorch.org/get-started/locally/>`__
Install GPU/CUDA version requirements
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To use GPU back ends (and in particular Triton), please make sure that
the CUDA that you have installed locally matches the PyTorch version you
are running.
The following command installs GPU PyTorch+TorchDynamo along with GPU
TorchDynamo dependencies (for CUDA 11.7):
.. code-block:: python
pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
CPU requirements
~~~~~~~~~~~~~~~~
There are no additional requirements for CPU TorchDynamo. CPU
TorchDynamo is included in the nightly versions of PyTorch, which, for
reference, can be installed with the following command:
.. code-block:: shell
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
Install from local source
~~~~~~~~~~~~~~~~~~~~~~~~~
Build PyTorch from source:
https://github.com/pytorch/pytorch#from-source, which has TorchDynamo
included.
To install GPU TorchDynamo dependencies, run ``make triton`` in the
PyTorch repo root directory.
Verify Installation
~~~~~~~~~~~~~~~~~~~
If you built PyTorch from source, then you can run the following
commands (from the PyTorch repo root directory) that run minimal
examples to check that TorchDynamo is installed correctly:
.. code:: shell
cd tools/dynamo
python verify_dynamo.py
If you do not have the PyTorch source locally, you can alternatively
copy the script (``tools/dynamo/verify_dynamo.py``) from the PyTorch
repo and run it locally.
Docker installation
-------------------
We also provide all the required dependencies in the PyTorch nightly
binaries which you can download with
.. code-block::
docker pull ghcr.io/pytorch/pytorch-nightly
And for ad hoc experiments just make sure that your container has access
to all your GPUs
.. code-block:: bash
docker run --gpus all -it ghcr.io/pytorch/pytorch-nightly:latest /bin/bash

View File

@@ -0,0 +1,665 @@
TorchDynamo Troubleshooting
===========================
**Author**: `Michael Lazos <https://github.com/mlazos>`_
TorchDynamo is still in active development, and many of the reasons for
graph breaks and excessive recompilation will be fixed with upcoming
support for `tracing dynamic tensor
shapes <https://docs.google.com/document/d/1QJB-GOnbv-9PygGlOMXwiO9K6vVNm8sNg_olixJ9koc/edit?usp=sharing>`__,
more careful choices for guards and better tuned heuristics.
In the mean time, you may need to diagnose a particular issue and
determine if it is easy to work around with a change to your model, or
file an issue for support.
Also, we are actively developing debug tools, profilers, and improving our
errors/warnings. Please give us feedback if you have an issue with this
infra, or an idea for an improvement. Below is a table of the available
tools and their typical usage. For additional help see
`Diagnosing Runtime Errors <#diagnosing-runtime-errors>`__.
.. list-table:: Title
:widths: 25 25 50
:header-rows: 1
* - Tool
- Purpose
- Usage
* - Info logging
- View summarized steps of compilation
- ``torch._dynamo.config.log_level = logging.INFO``
* - Debug logging
- View detailed steps of compilation (print every instruction traced)
- ``torch._dynamo.config.log_level = logging.DEBUG`` and
``torch._dynamo.config.verbose = True``
* - Minifier for any backend
- Find smallest subgraph which reproduces errors for any backend
- set environment variable ``TORCHDYNAMO_REPRO_AFTER="dynamo"``
* - Minifier for ``TorchInductor``
- If the error is known to occur after `AOTAutograd`` find
smallest subgraph wich reproduces errors during TorchInductor lowering
- set environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"``
* - Accuracy minifier
- Finds the smallest subgraph which reproduces an accuracy issue
between an eager model model and optimized model
- ``TORCHDYNAMO_REPRO_AFTER=<"aot"/"dynamo"> TORCHDYNAMO_REPRO_LEVEL=4``
* - ``torch._dynamo.explain``
- Find graph breaks and display reasoning for them
- ``torch._dynamo.explain(fn, *inputs)``
* - Record/Replay
- Record and replay frames which to reproduce errors during graph capture
- ``torch._dynamo.config.replay_record_enabled = True``
* - TorchDynamo function name filtering
- Only compile functions with the given name to reduce noise when
debugging an issue
- set environment variable ``TORCHDYNAMO_DEBUG_FUNCTION=<name>``
* - TorchInductor Debug logging
- Print general TorchInductor debug info and generated Triton/C++ code
- ``torch._inductor.config.debug = True``
* - TorchInductor Tracing
- Show time taken in each TorchInductor stage + output code and graph
visualization
- set the environment variable TORCHINDUCTOR_TRACE=1 or
``torch._inductor.config.trace.enabled = True``
Diagnosing Runtime Errors
~~~~~~~~~~~~~~~~~~~~~~~~~
Below is the TorchDynamo compiler stack.
At a high level, the TorchDynamo stack consists of a graph capture from
Python code (TorchDynamo) and a backend compiler. In this example the
backend compiler consists of backward graph tracing (AOTAutograd) and
graph lowering (TorchInductor)*. Errors can occur in any component of
the stack and will provide full stack traces.
You may use info logging
(``torch._dynamo.config.log_level = logging.INFO``) and look for
``Step #: ...`` outputs in order to determine in which component the
error occurred in. Logs are made at the beginning and end of each step,
so the step that an error should correspond to is the most recent logged
step whose end has not yet been logged. The steps correspond to the
following parts of the stack (according to the image above):
==== ================
Step Component
==== ================
1 TorchDynamo
2 Compiler Backend
3 TorchInductor
==== ================
The beginning and end of AOTAutograd is currently not logged, but we
plan to add it soon.
If info logging is insufficient, then there are also some backend
options which can enable you to determine which component is causing the
error if youre unable to understand the error message that is
generated. These are the following:
- ``"eager"``: only runs torchdynamo forward graph capture and then
runs the captured graph with PyTorch. This provides an indication as
to whether TorchDynamo is raising the error.
- ``"aot_eager"``: runs torchdynamo to capture a forward graph, and
then AOTAutograd to trace the backward graph without any additional
backend compiler steps. PyTorch eager will then be used to run the
forward and backward graphs. This is useful to narrow down the issue
to AOTAutograd.
The general procedure to narrow down an issue is the following: 1. Run
your program with the ``"eager"`` backend. If the error no longer
occurs, the issue is in the backend compiler that is being used (if
using TorchInductor, proceed to step 2, if not, see `this
section <#minifying-backend-compiler-errors>`__). If the error still
occurs with the ``"eager"`` backend, it is an `error while running
torchdynamo <#torchdynamo-errors>`__.
2. This step is only necessary if TorchInductor is used as the backend
compiler. Run the model with the ``"aot_eager"`` backend. If this
backend raises an error then the error is occurring during
AOTAutograd tracing. If the error no longer occurs with this backend,
then `the error is in
TorchInductor\* <#minifying-torchinductor-errors>`__.
Each of these cases are analyzed in the following sections.
\*Note on TorchInductor naming: The TorchInductor backend consists of
both AOTAutograd tracing and the TorchInductor compiler itself. We will
disambiguate by referring to TorchInductor as the backend, and
TorchInductor lowering as the phase which lowers the graph traced by
AOTAutograd.
Torchdynamo Errors
------------------
If the error that is generated occurs with the ``"eager"`` backend, then
torchdynamo is the most likely source of the error. Here is example code
which will generate an error.
.. code:: py
import torch
import torch._dynamo as dynamo
@dynamo.optimize("eager")
def test_assertion_error():
y = torch.ones(200, 200)
z = {y: 5}
return z
test_assertion_error()
Which will generate the following error:
::
torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26
due to:
Traceback (most recent call last):
File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP
assert isinstance(k, ConstantVariable) or (
AssertionError
from user code:
File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error
z = {y: 5}
Set torch._dynamo.config.verbose=True for more information
==========
As the message suggests you can set
``torch._dynamo.config.verbose=True`` to get a full stack trace to both
the error in torchdynamo and the user code. In addition to this flag,
you can also set the ``log_level`` of torchdynamo through
``torch._dynamo.config.log_level``. The available levels are the
following: - ``logging.DEBUG``: Print every instruction that is
encountered in addition to all below log levels - ``logging.INFO``:
Print each function that is compiled (original and modified bytecode)
and the graph that is captured in addition to all below log levels -
``logging.WARNING`` (default): Print graph breaks in addition to all
below log levels - ``logging.ERROR``: Print errors only
If a model is sufficiently large, the logs can become overwhelming. If
an error occurs deep within a models python code, it can be useful to
execute only the frame in which the error occurs to enable easier
debugging. There are two tools available to enable this: - Setting the
environment variable TORCHDYNAMO_DEBUG_FUNCTION to the desired function
name will only run torchdynamo on functions with that name. - There is a
record/replay tool (set
``torch._dynamo.config.replay_record_enabled = True``) which dumps an
execution record when an error is encountered. This record can then be
replayed to run only the frame where an error occurred.
TorchInductor Errors
--------------------
If the error doesnt occur with the ``"eager"`` backend, then the
backend compiler is the source of the error (`example
error <https://gist.github.com/mlazos/2f13681e3cc6c43b3911f336327032de%5D>`__).
There are `different
choices <https://github.com/pytorch/torchdynamo/blob/0b8aaf340dad4777a080ef24bf09623f1aa6f3dd/README.md#existing-backends>`__
for backend compilers for torchdynamo, with TorchInductor or nvfuser
fitting the needs of most users. This section focuses on TorchInductor
as the motivating example, but some tools will be usable with other
backend compilers.
Below is the portion of the stack which we are focusing on:
With TorchInductor as the chosen backend, AOTAutograd is used to
generate the backward graph from the forward graph captured by
torchdynamo. Its important to note that errors can occur during this
tracing and also while TorchInductor lowers the forward and backward
graphs to GPU code or C++. A model can often consist of hundreds or
thousands of FX nodes, so narrowing the exact nodes where this problem
occurred can be very difficult. Fortunately, there are tools availabe to
automatically minify these input graphs to the nodes which are causing
the issue. The first step is to determine whether the error occurs
during tracing of the backward graph with AOTAutograd or during
TorchInductor lowering. As mentioned above in step 2, the
``"aot_eager"`` backend can be used to run only AOTAutograd in isolation
without lowering. If the error still occurs with this backend, this
indicates that the error is occurring during AOTAutograd tracing.
Heres an example:
.. code:: py
import torch
import torch._dynamo as dynamo
model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
@dynamo.optimize("inductor")
def test_backend_error():
y = torch.ones(200, 200)
x = torch.ones(200, 200)
z = x + y
a = torch.ops.aten._foobar(z) # dummy function which errors
return model(a)
test_backend_error()
Running this should give you this error (with a longer stack trace below
it)
::
Traceback (most recent call last):
File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function
return lowerings[target](*args, **kwargs)
File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped
return decomp_fn(*args, **kwargs)
File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar
assert False
AssertionError
...
`error with full stack
trace <https://gist.github.com/mlazos/d6947854aa56d686800259a164c62100>`__
If you then change ``@dynamo.optimize("inductor")`` to
``@dynamo.optimize("aot_eager")``, it will run without error, because
`the
issue <https://github.com/pytorch/torchdynamo/blob/d09e50fbee388d466b5252a63045643166006f77/torchinductor/lowering.py#:~:text=%23%20This%20shouldn%27t%20be,assert%20False>`__
is in the TorchInductor lowering process, not in AOTAutograd.
Minifying TorchInductor Errors
------------------------------
From here, lets run the minifier to get a minimal repro. Setting the
environment variable TORCHDYNAMO_REPRO_AFTER=“aot” (or setting
``torch._dynamo.config.repro_after="aot"`` directly) will generate a
python program which reduces the graph produced by AOTAutograd to the
smallest subgraph which reproduces the error. (See below for an example
where we minify the graph produced by torchdynamo) Running the program
with this environment variable should show nearly `identical
output <https://gist.github.com/mlazos/0458ab828aa403c779fe73c012aa5982>`__,
with an additional line indicating where ``minifier_launcher.py`` has
been written to. The output directory is configurable by setting
``torch._dynamo.config.base_dir`` to a valid directory name. The final
step is to run the minifier and check that it runs successfully. A
successful run looks like
`this <https://gist.github.com/mlazos/e6ea41ccce68a7b1b8a7a09acb1b206a>`__.
If the minifier runs successfully, it generates runnable python code
which reproduces the exact error. For our example this is the following
code:
.. code:: py
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.13.0a0+gitfddfc44
# torch cuda version: 11.6
# torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0
# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, add):
_foobar = torch.ops.aten._foobar.default(add); add = None
return (_foobar,)
args = [((200, 200), (200, 1), torch.float32, 'cpu')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)
from torch._inductor.compile_fx import compile_fx_inner
compiled = compile_fx_inner(mod, args)
compiled(*args)
The ``forward`` method of the ``Repro`` module contains the exact op
which causes the issue. When filing an issue, please include any
minified repros to aid in debugging.
Minifying Backend Compiler Errors
---------------------------------
With backend compilers other than TorchInductor the process for finding
the subgraph causing the error is nearly identical to the procedure in
`errors in TorchInductor <#torchinductor-errors>`__ with one important
caveat. Namely, that the minifier will now be run on the graph that is
traced by TorchDynamo, not the output graph of AOTAutograd. Lets walk
through an example.
.. code:: py
import torch
import torch._dynamo as dynamo
model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
# toy compiler which fails if graph contains relu
def toy_compiler(gm: torch.fx.GraphModule, _):
for node in gm.graph.nodes:
if node.target == torch.relu:
assert False
return gm
@dynamo.optimize(toy_compiler)
def test_backend_error():
y = torch.ones(200, 200)
x = torch.ones(200, 200)
z = x + y
a = torch.relu(z)
return model(a)
test_backend_error()
In order to run the code after TorchDynamo has traced the forward graph,
the TORCHDYNAMO_REPRO_AFTER enviornment variable can be used. Running
this program with TORCHDYNAMO_REPRO_AFTER=“dynamo” (or
``torch._dynamo.config.repro_after="dynamo"``) should produce `this
output <https://gist.github.com/mlazos/244e3d5b53667e44078e194762c0c92b>`__\ and
the following code in ``{torch._dynamo.config.base_dir}/repro.py``.
Note: the other option for TORCHDYNAMO_REPRO_AFTER are ``"aot"``, which
will run the minifier after the backward graph has been generated.
.. code:: py
import torch
import torch._dynamo as dynamo
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, add):
relu = torch.relu(add); add = None
return (relu,)
mod = Repro().cuda()
opt_mod = dynamo.optimize("None")(mod)
args = [((200, 200), (200, 1), torch.float32, 'cpu', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
with torch.cuda.amp.autocast(enabled=False):
ref = run_fwd_maybe_bwd(mod, args)
res = run_fwd_maybe_bwd(opt_mod, args)
The minifier successfully reduced the graph to the op that raises the
error in ``toy_compiler``. The other difference from the procedure in
`TorhInductor Errors <#torchinductor-errors>`__ is that the minifier is
automatically run after encountering a backend compiler error. After a
successful run, the minifier writes ``repro.py`` to
``torch._dynamo.config.base_dir``.
Performance Profiling
~~~~~~~~~~~~~~~~~~~~~
Accessing TorchDynamo Profiler
------------------------------
TorchDynamo has a builtin stats function for collecting and displaying
the time spent in each compilation phase. These stats can be accessed by
calling ``torch._dynamo.utils.compile_times()`` after executing
Torch._Dynamo. By default, this returns a string representation of the
compile times spent in each TorchDynamo function by name.
TorchInductor Debug Tracing
---------------------------
TorchInductor has a builtin stats and trace function for displaying time
spent in each compilation phase, output code, output graph visualization
and IR dump. This is a debugging tool designed to make it easier to
debug/understand the internals of TorchInductor.
Setting the environment variable ``TORCHINDUCTOR_TRACE=1`` will cause a
debug trace directory to be created and printed:
::
$ env TORCHINDUCTOR_TRACE=1 python repro.py
torch._inductor.debug: [WARNING] model_forward_0 debug trace: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.debug
Here is an `example debug directory
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
for the test program:
::
torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.LayerNorm(10),
torch.nn.ReLU(),
)
Note each file in that debug trace can be enabled/disabled via
``torch._inductor.config.trace.*``. The profile and the diagram are both
disabled by default since they are expensive to generate.
A single node in this new debug format looks like:
::
buf1: SchedulerNode(ComputedBuffer)
buf1.writes =
{ MemoryDep(name='buf1', index=0, size=()),
MemoryDep(name='buf1', index=0, size=(s0,))}
buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf1.group.device = cuda:0
buf1.group.iteration = (1, s0)
buf1.sizes = ([], [s0])
class buf1_loop_body:
var_ranges = {z0: s0}
index0 = z0
index1 = 0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_2', get_index_1, False)
add = ops.add(load, load_1)
get_index_2 = self.get_index('index1')
reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add)
return reduction
See the `example debug directory
output <https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396>`__
for more examples.
Memory Profiling
----------------
TBD
Graph Breaks
------------
Given a program like this:
.. code-block:: python
@dynamo.optimize(...)
def some_fun(x):
...
some_fun(x)
...
TorchDynamo will attempt to compile all of the torch/tensor operations
within some_fun into a single FX graph, but it may fail to capture
everything into one graph.
Some graph break reasons are insurmountable to TorchDynamo, and cant be
easily fixed. - calling into a C extension other than torch is invisible
to torchdynamo, and could do arbitrary things without TorchDynamo being
able to introduce necessary `guards <./GuardsOverviewPt1.md>`__ to
ensure that the compiled program would be safe to reuse. Graph breaks
can hinder performance if the resulting fragments are small. To maximize
performance, its important to have as few graph breaks as possible.
Identifying the cause of a graph break
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To identify all graph breaks in a program and the associated reasons for
the breaks, ``torch._dynamo.explain`` can be used. This tool runs
TorchDynamo on the supplied function and aggregates the graph breaks
that are encountered. Here is an example usage:
.. code-block:: python
import torch
import torch._dynamo as dynamo
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
print("woo")
if b.sum() < 0:
b = b * -1
return x * b
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
print(explanation)
"""
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
Break reasons:
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
File "t2.py", line 16, in toy_example
print("woo")
2. generic_jump
File "t2.py", line 17, in toy_example
if b.sum() < 0:
"""
Note on other outputs: - ``out_guards`` - a list of lists where each
sublist contains the guards that must pass to ensure the traced graphs
are valid - ``graphs`` - a list of graph modules which were successfully
traced - ``ops_per_graph`` - a list of lists where each sublist contains
the ops thatare run in the graph
To throw an error on the first graph break encountered, ``nopython``
mode can be used. This disables TorchDynamos python fallback, and only
succeeds if the entire program is convertible to a single graph. Example
usage:
.. code-block:: python
@dynamo.optimize(<compiler>, nopython=True)
def toy_example(a, b):
...
Excessive Recompilation
-----------------------
When TorchDynamo compiles a function (or part of one), it makes certain
assumptions about locals and globals in order to allow compiler
optimizations, and expresses these assumptions as guards that check
particular values at runtime. If any of these guards fail, Dynamo will
recompile that function (or part) up to
``torch._dynamo.config.cache_size_limit`` times. If your program is
hitting the cache limit, you will first need to determine which guard is
failing and what part of your program is triggering it.
The `recompilation profiler <#recompilation-profiler>`__ automates the
process of setting TorchDynamos cache limit to 1 and running your
program under an observation-only compiler that records the causes of
any guard failures. You should be sure to run your program for at least
as long (as many iterations) as you were running when you ran into
trouble, and the profiler will accumulate statistics over this duration.
If your program exhibits a bounded amount of dynamism, you may be able
to tune the TorchDynamo cache limit to allow for each variation to be
compiled and cached, but if the cache limit is too high you may find the
cost of recompilation outweighs any optimization benefits.
::
torch._dynamo.config.cache_size_limit = <your desired cache limit>
Torchdynamo plans to support many common cases of dynamic tensor shapes,
such as varying batch size or sequence length. It does not plan to
support rank-dynamism. In the mean time, setting a specific cache limit
can be used in coordination with bucketing techniques to achieve an
acceptable number of recompilations for some dynamic models.
.. code-block:: python
prof = dynamo.utils.CompilationProfiler()
@dynamo.optimize(prof)
def my_model():
...
my_model()
print(prof.report())
Accuracy Debugging
~~~~~~~~~~~~~~~~~~
Accuracy issues can also be minified if you set the environment variable
``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect
model and a full repro might be something like
``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason
we need this is downstream compilers will codegen code whether its
Triton code or the C++ backend, the numerics from those downstream
compilers can be different in subtle ways yet have dramatic impact on
your training stability. So the accuracy debugger is very useful for us
to detect bugs in our codegen or with a backend compiler.
File an Issue
~~~~~~~~~~~~~
You should feel encouraged to `file a github
issue <https://github.com/pytorch/torchdynamo/issues>`__ and expect a
timely response.
Before filing an issue, read over the `README <../README.md>`__,
`TROUBLESHOOTING <./TROUBLESHOOTING.md>`__, and search for similar
issues.
When filing an issue, please include - your
OS/python/pytorch/CUDA/triton info by running:
.. code-block:: sh
python tools/verify_install.py
- A minimal repro script if possible, which can be generated by running
Minifier
- A description of the error
- the expected behavior
- A log (set ``torch._dynamo.config.log_file`` to a valid file name to
dump the logs to a file and
``torch._dynamo.config.log_level = logging.DEBUG`` and
``torch._dynamo.config.verbose = True``)

View File

@@ -42,6 +42,13 @@ Features described in this documentation are classified by release status:
notes/*
.. toctree::
:glob:
:maxdepth: 1
:caption: torch.compile
dynamo/*
.. toctree::
:maxdepth: 1
:caption: Language Bindings