mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Builds the scaffolding for a full Inductor backend based on NVIDIA’s new [Python universal GEMM template](https://github.com/NVIDIA/cutlass/tree/cutlass_api/python/cutlass_api), including initial setup of the `mm` execution path. Can benchmark on trunk Tritonbench with: `clear; TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TRITON_ALWAYS_COMPILE=1 TRITON_PRINT_AUTOTUNING=1 TORCH_LOGS=+inductor CUDA_VISIBLE_DEVICES=7 python run.py --op gemm --only aten_matmul,pt2_triton_matmul,pt2_nv_universal_gemm_matmul --metrics tflops,accuracy --force --cudagraph --num-inputs 1 ` Unit tests: `test/inductor/test_nv_universal_gemm.py` Followup PRs planned: - [ ] nvMatmulHeuristics to narrow search space - [ ] Autotune compilation in parallel - [ ] Epilogue fusions - [ ] BMM - [ ] AddMM - [ ] Dynamic shape support - [ ] AOTI support - [ ] Cutlass API kernels requiring a device workspace Pull Request resolved: https://github.com/pytorch/pytorch/pull/170623 Approved by: https://github.com/drisspg
4.9 KiB
4.9 KiB