@galv pointed out that #170148 is extraneous as the default stream within a green-context should be capturable. Indeed this doesn't seem to be the reason that graph capture wasn't working, rather the issue was the parent stream at capture time was not properly restored and the check for it was also too restrictive (the underlying `CUDAStream` needs to be equal, not the wrapped object).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170317
Approved by: https://github.com/malfet, https://github.com/ngimel
On ROCm fast path routes to group_gemm_ck and slow path to _grouped_mm_fallback. By default, fast path = False route is activated since CK path is not performant yet. To activate CK path, use ROCM_ALLOW_GROUP_GEMM_CK=1 env variable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170159
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Similar to args_strategy, which gets a flat list of OpStrategies out of
the args_spec, args_meta helps provide a version of args_spec usable by
single_dim_strategy functions: it has the same pytree structure as the
original arg_spec, and contains all non-tensor args, but OpStrategy
values are replaced with TensorMeta values and TupleStrategy values are
replaced with Tuples of TensorMeta.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170358
Approved by: https://github.com/weifengpy
ghstack dependencies: #170197
Fixes `test_grad_accuracy_check` unit test
The root cause is because in backward graphs, the tangent's aliasing behavior can change. e.g. when you call it the first time, two tangents are alias, but in the next call they're not alias. If they share the same backward graph, the gradients can be wrong.
This doesn't happen in the forward graph, because dynamo handles it already, no inputs will be aliases.
We add the aliasing information to the invoke_subgraph_cache, so if input's aliasing changes, we re-trace the backward graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170485
Approved by: https://github.com/zou3519
Before this PR, graph capturing a program that performed an input mutation on a DTensor under training would hard error.
The context:
(1) When AOTAutograd traces out a joint graph and detects an input mutation, it makes a call to `old_input.copy_(new_input)`, and relies on the current make_fx call to capture this copy_ as a node in the joint graph ([code](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/graph_capture_wrappers.py#L733))
(2) ordinarily, capturing `old_input.copy_(new_input)` doesn't **need** to generate any fresh proxies in the graph, as the output of `copy_` is the self argument, which we expect to already have a proxy. Why does this matter? @IvanKobzarev added some logic to handle the case where a buffer is mutated during both the fw and the bw ([PR](https://github.com/pytorch/pytorch/pull/155354)), and as part of doing so, tweaked the input mutation handling in AOTAutograd so that these copy_ calls are generated under a [context manager](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/graph_capture_wrappers.py#L979C29-L979C72) that prevents proxy_tensor from adding new proxies to the graph. The idea being that we are applying this context manager in a very limited region, where we know no new proxies need to be created
(3) However, this is not true for DTensor. When you call `dtensor.copy_(dtensor)`, DTensor runs fake tensor prop under the hood, which involves constructing fresh FakeTensors for the inputs with which to run the fake prop on ([code](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_op_schema.py#L510))
The net result is that we end up *not* constructing proxies for these fake tensor inputs, and we get a "proxy not found" error immediately afterwards when attempting to use them when DTensor runs fake prop ([here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_sharding_prop.py#L243))
The way I fixed this was just by tweaking the "don't clobber proxies" context manager to be a bit more general: it will still generate proxies for inputs that don't already have proxies, and it simply won't overwrite an existing proxy with a new one when you trace an inplace op.
One alternative would have been to disable proxy tracing when DTensor runs fake prop. Since after all, we don't really care about the ops that DTensor ran during fake prop. I decided not to do this because that code has changed a bunch recently and is pretty fragile, but I'm hoping to it if people prefer that path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170467
Approved by: https://github.com/IvanKobzarev
Summary:
Disable `mm_template` as a template option for `scaled_mm`. `mm_template` leverages epilogue scaling for `scaled_mm`; however, block-wise scaling, which uses the main loop scaling template, must apply scaling to each input tensor before accumulation, rather than as an epilogue after accumulation.
NOTE: It is interesting that for small-enough shapes, `mm_template` seems to actually pass in Inductor without a CUDA error or IMA, and it is the winner over the main loop scaling template. The same generated kernel fails when attempting to repro locally. As a follow-up, it would be useful to investigate whether there is an implementation gap in `mm_template` or `main_loop_scaling_template`.
Test Plan:
```
CUDA_LAUNCH_BLOCKING=1 MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TRITON_PRINT_AUTOTUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir TORCH_LOGS=+inductor,"output_code" python3 run.py --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics accuracy,tflops,latency --m 32768 --n 8192 --k 8192 --output ~/personal/fp8_scaling_benchmarks/blockwise1x128_blockwise128x128.csv --scaling-pair BlockWise1x128,BlockWise128x128 --bypass-fail 2<&1 | tee ~/personal/fp8_scaling_benchmarks/blockwise1x128_blockwise128x128_2.log
```
Differential Revision: D88900710
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170139
Approved by: https://github.com/NikhilAPatel
Fixes following unused variable warning
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal:288:27: warning: unused parameter 'nnz' [-Wunused-parameter]
constant uint& nnz [[buffer(2)]],
```
Also, use short circuit language rule to make kernel more compact
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170403
Approved by: https://github.com/Skylion007
Flex-decoding previously is only active in VLLM in very narrow cases. We use flex-decoding if the workload can be processed with a single BLOCK_M: i.e. `BLOCK_M >= seq_len_q * G`, where G is the ratio between Q head and KV head. We use BLOCK_M as 16 in the vllm integration.
Take llama3 8B as an example, G is 4 here. That means we use flex-decoding only if `seq_len_q` <= 4. This is very restrictive and make flex-decoding being skipped most of the time.
The PR enhance the flex-decoding kernel to work for multi-BLOCK_M.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/170343
Approved by: https://github.com/drisspg