mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Add normalization and activation ops to operator benchmarks (#169544)
We're adding some more ops to the benchmarking: Normalization ops: - LayerNorm - RMSNorm - BatchNorm1d - BatchNorm2d - BatchNorm3d - GroupNorm Activation ops: - nn.GELU - nn.SiLU - nn.ReLU - nn.LeakyReLU Pull Request resolved: https://github.com/pytorch/pytorch/pull/169544 Approved by: https://github.com/slayton58
This commit is contained in:
committed by
PyTorch MergeBot
parent
229d33f7f9
commit
fdccb593c8
@@ -1811,7 +1811,7 @@ test_operator_microbenchmark() {
|
||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||
|
||||
# NOTE: When adding a new test here, please update README: ../../benchmarks/operator_benchmark/README.md
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer; do
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer activation norm; do
|
||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
||||
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
||||
|
||||
@@ -292,6 +292,8 @@ class BenchmarkRunner:
|
||||
"latency unit": "us",
|
||||
"peak memory": results["peak_memory"],
|
||||
"memory unit": "KB",
|
||||
"memory bandwidth": results.get("memory_bandwidth_gb_s"),
|
||||
"memory bandwidth unit": "GB/s",
|
||||
}
|
||||
|
||||
# parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary
|
||||
@@ -559,6 +561,7 @@ class BenchmarkRunner:
|
||||
run_type = perf_item.get("run")
|
||||
latency = perf_item.get("latency", 0)
|
||||
peak_memory = perf_item.get("peak memory", 0)
|
||||
memory_bandwidth = perf_item.get("memory bandwidth", 0)
|
||||
device = perf_item.get("device", "unknown")
|
||||
dtype = perf_item.get("dtype", "torch.float").split(".")[1]
|
||||
runtime = perf_item.get("runtime", None)
|
||||
@@ -656,6 +659,16 @@ class BenchmarkRunner:
|
||||
)
|
||||
records.append(asdict(record_memory))
|
||||
|
||||
# Add record for memory bandwidth
|
||||
record_memory_bandwidth = copy.deepcopy(record_latency)
|
||||
record_memory_bandwidth.metric = MetricInfo(
|
||||
name="memory bandwidth",
|
||||
unit="GB/s",
|
||||
benchmark_values=[memory_bandwidth],
|
||||
target_value=None,
|
||||
)
|
||||
records.append(asdict(record_memory_bandwidth))
|
||||
|
||||
# Write all records to the output file
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(records, f, indent=2)
|
||||
@@ -671,6 +684,7 @@ class BenchmarkRunner:
|
||||
"run_backward",
|
||||
"Execution Time",
|
||||
"Peak Memory (KB)",
|
||||
"Memory Bandwidth (GB/s)",
|
||||
]
|
||||
|
||||
if self.args.output_json or self.args.output_json_for_dashboard:
|
||||
@@ -746,6 +760,7 @@ class BenchmarkRunner:
|
||||
test_case.test_config.run_backward,
|
||||
result_dict["reported_run_time_us"][0],
|
||||
result_dict["peak_memory"],
|
||||
result_dict["memory_bandwidth_gb_s"],
|
||||
],
|
||||
)
|
||||
if self.args.output_json or self.args.output_json_for_dashboard:
|
||||
|
||||
67
benchmarks/operator_benchmark/pt/activation_test.py
Normal file
67
benchmarks/operator_benchmark/pt/activation_test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
"""Microbenchmarks for activation operators."""
|
||||
|
||||
|
||||
activation_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["gelu", nn.GELU],
|
||||
["silu", nn.SiLU],
|
||||
["relu", nn.ReLU],
|
||||
["leaky_relu", nn.LeakyReLU],
|
||||
],
|
||||
)
|
||||
|
||||
activation_short_configs = op_bench.config_list(
|
||||
attr_names=["shape"],
|
||||
attrs=[
|
||||
[(1,)],
|
||||
[(64,)],
|
||||
[(4096,)],
|
||||
[(8192,)],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cuda"],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
activation_long_configs = op_bench.cross_product_configs(
|
||||
shape=[(1,), (64,), (4096,), (8192,), (131072,), (262144,), (524288,), (1048576,)],
|
||||
device=["cuda"],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class ActivationBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, op_func, device, shape):
|
||||
self.inputs = {
|
||||
"input": torch.rand(shape, device=device, requires_grad=self.auto_set())
|
||||
}
|
||||
self.op_func = op_func()
|
||||
self.set_module_name(op_func.__name__)
|
||||
|
||||
def forward(self, input):
|
||||
return self.op_func(input)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
activation_list,
|
||||
activation_long_configs,
|
||||
ActivationBenchmark,
|
||||
)
|
||||
|
||||
op_bench.generate_pt_gradient_tests_from_op_list(
|
||||
activation_list,
|
||||
activation_long_configs,
|
||||
ActivationBenchmark,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
219
benchmarks/operator_benchmark/pt/norm_test.py
Normal file
219
benchmarks/operator_benchmark/pt/norm_test.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
"""Microbenchmarks for normalization operators."""
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# LayerNorm and RMSNorm Benchmarks
|
||||
# ==============================================================================
|
||||
|
||||
layernorm_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["LayerNorm", nn.LayerNorm],
|
||||
["RMSNorm", nn.RMSNorm],
|
||||
],
|
||||
)
|
||||
|
||||
layernorm_configs = op_bench.cross_product_configs(
|
||||
B=[8, 32],
|
||||
M=[256, 1024],
|
||||
K=[64, 128, 512],
|
||||
device=["cuda"],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class LayerNormBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, op_func, device, B, M, K):
|
||||
self.inputs = {
|
||||
"input": torch.rand(B, M, K, device=device, requires_grad=self.auto_set())
|
||||
}
|
||||
# normalized_shape is the last dimension (hidden dim K)
|
||||
self.op_func = op_func(K, device=device)
|
||||
self.set_module_name(op_func.__name__)
|
||||
|
||||
def forward(self, input):
|
||||
return self.op_func(input)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
layernorm_list,
|
||||
layernorm_configs,
|
||||
LayerNormBenchmark,
|
||||
)
|
||||
|
||||
op_bench.generate_pt_gradient_tests_from_op_list(
|
||||
layernorm_list,
|
||||
layernorm_configs,
|
||||
LayerNormBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# BatchNorm1d Benchmarks (training + eval)
|
||||
# ==============================================================================
|
||||
|
||||
batchnorm1d_configs = op_bench.cross_product_configs(
|
||||
B=[8, 32],
|
||||
C=[64, 128, 256],
|
||||
M=[256, 1024],
|
||||
device=["cuda"],
|
||||
training=[True, False],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class BatchNorm1dBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, device, B, C, M, training):
|
||||
self.inputs = {
|
||||
"input": torch.rand(B, C, M, device=device, requires_grad=self.auto_set())
|
||||
}
|
||||
self.op_func = nn.BatchNorm1d(C, device=device)
|
||||
self.op_func.train(training)
|
||||
self.set_module_name("BatchNorm1d")
|
||||
|
||||
def forward(self, input):
|
||||
return self.op_func(input)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(
|
||||
batchnorm1d_configs,
|
||||
BatchNorm1dBenchmark,
|
||||
)
|
||||
|
||||
op_bench.generate_pt_gradient_test(
|
||||
batchnorm1d_configs,
|
||||
BatchNorm1dBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# BatchNorm2d Benchmarks (training + eval)
|
||||
# ==============================================================================
|
||||
|
||||
batchnorm2d_configs = op_bench.cross_product_configs(
|
||||
B=[8, 32],
|
||||
C=[64, 128, 256],
|
||||
H=[28, 56],
|
||||
W=[28, 56],
|
||||
device=["cuda"],
|
||||
training=[True, False],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class BatchNorm2dBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, device, B, C, H, W, training):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
B, C, H, W, device=device, requires_grad=self.auto_set()
|
||||
)
|
||||
}
|
||||
self.op_func = nn.BatchNorm2d(C, device=device)
|
||||
self.op_func.train(training)
|
||||
self.set_module_name("BatchNorm2d")
|
||||
|
||||
def forward(self, input):
|
||||
return self.op_func(input)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(
|
||||
batchnorm2d_configs,
|
||||
BatchNorm2dBenchmark,
|
||||
)
|
||||
|
||||
op_bench.generate_pt_gradient_test(
|
||||
batchnorm2d_configs,
|
||||
BatchNorm2dBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# BatchNorm3d Benchmarks (training + eval)
|
||||
# ==============================================================================
|
||||
|
||||
batchnorm3d_configs = op_bench.cross_product_configs(
|
||||
B=[8, 32],
|
||||
C=[64, 128, 256],
|
||||
D=[4, 8],
|
||||
H=[14, 28],
|
||||
W=[14, 28],
|
||||
device=["cuda"],
|
||||
training=[True, False],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class BatchNorm3dBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, device, B, C, D, H, W, training):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
B, C, D, H, W, device=device, requires_grad=self.auto_set()
|
||||
)
|
||||
}
|
||||
self.op_func = nn.BatchNorm3d(C, device=device)
|
||||
self.op_func.train(training)
|
||||
self.set_module_name("BatchNorm3d")
|
||||
|
||||
def forward(self, input):
|
||||
return self.op_func(input)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(
|
||||
batchnorm3d_configs,
|
||||
BatchNorm3dBenchmark,
|
||||
)
|
||||
|
||||
op_bench.generate_pt_gradient_test(
|
||||
batchnorm3d_configs,
|
||||
BatchNorm3dBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# GroupNorm Benchmarks
|
||||
# ==============================================================================
|
||||
|
||||
groupnorm_configs = op_bench.cross_product_configs(
|
||||
B=[8, 32],
|
||||
C=[64, 128, 256],
|
||||
H=[28, 56],
|
||||
W=[28, 56],
|
||||
num_groups=[8, 16, 32],
|
||||
device=["cuda"],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class GroupNormBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, device, B, C, H, W, num_groups):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
B, C, H, W, device=device, requires_grad=self.auto_set()
|
||||
)
|
||||
}
|
||||
self.op_func = nn.GroupNorm(num_groups, C, device=device)
|
||||
self.set_module_name("GroupNorm")
|
||||
|
||||
def forward(self, input):
|
||||
return self.op_func(input)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(
|
||||
groupnorm_configs,
|
||||
GroupNormBenchmark,
|
||||
)
|
||||
|
||||
op_bench.generate_pt_gradient_test(
|
||||
groupnorm_configs,
|
||||
GroupNormBenchmark,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
Reference in New Issue
Block a user