Files
pytorch/torchgen
Nikita Vedeneev 8982850fb6 [Inductor] ReLU/GELU(Addmm) fusions (#168157)
Similar to https://github.com/pytorch/pytorch/pull/158137 (thank you, @AaronWang04, for the instructional tips and answering my questions!), but performs
`Activation(Addmm) -> _addmm_activation` replacement instead of `Activation(add(mm)) -> _addmm_activation`.

The reasons as to why this mapping over the one in https://github.com/pytorch/pytorch/pull/158137 are:
- Prior work done to extend cuBLASLt coverage in `addmm` beyond just 1D bias and `beta=1, alpha=1`. As long as there is an activation after `addmm`, we can call Lt. This makes the check for pattern replacement leaner and agnostic to the inputs' meta-data (`addmm`'s checks for free).
- Inductor intercepts `addmm` and replaces it with
`alpha * [alpha != 1] * m1 @ m2 + beta * [beta != 1] * input` when followed by point-wise consumers (including activation functions).
So it is way easier and cleaner to intercept just `addmm` (and not combinatorial set of patterns) before such replacements.

Re-run of the benchmark script in https://github.com/pytorch/pytorch/pull/158137 on H100 yields:

`float16`:
```
============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas):     0.0096 ms
Average Time per Iteration (torch compile):      0.0407 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas):     0.0270 ms
Average Time per Iteration (torch compile):      0.0409 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas):     0.1828 ms
Average Time per Iteration (torch compile):      0.2415 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas):     1.5971 ms
Average Time per Iteration (torch compile):      1.9723 ms
```

`bfloat16`:
```
============================================================
Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):     0.0093 ms
Average Time per Iteration (torch compile):      0.0416 m

============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):     0.0264 ms
Average Time per Iteration (torch compile):      0.0411 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):     0.1768 ms
Average Time per Iteration (torch compile):      0.2430 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):     1.5564 ms
Average Time per Iteration (torch compile):      1.8916 ms
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/168157
Approved by: https://github.com/eellison, https://github.com/eqy
2025-12-08 14:54:00 +00:00
..
2025-11-28 08:00:09 +00:00
2025-11-21 01:17:00 +00:00