From fdccb593c81fd8d39082e16b38ae5d54e51d6e85 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 22 Dec 2025 16:41:45 +0000 Subject: [PATCH] Add normalization and activation ops to operator benchmarks (#169544) We're adding some more ops to the benchmarking: Normalization ops: - LayerNorm - RMSNorm - BatchNorm1d - BatchNorm2d - BatchNorm3d - GroupNorm Activation ops: - nn.GELU - nn.SiLU - nn.ReLU - nn.LeakyReLU Pull Request resolved: https://github.com/pytorch/pytorch/pull/169544 Approved by: https://github.com/slayton58 --- .ci/pytorch/test.sh | 2 +- .../operator_benchmark/benchmark_core.py | 15 ++ .../operator_benchmark/pt/activation_test.py | 67 ++++++ benchmarks/operator_benchmark/pt/norm_test.py | 219 ++++++++++++++++++ 4 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 benchmarks/operator_benchmark/pt/activation_test.py create mode 100644 benchmarks/operator_benchmark/pt/norm_test.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index b514b7363be..c6a440b1f14 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1811,7 +1811,7 @@ test_operator_microbenchmark() { cd "${TEST_DIR}"/benchmarks/operator_benchmark # NOTE: When adding a new test here, please update README: ../../benchmarks/operator_benchmark/README.md - for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer; do + for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer activation norm; do $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ --benchmark-name "PyTorch operator microbenchmark" --use-compile diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 5e88af6738a..7eadf39cec9 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -292,6 +292,8 @@ class BenchmarkRunner: "latency unit": "us", "peak memory": results["peak_memory"], "memory unit": "KB", + "memory bandwidth": results.get("memory_bandwidth_gb_s"), + "memory bandwidth unit": "GB/s", } # parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary @@ -559,6 +561,7 @@ class BenchmarkRunner: run_type = perf_item.get("run") latency = perf_item.get("latency", 0) peak_memory = perf_item.get("peak memory", 0) + memory_bandwidth = perf_item.get("memory bandwidth", 0) device = perf_item.get("device", "unknown") dtype = perf_item.get("dtype", "torch.float").split(".")[1] runtime = perf_item.get("runtime", None) @@ -656,6 +659,16 @@ class BenchmarkRunner: ) records.append(asdict(record_memory)) + # Add record for memory bandwidth + record_memory_bandwidth = copy.deepcopy(record_latency) + record_memory_bandwidth.metric = MetricInfo( + name="memory bandwidth", + unit="GB/s", + benchmark_values=[memory_bandwidth], + target_value=None, + ) + records.append(asdict(record_memory_bandwidth)) + # Write all records to the output file with open(output_file, "w", encoding="utf-8") as f: json.dump(records, f, indent=2) @@ -671,6 +684,7 @@ class BenchmarkRunner: "run_backward", "Execution Time", "Peak Memory (KB)", + "Memory Bandwidth (GB/s)", ] if self.args.output_json or self.args.output_json_for_dashboard: @@ -746,6 +760,7 @@ class BenchmarkRunner: test_case.test_config.run_backward, result_dict["reported_run_time_us"][0], result_dict["peak_memory"], + result_dict["memory_bandwidth_gb_s"], ], ) if self.args.output_json or self.args.output_json_for_dashboard: diff --git a/benchmarks/operator_benchmark/pt/activation_test.py b/benchmarks/operator_benchmark/pt/activation_test.py new file mode 100644 index 00000000000..cd2cbe41914 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/activation_test.py @@ -0,0 +1,67 @@ +import operator_benchmark as op_bench + +import torch +import torch.nn as nn + + +"""Microbenchmarks for activation operators.""" + + +activation_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["gelu", nn.GELU], + ["silu", nn.SiLU], + ["relu", nn.ReLU], + ["leaky_relu", nn.LeakyReLU], + ], +) + +activation_short_configs = op_bench.config_list( + attr_names=["shape"], + attrs=[ + [(1,)], + [(64,)], + [(4096,)], + [(8192,)], + ], + cross_product_configs={ + "device": ["cuda"], + }, + tags=["short"], +) + +activation_long_configs = op_bench.cross_product_configs( + shape=[(1,), (64,), (4096,), (8192,), (131072,), (262144,), (524288,), (1048576,)], + device=["cuda"], + tags=["long"], +) + + +class ActivationBenchmark(op_bench.TorchBenchmarkBase): + def init(self, op_func, device, shape): + self.inputs = { + "input": torch.rand(shape, device=device, requires_grad=self.auto_set()) + } + self.op_func = op_func() + self.set_module_name(op_func.__name__) + + def forward(self, input): + return self.op_func(input) + + +op_bench.generate_pt_tests_from_op_list( + activation_list, + activation_long_configs, + ActivationBenchmark, +) + +op_bench.generate_pt_gradient_tests_from_op_list( + activation_list, + activation_long_configs, + ActivationBenchmark, +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/norm_test.py b/benchmarks/operator_benchmark/pt/norm_test.py new file mode 100644 index 00000000000..cd897caf42d --- /dev/null +++ b/benchmarks/operator_benchmark/pt/norm_test.py @@ -0,0 +1,219 @@ +import operator_benchmark as op_bench + +import torch +import torch.nn as nn + + +"""Microbenchmarks for normalization operators.""" + + +# ============================================================================== +# LayerNorm and RMSNorm Benchmarks +# ============================================================================== + +layernorm_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["LayerNorm", nn.LayerNorm], + ["RMSNorm", nn.RMSNorm], + ], +) + +layernorm_configs = op_bench.cross_product_configs( + B=[8, 32], + M=[256, 1024], + K=[64, 128, 512], + device=["cuda"], + tags=["long"], +) + + +class LayerNormBenchmark(op_bench.TorchBenchmarkBase): + def init(self, op_func, device, B, M, K): + self.inputs = { + "input": torch.rand(B, M, K, device=device, requires_grad=self.auto_set()) + } + # normalized_shape is the last dimension (hidden dim K) + self.op_func = op_func(K, device=device) + self.set_module_name(op_func.__name__) + + def forward(self, input): + return self.op_func(input) + + +op_bench.generate_pt_tests_from_op_list( + layernorm_list, + layernorm_configs, + LayerNormBenchmark, +) + +op_bench.generate_pt_gradient_tests_from_op_list( + layernorm_list, + layernorm_configs, + LayerNormBenchmark, +) + + +# ============================================================================== +# BatchNorm1d Benchmarks (training + eval) +# ============================================================================== + +batchnorm1d_configs = op_bench.cross_product_configs( + B=[8, 32], + C=[64, 128, 256], + M=[256, 1024], + device=["cuda"], + training=[True, False], + tags=["long"], +) + + +class BatchNorm1dBenchmark(op_bench.TorchBenchmarkBase): + def init(self, device, B, C, M, training): + self.inputs = { + "input": torch.rand(B, C, M, device=device, requires_grad=self.auto_set()) + } + self.op_func = nn.BatchNorm1d(C, device=device) + self.op_func.train(training) + self.set_module_name("BatchNorm1d") + + def forward(self, input): + return self.op_func(input) + + +op_bench.generate_pt_test( + batchnorm1d_configs, + BatchNorm1dBenchmark, +) + +op_bench.generate_pt_gradient_test( + batchnorm1d_configs, + BatchNorm1dBenchmark, +) + + +# ============================================================================== +# BatchNorm2d Benchmarks (training + eval) +# ============================================================================== + +batchnorm2d_configs = op_bench.cross_product_configs( + B=[8, 32], + C=[64, 128, 256], + H=[28, 56], + W=[28, 56], + device=["cuda"], + training=[True, False], + tags=["long"], +) + + +class BatchNorm2dBenchmark(op_bench.TorchBenchmarkBase): + def init(self, device, B, C, H, W, training): + self.inputs = { + "input": torch.rand( + B, C, H, W, device=device, requires_grad=self.auto_set() + ) + } + self.op_func = nn.BatchNorm2d(C, device=device) + self.op_func.train(training) + self.set_module_name("BatchNorm2d") + + def forward(self, input): + return self.op_func(input) + + +op_bench.generate_pt_test( + batchnorm2d_configs, + BatchNorm2dBenchmark, +) + +op_bench.generate_pt_gradient_test( + batchnorm2d_configs, + BatchNorm2dBenchmark, +) + + +# ============================================================================== +# BatchNorm3d Benchmarks (training + eval) +# ============================================================================== + +batchnorm3d_configs = op_bench.cross_product_configs( + B=[8, 32], + C=[64, 128, 256], + D=[4, 8], + H=[14, 28], + W=[14, 28], + device=["cuda"], + training=[True, False], + tags=["long"], +) + + +class BatchNorm3dBenchmark(op_bench.TorchBenchmarkBase): + def init(self, device, B, C, D, H, W, training): + self.inputs = { + "input": torch.rand( + B, C, D, H, W, device=device, requires_grad=self.auto_set() + ) + } + self.op_func = nn.BatchNorm3d(C, device=device) + self.op_func.train(training) + self.set_module_name("BatchNorm3d") + + def forward(self, input): + return self.op_func(input) + + +op_bench.generate_pt_test( + batchnorm3d_configs, + BatchNorm3dBenchmark, +) + +op_bench.generate_pt_gradient_test( + batchnorm3d_configs, + BatchNorm3dBenchmark, +) + + +# ============================================================================== +# GroupNorm Benchmarks +# ============================================================================== + +groupnorm_configs = op_bench.cross_product_configs( + B=[8, 32], + C=[64, 128, 256], + H=[28, 56], + W=[28, 56], + num_groups=[8, 16, 32], + device=["cuda"], + tags=["long"], +) + + +class GroupNormBenchmark(op_bench.TorchBenchmarkBase): + def init(self, device, B, C, H, W, num_groups): + self.inputs = { + "input": torch.rand( + B, C, H, W, device=device, requires_grad=self.auto_set() + ) + } + self.op_func = nn.GroupNorm(num_groups, C, device=device) + self.set_module_name("GroupNorm") + + def forward(self, input): + return self.op_func(input) + + +op_bench.generate_pt_test( + groupnorm_configs, + GroupNormBenchmark, +) + +op_bench.generate_pt_gradient_test( + groupnorm_configs, + GroupNormBenchmark, +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main()