add a curve for customized compilation in the kernel benchmarking scripts (#166697)

It's nice to add a curve with a customized compilation options so that we can compare side-by-side the perf improvement of new features.

E.g. for mix-order-reduction, by running the following command
```
python benchmarks/dynamo/genai_layers/benchmark.py --tolerance=1e-2 --exit-on-accuracy-failure --visualize rmsnorm_backward --custom-compile-name="compiled-no-fusion" --custom-compile-options='{"triton.mix_order_reduction":false}'
```

I get following output:
```
Geomean speedup for benchmark RMSNormBackward
  eager 11 data points
  compiled 11 data points, 15.82x speedup
  quack 11 data points, 15.45x speedup
  liger 11 data points, 14.06x speedup
  compiled-no-fusion 11 data points, 10.26x speedup
```

The output shows that the feature on average improve perf by `15.82 / 10.26 = 1.54x` for all the shapes tested. (I remove a shape (32768, 32768) whose rnumel is too large and not representative).

The new curve also shows up in the figure:
<img width="3564" height="2368" alt="RMSNormBackward_bench" src="https://github.com/user-attachments/assets/1ffac2bc-e726-4f1e-806d-e9e5de711492" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166697
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #166053, #166382, #166461, #166585, #166675
This commit is contained in:
Shunting Zhang
2025-10-31 14:49:55 -07:00
committed by PyTorch MergeBot
parent a19e92d433
commit 9f9dbe0a9a
3 changed files with 78 additions and 15 deletions

View File

@@ -163,8 +163,37 @@ Examples:
help="Whether to print the raw benchmarking result. Easier to quickly check the benchmark results on a server without GUI",
)
parser.add_argument(
"--custom-compile-name",
type=str,
default=None,
help="Name for the curve with customized compilation options",
)
parser.add_argument(
"--custom-compile-options",
type=str,
default=None,
help="Json string for the custom compile options.",
)
args = parser.parse_args()
if args.custom_compile_options:
import json
try:
args.custom_compile_options = json.loads(args.custom_compile_options)
except json.decoder.JSONDecodeError as e:
raise RuntimeError(
f"Invalid json string for --custom-compile-options: {args.custom_compile_options}"
) from e
if not args.custom_compile_options:
raise RuntimeError("Found no options for --custom-compile-options")
if not args.custom_compile_name:
raise RuntimeError("Missing label name for the custom compilation")
# Handle list option
if args.list:
list_benchmarks()

View File

@@ -447,7 +447,6 @@ class RMSNormBackward(BenchmarkKernel):
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
) + extra_shapes_for_norm
def get_memory_bytes(self, args, kwargs) -> int:

View File

@@ -108,6 +108,18 @@ class BenchmarkKernel:
for backend in self.available_backends:
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
if (
"compiled" in self.available_backends
and self.script_args.custom_compile_options
):
torch._dynamo.reset() # cause recompile
with torch._inductor.config.patch(self.script_args.custom_compile_options):
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
res[self.script_args.custom_compile_name] = self.compiled(
args_ref, kwargs_ref
)()
gold = res["eager"]
tol = {}
@@ -116,7 +128,7 @@ class BenchmarkKernel:
"atol": self.script_args.tolerance,
"rtol": self.script_args.tolerance,
}
for backend in self.available_backends:
for backend in res:
if backend == "eager":
continue
try:
@@ -135,25 +147,48 @@ class BenchmarkKernel:
print("Exit right away since --exit-on-accuracy-failure is set")
sys.exit(1)
def benchmark_single_shape_for_backend(
self, backend, args, kwargs, setting, fn=None
) -> bool:
if fn is None:
fn = getattr(self, backend)
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
try:
avg_time = benchmark_kernel_in_milliseconds(fn(args_ref, kwargs_ref))
except Exception as e:
print(
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
)
self.available_backends.remove(backend) # noqa: B909
return False
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
perf = Performance(setting, avg_time, mem_bytes)
print(f"{self.name} kernel on {backend} backend. {perf}")
self.profiling_results[backend].append(perf)
return True
def benchmark_single_shape(
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
):
for backend in self.available_backends:
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
try:
avg_time = benchmark_kernel_in_milliseconds(
getattr(self, backend)(args_ref, kwargs_ref)
self.benchmark_single_shape_for_backend(backend, args, kwargs, setting)
if (
"compiled" in self.available_backends
and self.script_args.custom_compile_options
):
torch._dynamo.reset() # cause recompile
with torch._inductor.config.patch(self.script_args.custom_compile_options):
status = self.benchmark_single_shape_for_backend(
self.script_args.custom_compile_name,
args,
kwargs,
setting,
fn=self.compiled,
)
except Exception as e:
print(
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
if not status:
self.script_args.custom_compile_options = (
None # once fail, don't run again
)
self.available_backends.remove(backend) # noqa: B909
continue
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
perf = Performance(setting, avg_time, mem_bytes)
print(f"{self.name} kernel on {backend} backend. {perf}")
self.profiling_results[backend].append(perf)
if should_check_accuracy:
self.check_accuracy(args, kwargs)