diff --git a/benchmarks/transformer/config_utils.py b/benchmarks/transformer/config_utils.py new file mode 100644 index 00000000000..34214cdb59a --- /dev/null +++ b/benchmarks/transformer/config_utils.py @@ -0,0 +1,157 @@ +"""Configuration utilities for parsing JSON and YAML config files.""" + +import json +import re + + +def heads_input_type(s: str) -> tuple[int, int]: + """Convert string format 'Hq,Hkv' to tuple (Hq, Hkv).""" + try: + hq, hkv = map(int, s.split(",")) + return hq, hkv + except Exception as e: + raise ValueError("Heads must be Hq,Hkv") from e + + +default_config = { + "dynamic": False, + "calculate_bwd": False, + "dtype": "bfloat16", + "b": [2, 8, 16], + "nh": ["16,16", "16,2"], + "s": [512, 1024, 4096], + "d": [64, 128], + "mods": ["noop", "causal", "alibi", "sliding_window"], + "backend": ["efficient"], + "max_autotune": False, + "decoding": False, + "kv_size": None, + "throughput": True, + "save_path": None, + "output_json_for_dashboard": None, + "benchmark_name": "PyTorch operator microbenchmark", +} + + +def load_config_file(config_path: str) -> dict: + """Load configuration from JSON or YAML file. + + Automatically converts 'nh' field from strings to tuples. + + Args: + config_path: Path to the configuration file + + Returns: + Dictionary containing the configuration + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If config file format is invalid + """ + with open(config_path) as f: + config_str = f.read() + + # Try to load as JSON first + try: + config = json.loads(config_str) + except json.JSONDecodeError: + # Fall back to YAML parsing + config = _parse_simple_yaml(config_str) + + # Apply automatic conversions for 'nh' field + if "nh" in config and isinstance(config["nh"], list): + config["nh"] = [ + heads_input_type(h) if isinstance(h, str) else h for h in config["nh"] + ] + + return config + + +def _parse_simple_yaml(yaml_str: str) -> dict: + """Simple YAML parser for basic configs (without external dependencies). + + Supports: + - key: value pairs + - booleans (true/false) + - null values + - integers and floats + - strings (quoted and unquoted) + - lists in JSON format [item1, item2, ...] + - comments (lines starting with # or after #) + + Args: + yaml_str: YAML content as string + + Returns: + Dictionary containing parsed YAML content + """ + config = {} + + for line in yaml_str.split("\n"): + # Remove comments + line = line.split("#")[0].strip() + + if not line or ":" not in line: + continue + + key, value = line.split(":", 1) + key = key.strip() + value = value.strip() + + # Parse value based on type + if value.lower() == "true": + config[key] = True + elif value.lower() == "false": + config[key] = False + elif value.lower() in ("null", "none", ""): + config[key] = None + elif value.startswith("[") and value.endswith("]"): + # Parse list - handle quoted strings properly + pattern = r'"([^"]+)"|\'([^\']+)\'|([^,\[\]\s]+)' + matches = re.findall(pattern, value[1:-1]) # Remove [ ] + parsed_items = [] + for match in matches: + # match is a tuple of (double_quoted, single_quoted, unquoted) + item = match[0] or match[1] or match[2] + item = item.strip() + if item: + try: + parsed_items.append(int(item)) + except ValueError: + parsed_items.append(item) + config[key] = parsed_items + elif value.startswith(('"', "'")): + config[key] = value.strip("\"'") + else: + # Try to parse as number + try: + config[key] = int(value) + except ValueError: + try: + config[key] = float(value) + except ValueError: + config[key] = value + + return config + + +def print_default_config(output_format: str) -> None: + """Print a default configuration template in JSON or YAML format. + + Args: + output_format: Either "json" or "yaml" + """ + if output_format == "json": + print(json.dumps(default_config, indent=2)) + else: # yaml + for key, value in default_config.items(): + if value is None: + print(f"{key}: null") + elif isinstance(value, bool): + print(f"{key}: {str(value).lower()}") + elif isinstance(value, str): + print(f'{key}: "{value}"') + elif isinstance(value, list): + print(f"{key}: {json.dumps(value)}") + else: + print(f"{key}: {value}") diff --git a/benchmarks/transformer/configs/config_basic.yaml b/benchmarks/transformer/configs/config_basic.yaml new file mode 100644 index 00000000000..9f5e7631301 --- /dev/null +++ b/benchmarks/transformer/configs/config_basic.yaml @@ -0,0 +1,29 @@ +# Basic benchmark configuration for PyTorch transformer benchmarks +# Usage: python score_mod.py --config config_basic.yaml + +# Core parameters +dynamic: false +calculate_bwd: true +dtype: "bfloat16" + +# Shape parameters - larger sweep +b: [1, 2, 4, 8, 16] # batch sizes +nh: ["16,16", "16,2", "32,32", "32,4"] # [query_heads,key_value_heads] +s: [512, 1024, 2048, 4096, 8192] # sequence lengths +d: [64, 128] # head dimensions (limited to 128 for Flash Attention/cuDNN compatibility) + +# All attention types +mods: ["noop", "causal", "rel", "head_bias", "alibi", "sliding_window", "prefix_lm", "softcap"] + +# Multiple backends for comparison (SDPA + Flash Attention) - flex is always included internally +backend: ["efficient", "math", "cudnn", "fav2"] +max_autotune: true # Enable torch.compile with max-autotune for optimal performance + +# Decoding and cache settings +decoding: false +kv_size: null + +# Metrics and output +throughput: true # Calculate memory bandwidth & TFLOPS +save_path: "comprehensive_results.csv" # Save to CSV +output_json_for_dashboard: "attn_bench_basic.json" diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 520fb26994e..928cbf27df5 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -1,15 +1,19 @@ import argparse import csv +import gc import itertools +import json import random +import sys from collections import defaultdict from collections.abc import Callable from contextlib import nullcontext from dataclasses import asdict, dataclass -from functools import partial -from typing import Optional, Union +from functools import partial, wraps +from typing import Literal, Optional, Union import numpy as np +from config_utils import heads_input_type, load_config_file, print_default_config from tabulate import tabulate from tqdm import tqdm @@ -33,6 +37,96 @@ torch._dynamo.config.recompile_limit = 1000 from torch._inductor.runtime.benchmarking import benchmarker +def cleanup_memory(): + """Aggressively free GPU memory""" + torch.cuda.empty_cache() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def safe_backend(backend_name=None, return_dict=False): + """Decorator that wraps backend functions with error handling + + Args: + backend_name: Name of the backend for error messages + return_dict: If True, returns dict of results for all backends (for run_single_experiment) + If False, returns single ExperimentResults (for individual backend functions) + """ + + def decorator(func): + @wraps(func) + def wrapper(config, *args, **kwargs): + try: + return func(config, *args, **kwargs) + except torch.OutOfMemoryError: + print( + f"[SKIP] OOM for {backend_name or func.__name__} with shape {config.shape}" + ) + cleanup_memory() + except RuntimeError as e: + error_msg = str(e) + if "out of resource" in error_msg or "OutOfMemoryError" in error_msg: + print( + f"[SKIP] Triton OOM for {backend_name or func.__name__} with shape {config.shape}" + ) + cleanup_memory() + elif "No valid triton configs" in error_msg: + print( + f"[SKIP] No valid Triton config for {backend_name or func.__name__} with shape {config.shape}" + ) + else: + print( + f"[SKIP] Runtime error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}" + ) + except Exception as e: + print( + f"[SKIP] Error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}" + ) + + # Return appropriate NaN result based on function type + if return_dict: + # For run_single_experiment: return dict with NaN for all backends + nan_result = ExperimentResults( + fwd_time=float("nan"), + bwd_time=float("nan") if config.calculate_bwd_time else None, + ) + results = dict.fromkeys(config.backends, nan_result) + results["flex"] = ExperimentResults( + fwd_time=float("nan"), + bwd_time=float("nan") if config.calculate_bwd_time else None, + sparsity=None, + ) + return results + else: + # For individual backend functions: return single ExperimentResults + return ExperimentResults( + fwd_time=float("nan"), + bwd_time=float("nan") if config.calculate_bwd_time else None, + ) + + return wrapper + + return decorator + + +# Type definitions +Backend = Literal["math", "efficient", "cudnn", "fav2", "fav3", "fakv", "og-eager"] +AttentionType = Literal[ + "noop", + "causal", + "rel", + "head_bias", + "alibi", + "sliding_window", + "document_mask", + "prefix_lm", + "softcap", +] +DtypeString = Literal["bfloat16", "float16", "float32"] +SpeedupType = Literal["fwd", "bwd"] + + def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: # warmup for _ in range(5): @@ -48,6 +142,7 @@ class ExperimentConfig: calculate_bwd_time: bool cal_bandwidth: bool backends: list[str] + max_autotune: bool def __post_init__(self): assert len(self.shape) == 6, ( @@ -62,6 +157,7 @@ class ExperimentConfig: d.pop("cal_bandwidth", None) d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape") d.pop("backends", None) + d.pop("max_autotune", False) return d @@ -209,6 +305,7 @@ def query_key_value_clones( return query_ref, key_ref, value_ref +@safe_backend("SDPA") def run_single_backend_sdpa( config: ExperimentConfig, query: torch.Tensor, @@ -223,6 +320,7 @@ def run_single_backend_sdpa( backend_context = get_backend_context(backend) with backend_context: _device = torch.device("cuda") + eager_sdpa = generate_eager_sdpa( config.attn_type, config.shape, config.dtype, block_mask, score_mod ) @@ -290,6 +388,7 @@ def run_single_backend_sdpa( ) +@safe_backend("FlashAttention") def run_single_backend_FA( config: ExperimentConfig, query: torch.Tensor, @@ -301,9 +400,9 @@ def run_single_backend_FA( mask_kwargs, backend: str, ) -> ExperimentResults: - assert backend in ["fav2", "fav3", "fakv"] + assert backend in ["fav3", "fakv"] # Generate callable for specific backend. - if backend in ["fav2", "fav3"]: + if backend in ["fav3"]: FA = generate_FA_callable( config.attn_type, config.shape, config.dtype, backend, **mask_kwargs ) @@ -354,10 +453,10 @@ def run_single_backend_FA( ) +@safe_backend("flex_attention", return_dict=True) def run_single_experiment( config: ExperimentConfig, dynamic=False, - max_autotune=False, ) -> dict[str, ExperimentResults]: device = torch.device("cuda") batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape @@ -377,7 +476,7 @@ def run_single_experiment( block_mask, mask_kwargs = generate_block_mask(config.attn_type, config.shape) kernel_options = get_kernel_options(config.attn_type, config.shape) - if max_autotune: + if config.max_autotune: compiled_sdpa = torch.compile( flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs" ) @@ -407,7 +506,7 @@ def run_single_experiment( results = {} for backend in config.backends: - if backend in ["fav2", "fav3", "fakv"]: + if backend in ["fav3", "fakv"]: results[backend] = run_single_backend_FA( config, query, @@ -419,7 +518,7 @@ def run_single_experiment( mask_kwargs, backend, ) - else: # sdpa + else: # sdpa (also supports fav2) results[backend] = run_single_backend_sdpa( config, query, @@ -440,7 +539,7 @@ def run_single_experiment( sparsity = block_mask.sparsity() / 100.0 if block_mask is not None else 0.0 sparsity = sparsity if config.attn_type != "document_mask" else 0.5 - results["compiled"] = ExperimentResults( + results["flex"] = ExperimentResults( fwd_time=forward_compiled_time, bwd_time=backward_compile_time if config.calculate_bwd_time else None, sparsity=sparsity, @@ -501,15 +600,15 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl softmax_flops = M * N * 2 # Not counting online softmax overhead o_flops = M * D * N * 2 # Not counting split k overhead - total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - results.sparsity) + sparsity = results.sparsity if results.sparsity is not None else 0.0 + total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - sparsity) return total_flops / results.fwd_time / 1e6 # in TFLOPs/ def get_average_speedups(results: list[Experiment], type: str, backend: str): # Calculate speedups speedups = [ - calculate_speedup(r.results["compiled"], r.results[backend], type) - for r in results + calculate_speedup(r.results["flex"], r.results[backend], type) for r in results ] # Find indices of max and min speedups @@ -537,7 +636,7 @@ def get_average_speedups(results: list[Experiment], type: str, backend: str): def print_results(results: list[Experiment], save_path: Optional[str] = None): table_data = defaultdict(list) for experiment in results: - backends = experiment.config.backends + ["compiled"] + backends = experiment.config.backends + ["flex"] for key, value in experiment.asdict().items(): if key in backends: if value.fwd_time: @@ -550,45 +649,43 @@ def print_results(results: list[Experiment], save_path: Optional[str] = None): # Calculate speedups for backend in results[0].config.backends: fwd_speedups = [ - calculate_speedup(r.results["compiled"], r.results[backend], type="fwd") + calculate_speedup(r.results["flex"], r.results[backend], type="fwd") for r in results ] - table_data[f"fwd_{backend}_speedup"] = fwd_speedups + table_data[f"fwd_speedup_flex_over_{backend}"] = fwd_speedups if results[0].config.calculate_bwd_time: for backend in results[0].config.backends: bwd_speedups = [ - calculate_speedup(r.results["compiled"], r.results[backend], type="bwd") + calculate_speedup(r.results["flex"], r.results[backend], type="bwd") for r in results ] - table_data[f"bwd_{backend}_speedup"] = bwd_speedups + table_data[f"bwd_speedup_flex_over_{backend}"] = bwd_speedups # Calculate mem + computational throughput if results[0].config.cal_bandwidth: fwd_bandwidth = [ - calculate_bandwidth(r.config, r.results["compiled"], type="fwd") + calculate_bandwidth(r.config, r.results["flex"], type="fwd") for r in results ] table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth - fwd_tflops = [ - calculate_tflops(r.config, r.results["compiled"]) for r in results - ] + fwd_tflops = [calculate_tflops(r.config, r.results["flex"]) for r in results] table_data["TFlops/s"] = fwd_tflops print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f")) for backend in results[0].config.backends: - if np.isnan(table_data[f"fwd_{backend}_speedup"]).all(): + if np.isnan(table_data[f"fwd_speedup_flex_over_{backend}"]).all(): continue print("\n") - print(f"FWD Speedups vs. {backend}".center(125, "=")) + print(f"FWD Speedup of Flex over {backend}".center(125, "=")) print("\n") average_data = get_average_speedups(results, type="fwd", backend=backend) print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) if results[0].config.calculate_bwd_time: print("\n") - print(f"BWD Speedups vs. {backend}".center(125, "=")) + print(f"BWD Speedup of Flex over {backend}".center(125, "=")) print("\n") average_data = get_average_speedups(results, type="bwd", backend=backend) print( @@ -791,14 +888,14 @@ def get_backend_context(backend: str): Returns a context manager for the specified backend. Args: backend (str): The name of the backend to use. - Valid options are 'fav2', 'cudnn', 'math', 'efficient', 'fav3', 'fakv', 'og-eager'. + Valid options are 'math', 'efficient', 'cudnn', 'fav2', 'fav3', 'fakv', 'og-eager'. Returns: A context manager for the specified backend. Raises: ValueError: If an invalid backend is specified. """ backends = { - "fav2": nullcontext(), + "fav2": sdpa_kernel(SDPBackend.FLASH_ATTENTION), "cudnn": sdpa_kernel(SDPBackend.CUDNN_ATTENTION), "math": sdpa_kernel(SDPBackend.MATH), "efficient": sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION), @@ -820,15 +917,7 @@ def generate_FA_callable( ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None - if backend == "fav2": - try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - except ImportError: - print( - "Flash attention 2 is not installed. Please install it to run fav2 backend. " - ) - raise - elif backend == "fav3": + if backend == "fav3": try: from flash_attn.flash_attn_interface import ( flash_attn_func, @@ -1034,6 +1123,7 @@ def generate_experiment_configs( kv_cache_size: list[int], cal_bandwidth: bool, backends: list[str], + max_autotune: bool, ) -> list[ExperimentConfig]: assert not (calculate_bwd and decoding), "Decoding does not support backward" @@ -1077,52 +1167,333 @@ def generate_experiment_configs( calculate_bwd_time=calculate_bwd, cal_bandwidth=cal_bandwidth, backends=backends, + max_autotune=max_autotune, ) ) return all_configs -def main(args): +def _output_json_for_dashboard( + experiments, + output_file, + benchmark_name="PyTorch operator microbenchmark", +): + """ + Write the result into JSON format for PyTorch OSS dashboard. + The JSON format is defined at + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + + Args: + experiments: List of experiment results + output_file: Path to output JSON file + benchmark_name: Name of the benchmark + """ + if not experiments: + return + + import math + import platform + from dataclasses import asdict, dataclass + from typing import Any, Optional + + # Prepare headers and records for JSON output + records = [] + for experiment in experiments: + config = experiment.config + results_dict = ( + experiment.results + ) # This is a dict: backend -> ExperimentResults + + # Process each backend result + for backend, results in results_dict.items(): + # Skip backends that were not run (NaN results) + if math.isnan(results.fwd_time): + continue + + # Extract data from experiment + test_name = f"{backend}_{config.attn_type}_" + input_config = f"shape: {config.shape}, dtype: {config.dtype}" + + # Determine mode based on backward pass + mode = "training" if config.calculate_bwd_time else "inference" + + # Extract dtype + dtype = ( + str(config.dtype).split(".")[1] + if "." in str(config.dtype) + else str(config.dtype) + ) + + # Determine device + device = "cuda" + + # Get device architecture + device_arch = ( + torch.cuda.get_device_name(0) + if device == "cuda" + else platform.processor() + if device == "cpu" + else "unknown" + ) + + # Create dataclasses for JSON structure + @dataclass + class BenchmarkInfo: + name: str + mode: Optional[str] + dtype: str + extra_info: dict[str, Any] + + @dataclass + class ModelInfo: + name: str + type: str + origins: list[str] + extra_info: dict[str, Any] + + @dataclass + class MetricInfo: + name: str + unit: str + benchmark_values: list[float] + target_value: Optional[float] + + @dataclass + class BenchmarkRecord: + benchmark: BenchmarkInfo + model: ModelInfo + metric: MetricInfo + + # Benchmark extra info + benchmark_extra_info = { + "input_config": input_config, + "device": device, + "arch": device_arch, + "operator_name": backend, + "attn_type": config.attn_type, + "shape": str(config.shape), + "max_autotune": config.max_autotune, + } + # Add record for forward latency + record_fwd_latency = BenchmarkRecord( + benchmark=BenchmarkInfo( + name=benchmark_name, + mode=mode, + dtype=dtype, + extra_info=benchmark_extra_info, + ), + model=ModelInfo( + name=test_name + str(config.shape), + type="attention-benchmark", + origins=["pytorch"], + extra_info={ + "operator_name": backend, + "attn_type": config.attn_type, + }, + ), + metric=MetricInfo( + name="forward latency", + unit="us", + benchmark_values=[results.fwd_time], + target_value=None, + ), + ) + records.append(asdict(record_fwd_latency)) + + # Add record for forward memory bandwidth (if available) + if config.cal_bandwidth: + record_fwd_bandwidth = BenchmarkRecord( + benchmark=BenchmarkInfo( + name=benchmark_name, + mode=mode, + dtype=dtype, + extra_info=benchmark_extra_info, + ), + model=ModelInfo( + name=test_name + str(config.shape), + type="attention-benchmark", + origins=["pytorch"], + extra_info={ + "operator_name": backend, + }, + ), + metric=MetricInfo( + name="memory bandwidth", + unit="TB/s", + benchmark_values=[calculate_bandwidth(config, results, "fwd")], + target_value=None, + ), + ) + records.append(asdict(record_fwd_bandwidth)) + + # Add record for forward TFLOPS (if available) + if config.cal_bandwidth: + record_fwd_tflops = BenchmarkRecord( + benchmark=BenchmarkInfo( + name=benchmark_name, + mode=mode, + dtype=dtype, + extra_info=benchmark_extra_info, + ), + model=ModelInfo( + name=test_name + str(config.shape), + type="attention-benchmark", + origins=["pytorch"], + extra_info={ + "operator_name": backend, + }, + ), + metric=MetricInfo( + name="tflops", + unit="TFLOPS/s", + benchmark_values=[calculate_tflops(config, results)], + target_value=None, + ), + ) + records.append(asdict(record_fwd_tflops)) + + # Add record for backward latency (if available and not NaN) + if ( + config.calculate_bwd_time + and results.bwd_time is not None + and not math.isnan(results.bwd_time) + ): + record_bwd_latency = BenchmarkRecord( + benchmark=BenchmarkInfo( + name=benchmark_name, + mode=mode, + dtype=dtype, + extra_info=benchmark_extra_info, + ), + model=ModelInfo( + name=test_name + str(config.shape), + type="attention-benchmark", + origins=["pytorch"], + extra_info={ + "operator_name": backend, + }, + ), + metric=MetricInfo( + name="backward latency", + unit="us", + benchmark_values=[results.bwd_time], + target_value=None, + ), + ) + records.append(asdict(record_bwd_latency)) + + # Write all records to the output file + with open(output_file, "w", encoding="utf-8") as f: + json.dump(records, f, indent=2) + + +def main( + dynamic: bool = False, + calculate_bwd: bool = False, + dtype: DtypeString = "bfloat16", + b: list[int] | None = None, + nh: list[str] | None = None, + s: list[int] | None = None, + d: list[int] | None = None, + mods: list[AttentionType] | None = None, + backend: list[Backend] | None = None, + max_autotune: bool = False, + decoding: bool = False, + kv_size: Optional[list[int]] = None, + throughput: bool = True, + save_path: Optional[str] = None, + output_json_for_dashboard: Optional[str] = None, + benchmark_name: str = "PyTorch operator microbenchmark", +) -> None: + """Run sweep over sizes and score mods for flex attention. + + Usage Examples: + # Use a yml config file + python score_mod.py --config basic_config.yaml + + # Use a json config file + python score_mod.py --config my_config.json + + # Generate a config template + python score_mod.py --print-config json > my_config.json # For a json config + python score_mod.py --print-config yaml > my_config.yaml # For a yaml config + + # Override config with CLI args + python score_mod.py --config my_config.json -dtype float16 --max-autotune + + # Pure CLI usage + python score_mod.py -b 4 8 -s 1024 2048 -mods causal alibi --backend efficient + + Args: + dynamic: Runs a dynamic shapes version of compiled flex attention + calculate_bwd: Calculate backward pass times + dtype: Data type for tensors (bfloat16, float16, float32) + b: Batch sizes to benchmark + nh: Number of query and key/value heads in format "Hq,Hkv" + s: Sequence lengths to benchmark + d: Head dimensions to benchmark + mods: Score modifications: noop, causal, rel, head_bias, alibi, sliding_window, document_mask, prefix_lm, softcap + backend: Backends for attention computation: math, efficient, cudnn, fav2, fav3, fakv, og-eager + max_autotune: Turn on max-autotune optimization + decoding: Benchmark decoding mode (query sequence length = 1) + kv_size: Key/value cache size in MiB (ignores batch size if specified) + throughput: Calculate kernel memory bandwidth & computational throughput (always True) + save_path: Path to save the results CSV file + output_json_for_dashboard: Path to save results in JSON format for PyTorch OSS dashboard + benchmark_name: Name of the benchmark for dashboard output + """ + # Convert dtype string to torch dtype (if not already converted) + import torch + + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + + # Always calculate throughput + throughput = True + print("Backend: ", backend) seed = 123 np.random.seed(seed) torch.manual_seed(seed) results = [] - for config in tqdm( - generate_experiment_configs( - args.calculate_bwd, - args.dtype, - args.b, - args.nh, - args.s, - args.d, - args.mods, - args.decoding, - args.kv_size, - args.throughput, - args.backend, - ) + for experiment_count, config in enumerate( + tqdm( + generate_experiment_configs( + calculate_bwd, + dtype, + b, + nh, + s, + d, + mods, + decoding, + kv_size, + throughput, + backend, + max_autotune, + ) + ), + start=1, ): results.append( Experiment( config, run_single_experiment( config, - dynamic=args.dynamic, - max_autotune=args.max_autotune, + dynamic=dynamic, ), ) ) - print_results(results, args.save_path) + # Periodic memory cleanup every 50 experiments + if experiment_count % 50 == 0: + cleanup_memory() + print_results(results, save_path) -def heads_input_type(s): - try: - hq, hkv = map(int, s.split(",")) - return hq, hkv - except Exception as e: - raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e + # Output JSON for dashboard if requested + if output_json_for_dashboard: + _output_json_for_dashboard(results, output_json_for_dashboard, benchmark_name) if __name__ == "__main__": @@ -1130,6 +1501,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Run sweep over sizes and score mods for flex attention" ) + parser.add_argument( + "--config", + type=str, + help="Path to JSON config file. CLI args override config file values.", + default=None, + ) parser.add_argument( "--dynamic", action="store_true", @@ -1199,8 +1576,49 @@ Ignores -b batch size and calculate batch size from kv size instead when specifi default=["efficient"], help="Backend to use for attention computation", ) + parser.add_argument( + "--output-json-for-dashboard", + type=str, + help="Path to save results in JSON format for PyTorch OSS dashboard", + default=None, + ) + parser.add_argument( + "--benchmark-name", + type=str, + help="Name of the benchmark for dashboard output", + default="PyTorch operator microbenchmark", + ) + parser.add_argument( + "--print-config", + type=str, + choices=["json", "yaml"], + help="Print a default config template in JSON or YAML format and exit", + default=None, + ) # Parse arguments args = parser.parse_args() - args.dtype = getattr(torch, args.dtype) - main(args) + # Handle --print-config + if args.print_config: + print_default_config(args.print_config) + sys.exit(0) + + # Load and merge config if provided + if args.config: + config = load_config_file(args.config) + + # Merge config with CLI args (CLI args take precedence) + json_args = argparse.Namespace() + json_args.__dict__ = config + args = parser.parse_args(namespace=json_args) + + # Convert dtype string to torch dtype (only if it's still a string) + if isinstance(args.dtype, str): + args.dtype = getattr(torch, args.dtype) + + # Remove config and print_config from args before passing to main + args_dict = vars(args) + args_dict.pop("config", None) + args_dict.pop("print_config", None) + + main(**args_dict)