Add normalization and activation ops to operator benchmarks (#169544)

We're adding some more ops to the benchmarking:

Normalization ops:
- LayerNorm
- RMSNorm
- BatchNorm1d
- BatchNorm2d
- BatchNorm3d
- GroupNorm

Activation ops:
- nn.GELU
- nn.SiLU
- nn.ReLU
- nn.LeakyReLU
Pull Request resolved: https://github.com/pytorch/pytorch/pull/169544
Approved by: https://github.com/slayton58
This commit is contained in:
jainapurva
2025-12-22 16:41:45 +00:00
committed by PyTorch MergeBot
parent 229d33f7f9
commit fdccb593c8
4 changed files with 302 additions and 1 deletions

View File

@@ -1811,7 +1811,7 @@ test_operator_microbenchmark() {
cd "${TEST_DIR}"/benchmarks/operator_benchmark
# NOTE: When adding a new test here, please update README: ../../benchmarks/operator_benchmark/README.md
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer; do
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer activation norm; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile

View File

@@ -292,6 +292,8 @@ class BenchmarkRunner:
"latency unit": "us",
"peak memory": results["peak_memory"],
"memory unit": "KB",
"memory bandwidth": results.get("memory_bandwidth_gb_s"),
"memory bandwidth unit": "GB/s",
}
# parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary
@@ -559,6 +561,7 @@ class BenchmarkRunner:
run_type = perf_item.get("run")
latency = perf_item.get("latency", 0)
peak_memory = perf_item.get("peak memory", 0)
memory_bandwidth = perf_item.get("memory bandwidth", 0)
device = perf_item.get("device", "unknown")
dtype = perf_item.get("dtype", "torch.float").split(".")[1]
runtime = perf_item.get("runtime", None)
@@ -656,6 +659,16 @@ class BenchmarkRunner:
)
records.append(asdict(record_memory))
# Add record for memory bandwidth
record_memory_bandwidth = copy.deepcopy(record_latency)
record_memory_bandwidth.metric = MetricInfo(
name="memory bandwidth",
unit="GB/s",
benchmark_values=[memory_bandwidth],
target_value=None,
)
records.append(asdict(record_memory_bandwidth))
# Write all records to the output file
with open(output_file, "w", encoding="utf-8") as f:
json.dump(records, f, indent=2)
@@ -671,6 +684,7 @@ class BenchmarkRunner:
"run_backward",
"Execution Time",
"Peak Memory (KB)",
"Memory Bandwidth (GB/s)",
]
if self.args.output_json or self.args.output_json_for_dashboard:
@@ -746,6 +760,7 @@ class BenchmarkRunner:
test_case.test_config.run_backward,
result_dict["reported_run_time_us"][0],
result_dict["peak_memory"],
result_dict["memory_bandwidth_gb_s"],
],
)
if self.args.output_json or self.args.output_json_for_dashboard:

View File

@@ -0,0 +1,67 @@
import operator_benchmark as op_bench
import torch
import torch.nn as nn
"""Microbenchmarks for activation operators."""
activation_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["gelu", nn.GELU],
["silu", nn.SiLU],
["relu", nn.ReLU],
["leaky_relu", nn.LeakyReLU],
],
)
activation_short_configs = op_bench.config_list(
attr_names=["shape"],
attrs=[
[(1,)],
[(64,)],
[(4096,)],
[(8192,)],
],
cross_product_configs={
"device": ["cuda"],
},
tags=["short"],
)
activation_long_configs = op_bench.cross_product_configs(
shape=[(1,), (64,), (4096,), (8192,), (131072,), (262144,), (524288,), (1048576,)],
device=["cuda"],
tags=["long"],
)
class ActivationBenchmark(op_bench.TorchBenchmarkBase):
def init(self, op_func, device, shape):
self.inputs = {
"input": torch.rand(shape, device=device, requires_grad=self.auto_set())
}
self.op_func = op_func()
self.set_module_name(op_func.__name__)
def forward(self, input):
return self.op_func(input)
op_bench.generate_pt_tests_from_op_list(
activation_list,
activation_long_configs,
ActivationBenchmark,
)
op_bench.generate_pt_gradient_tests_from_op_list(
activation_list,
activation_long_configs,
ActivationBenchmark,
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@@ -0,0 +1,219 @@
import operator_benchmark as op_bench
import torch
import torch.nn as nn
"""Microbenchmarks for normalization operators."""
# ==============================================================================
# LayerNorm and RMSNorm Benchmarks
# ==============================================================================
layernorm_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["LayerNorm", nn.LayerNorm],
["RMSNorm", nn.RMSNorm],
],
)
layernorm_configs = op_bench.cross_product_configs(
B=[8, 32],
M=[256, 1024],
K=[64, 128, 512],
device=["cuda"],
tags=["long"],
)
class LayerNormBenchmark(op_bench.TorchBenchmarkBase):
def init(self, op_func, device, B, M, K):
self.inputs = {
"input": torch.rand(B, M, K, device=device, requires_grad=self.auto_set())
}
# normalized_shape is the last dimension (hidden dim K)
self.op_func = op_func(K, device=device)
self.set_module_name(op_func.__name__)
def forward(self, input):
return self.op_func(input)
op_bench.generate_pt_tests_from_op_list(
layernorm_list,
layernorm_configs,
LayerNormBenchmark,
)
op_bench.generate_pt_gradient_tests_from_op_list(
layernorm_list,
layernorm_configs,
LayerNormBenchmark,
)
# ==============================================================================
# BatchNorm1d Benchmarks (training + eval)
# ==============================================================================
batchnorm1d_configs = op_bench.cross_product_configs(
B=[8, 32],
C=[64, 128, 256],
M=[256, 1024],
device=["cuda"],
training=[True, False],
tags=["long"],
)
class BatchNorm1dBenchmark(op_bench.TorchBenchmarkBase):
def init(self, device, B, C, M, training):
self.inputs = {
"input": torch.rand(B, C, M, device=device, requires_grad=self.auto_set())
}
self.op_func = nn.BatchNorm1d(C, device=device)
self.op_func.train(training)
self.set_module_name("BatchNorm1d")
def forward(self, input):
return self.op_func(input)
op_bench.generate_pt_test(
batchnorm1d_configs,
BatchNorm1dBenchmark,
)
op_bench.generate_pt_gradient_test(
batchnorm1d_configs,
BatchNorm1dBenchmark,
)
# ==============================================================================
# BatchNorm2d Benchmarks (training + eval)
# ==============================================================================
batchnorm2d_configs = op_bench.cross_product_configs(
B=[8, 32],
C=[64, 128, 256],
H=[28, 56],
W=[28, 56],
device=["cuda"],
training=[True, False],
tags=["long"],
)
class BatchNorm2dBenchmark(op_bench.TorchBenchmarkBase):
def init(self, device, B, C, H, W, training):
self.inputs = {
"input": torch.rand(
B, C, H, W, device=device, requires_grad=self.auto_set()
)
}
self.op_func = nn.BatchNorm2d(C, device=device)
self.op_func.train(training)
self.set_module_name("BatchNorm2d")
def forward(self, input):
return self.op_func(input)
op_bench.generate_pt_test(
batchnorm2d_configs,
BatchNorm2dBenchmark,
)
op_bench.generate_pt_gradient_test(
batchnorm2d_configs,
BatchNorm2dBenchmark,
)
# ==============================================================================
# BatchNorm3d Benchmarks (training + eval)
# ==============================================================================
batchnorm3d_configs = op_bench.cross_product_configs(
B=[8, 32],
C=[64, 128, 256],
D=[4, 8],
H=[14, 28],
W=[14, 28],
device=["cuda"],
training=[True, False],
tags=["long"],
)
class BatchNorm3dBenchmark(op_bench.TorchBenchmarkBase):
def init(self, device, B, C, D, H, W, training):
self.inputs = {
"input": torch.rand(
B, C, D, H, W, device=device, requires_grad=self.auto_set()
)
}
self.op_func = nn.BatchNorm3d(C, device=device)
self.op_func.train(training)
self.set_module_name("BatchNorm3d")
def forward(self, input):
return self.op_func(input)
op_bench.generate_pt_test(
batchnorm3d_configs,
BatchNorm3dBenchmark,
)
op_bench.generate_pt_gradient_test(
batchnorm3d_configs,
BatchNorm3dBenchmark,
)
# ==============================================================================
# GroupNorm Benchmarks
# ==============================================================================
groupnorm_configs = op_bench.cross_product_configs(
B=[8, 32],
C=[64, 128, 256],
H=[28, 56],
W=[28, 56],
num_groups=[8, 16, 32],
device=["cuda"],
tags=["long"],
)
class GroupNormBenchmark(op_bench.TorchBenchmarkBase):
def init(self, device, B, C, H, W, num_groups):
self.inputs = {
"input": torch.rand(
B, C, H, W, device=device, requires_grad=self.auto_set()
)
}
self.op_func = nn.GroupNorm(num_groups, C, device=device)
self.set_module_name("GroupNorm")
def forward(self, input):
return self.op_func(input)
op_bench.generate_pt_test(
groupnorm_configs,
GroupNormBenchmark,
)
op_bench.generate_pt_gradient_test(
groupnorm_configs,
GroupNormBenchmark,
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()