mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
add a curve for customized compilation in the kernel benchmarking scripts (#166697)
It's nice to add a curve with a customized compilation options so that we can compare side-by-side the perf improvement of new features.
E.g. for mix-order-reduction, by running the following command
```
python benchmarks/dynamo/genai_layers/benchmark.py --tolerance=1e-2 --exit-on-accuracy-failure --visualize rmsnorm_backward --custom-compile-name="compiled-no-fusion" --custom-compile-options='{"triton.mix_order_reduction":false}'
```
I get following output:
```
Geomean speedup for benchmark RMSNormBackward
eager 11 data points
compiled 11 data points, 15.82x speedup
quack 11 data points, 15.45x speedup
liger 11 data points, 14.06x speedup
compiled-no-fusion 11 data points, 10.26x speedup
```
The output shows that the feature on average improve perf by `15.82 / 10.26 = 1.54x` for all the shapes tested. (I remove a shape (32768, 32768) whose rnumel is too large and not representative).
The new curve also shows up in the figure:
<img width="3564" height="2368" alt="RMSNormBackward_bench" src="https://github.com/user-attachments/assets/1ffac2bc-e726-4f1e-806d-e9e5de711492" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166697
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #166053, #166382, #166461, #166585, #166675
This commit is contained in:
committed by
PyTorch MergeBot
parent
a19e92d433
commit
9f9dbe0a9a
@@ -163,8 +163,37 @@ Examples:
|
||||
help="Whether to print the raw benchmarking result. Easier to quickly check the benchmark results on a server without GUI",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--custom-compile-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name for the curve with customized compilation options",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--custom-compile-options",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Json string for the custom compile options.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.custom_compile_options:
|
||||
import json
|
||||
|
||||
try:
|
||||
args.custom_compile_options = json.loads(args.custom_compile_options)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Invalid json string for --custom-compile-options: {args.custom_compile_options}"
|
||||
) from e
|
||||
|
||||
if not args.custom_compile_options:
|
||||
raise RuntimeError("Found no options for --custom-compile-options")
|
||||
if not args.custom_compile_name:
|
||||
raise RuntimeError("Missing label name for the custom compilation")
|
||||
|
||||
# Handle list option
|
||||
if args.list:
|
||||
list_benchmarks()
|
||||
|
||||
@@ -447,7 +447,6 @@ class RMSNormBackward(BenchmarkKernel):
|
||||
(32768, 4096),
|
||||
(32768, 8192),
|
||||
(32768, 16384),
|
||||
(32768, 32768),
|
||||
) + extra_shapes_for_norm
|
||||
|
||||
def get_memory_bytes(self, args, kwargs) -> int:
|
||||
|
||||
@@ -108,6 +108,18 @@ class BenchmarkKernel:
|
||||
for backend in self.available_backends:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
|
||||
|
||||
if (
|
||||
"compiled" in self.available_backends
|
||||
and self.script_args.custom_compile_options
|
||||
):
|
||||
torch._dynamo.reset() # cause recompile
|
||||
with torch._inductor.config.patch(self.script_args.custom_compile_options):
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
res[self.script_args.custom_compile_name] = self.compiled(
|
||||
args_ref, kwargs_ref
|
||||
)()
|
||||
|
||||
gold = res["eager"]
|
||||
|
||||
tol = {}
|
||||
@@ -116,7 +128,7 @@ class BenchmarkKernel:
|
||||
"atol": self.script_args.tolerance,
|
||||
"rtol": self.script_args.tolerance,
|
||||
}
|
||||
for backend in self.available_backends:
|
||||
for backend in res:
|
||||
if backend == "eager":
|
||||
continue
|
||||
try:
|
||||
@@ -135,25 +147,48 @@ class BenchmarkKernel:
|
||||
print("Exit right away since --exit-on-accuracy-failure is set")
|
||||
sys.exit(1)
|
||||
|
||||
def benchmark_single_shape_for_backend(
|
||||
self, backend, args, kwargs, setting, fn=None
|
||||
) -> bool:
|
||||
if fn is None:
|
||||
fn = getattr(self, backend)
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
try:
|
||||
avg_time = benchmark_kernel_in_milliseconds(fn(args_ref, kwargs_ref))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
|
||||
)
|
||||
self.available_backends.remove(backend) # noqa: B909
|
||||
return False
|
||||
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
|
||||
perf = Performance(setting, avg_time, mem_bytes)
|
||||
print(f"{self.name} kernel on {backend} backend. {perf}")
|
||||
self.profiling_results[backend].append(perf)
|
||||
return True
|
||||
|
||||
def benchmark_single_shape(
|
||||
self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
|
||||
):
|
||||
for backend in self.available_backends:
|
||||
args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
|
||||
try:
|
||||
avg_time = benchmark_kernel_in_milliseconds(
|
||||
getattr(self, backend)(args_ref, kwargs_ref)
|
||||
self.benchmark_single_shape_for_backend(backend, args, kwargs, setting)
|
||||
if (
|
||||
"compiled" in self.available_backends
|
||||
and self.script_args.custom_compile_options
|
||||
):
|
||||
torch._dynamo.reset() # cause recompile
|
||||
with torch._inductor.config.patch(self.script_args.custom_compile_options):
|
||||
status = self.benchmark_single_shape_for_backend(
|
||||
self.script_args.custom_compile_name,
|
||||
args,
|
||||
kwargs,
|
||||
setting,
|
||||
fn=self.compiled,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
|
||||
if not status:
|
||||
self.script_args.custom_compile_options = (
|
||||
None # once fail, don't run again
|
||||
)
|
||||
self.available_backends.remove(backend) # noqa: B909
|
||||
continue
|
||||
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
|
||||
perf = Performance(setting, avg_time, mem_bytes)
|
||||
print(f"{self.name} kernel on {backend} backend. {perf}")
|
||||
self.profiling_results[backend].append(perf)
|
||||
|
||||
if should_check_accuracy:
|
||||
self.check_accuracy(args, kwargs)
|
||||
|
||||
Reference in New Issue
Block a user