mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Similar to https://github.com/pytorch/pytorch/pull/158137 (thank you, @AaronWang04, for the instructional tips and answering my questions!), but performs `Activation(Addmm) -> _addmm_activation` replacement instead of `Activation(add(mm)) -> _addmm_activation`. The reasons as to why this mapping over the one in https://github.com/pytorch/pytorch/pull/158137 are: - Prior work done to extend cuBLASLt coverage in `addmm` beyond just 1D bias and `beta=1, alpha=1`. As long as there is an activation after `addmm`, we can call Lt. This makes the check for pattern replacement leaner and agnostic to the inputs' meta-data (`addmm`'s checks for free). - Inductor intercepts `addmm` and replaces it with `alpha * [alpha != 1] * m1 @ m2 + beta * [beta != 1] * input` when followed by point-wise consumers (including activation functions). So it is way easier and cleaner to intercept just `addmm` (and not combinatorial set of patterns) before such replacements. Re-run of the benchmark script in https://github.com/pytorch/pytorch/pull/158137 on H100 yields: `float16`: ``` ============================================================ Testing with M=1024, N=1024, K=1024, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0096 ms Average Time per Iteration (torch compile): 0.0407 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0270 ms Average Time per Iteration (torch compile): 0.0409 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.1828 ms Average Time per Iteration (torch compile): 0.2415 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=float16 ============================================================ Average Time per Iteration (cublas): 1.5971 ms Average Time per Iteration (torch compile): 1.9723 ms ``` `bfloat16`: ``` ============================================================ Testing with M=1024, N=1024, K=1024, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0093 ms Average Time per Iteration (torch compile): 0.0416 m ============================================================ Testing with M=2048, N=2048, K=2048, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0264 ms Average Time per Iteration (torch compile): 0.0411 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.1768 ms Average Time per Iteration (torch compile): 0.2430 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 1.5564 ms Average Time per Iteration (torch compile): 1.8916 ms ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168157 Approved by: https://github.com/eellison, https://github.com/eqy