Migrate from Tuple -> tuple in benchmarks/instruction_counts/core (#144253)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144253
Approved by: https://github.com/aorenste
This commit is contained in:
bobrenjc93
2025-01-09 22:29:41 +00:00
committed by PyTorch MergeBot
parent a55977f763
commit 3607ff2c1d
4 changed files with 21 additions and 21 deletions

View File

@@ -7,7 +7,7 @@ import enum
import itertools as it
import re
import textwrap
from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Union
from worker.main import WorkerTimerArgs
@@ -159,11 +159,11 @@ class GroupedBenchmark:
# Described above
setup: GroupedSetup
signature_args: Optional[Tuple[str, ...]]
signature_args: Optional[tuple[str, ...]]
signature_output: Optional[str]
torchscript: bool
autograd: bool
num_threads: Tuple[int, ...]
num_threads: tuple[int, ...]
@classmethod
def init_from_stmts(
@@ -175,7 +175,7 @@ class GroupedBenchmark:
signature: Optional[str] = None,
torchscript: bool = False,
autograd: bool = False,
num_threads: Union[int, Tuple[int, ...]] = 1,
num_threads: Union[int, tuple[int, ...]] = 1,
) -> "GroupedBenchmark":
"""Create a set of benchmarks from free-form statements.
@@ -223,7 +223,7 @@ class GroupedBenchmark:
signature: Optional[str] = None,
torchscript: bool = False,
autograd: bool = False,
num_threads: Union[int, Tuple[int, ...]] = 1,
num_threads: Union[int, tuple[int, ...]] = 1,
) -> "GroupedBenchmark":
"""Create a set of benchmarks using torch.nn Modules.
@@ -260,8 +260,8 @@ class GroupedBenchmark:
cls,
py_block: str = "",
cpp_block: str = "",
num_threads: Union[int, Tuple[int, ...]] = 1,
) -> Dict[Union[Tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
num_threads: Union[int, tuple[int, ...]] = 1,
) -> Dict[Union[tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
py_cases, py_setup, py_global_setup = cls._parse_variants(
py_block, Language.PYTHON
)
@@ -279,7 +279,7 @@ class GroupedBenchmark:
# NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused
# and we use the superset `Union[Tuple[str, ...], Optional[str]` to
# match the expected signature.
variants: Dict[Union[Tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
variants: Dict[Union[tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
seen_labels: Set[str] = set()
for label in it.chain(py_cases.keys(), cpp_cases.keys()):
@@ -333,7 +333,7 @@ class GroupedBenchmark:
@staticmethod
def _parse_signature(
signature: Optional[str],
) -> Tuple[Optional[Tuple[str, ...]], Optional[str]]:
) -> tuple[Optional[tuple[str, ...]], Optional[str]]:
if signature is None:
return None, None
@@ -341,7 +341,7 @@ class GroupedBenchmark:
if match is None:
raise ValueError(f"Invalid signature: `{signature}`")
args: Tuple[str, ...] = tuple(match.groups()[0].split(", "))
args: tuple[str, ...] = tuple(match.groups()[0].split(", "))
output: str = match.groups()[1].strip()
if "," in output:
@@ -357,7 +357,7 @@ class GroupedBenchmark:
@staticmethod
def _model_from_py_stmt(
py_stmt: Optional[str],
signature_args: Optional[Tuple[str, ...]],
signature_args: Optional[tuple[str, ...]],
signature_output: Optional[str],
) -> str:
if py_stmt is None:
@@ -376,10 +376,10 @@ class GroupedBenchmark:
@staticmethod
def _make_model_invocation(
signature_args: Tuple[str, ...],
signature_args: tuple[str, ...],
signature_output: Optional[str],
runtime: RuntimeMode,
) -> Tuple[str, str]:
) -> tuple[str, str]:
py_prefix, cpp_prefix = "", ""
if signature_output is not None:
py_prefix = f"{signature_output} = "
@@ -415,7 +415,7 @@ class GroupedBenchmark:
@staticmethod
def _parse_variants(
block: str, language: Language
) -> Tuple[Dict[str, List[str]], str, str]:
) -> tuple[Dict[str, List[str]], str, str]:
block = textwrap.dedent(block).strip()
comment = "#" if language == Language.PYTHON else "//"
label_pattern = f"{comment} @(.+)$"

View File

@@ -12,7 +12,7 @@ import os
import re
import textwrap
import uuid
from typing import List, Optional, Tuple, TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING
import torch
@@ -204,7 +204,7 @@ def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition:
GroupedBenchmarks into multiple TimerArgs, and tagging the results with
AutoLabels.
"""
results: List[Tuple[Label, AutoLabels, TimerArgs]] = []
results: List[tuple[Label, AutoLabels, TimerArgs]] = []
for label, args in benchmarks.items():
if isinstance(args, TimerArgs):

View File

@@ -2,7 +2,7 @@
# mypy: ignore-errors
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Union
from core.api import AutoLabels, GroupedBenchmark, TimerArgs
@@ -66,7 +66,7 @@ TL;DR
# Allow strings in definition for convenience, and None to signify a base
# case. (No subsequent entry needed. See the "add" example above.)
Label = Tuple[str, ...]
Label = tuple[str, ...]
_Label = Union[Label, Optional[str]]
_Value = Union[
@@ -82,4 +82,4 @@ Definition = Dict[_Label, _Value]
FlatIntermediateDefinition = Dict[Label, Union[TimerArgs, GroupedBenchmark]]
# Final parsed schema.
FlatDefinition = Tuple[Tuple[Label, AutoLabels, TimerArgs], ...]
FlatDefinition = tuple[tuple[Label, AutoLabels, TimerArgs], ...]

View File

@@ -3,7 +3,7 @@ import atexit
import re
import shutil
import textwrap
from typing import List, Optional, Tuple
from typing import List, Optional
from core.api import GroupedBenchmark, TimerArgs
from core.types import Definition, FlatIntermediateDefinition, Label
@@ -59,7 +59,7 @@ def flatten(schema: Definition) -> FlatIntermediateDefinition:
return result
def parse_stmts(stmts: str) -> Tuple[str, str]:
def parse_stmts(stmts: str) -> tuple[str, str]:
"""Helper function for side-by-side Python and C++ stmts.
For more complex statements, it can be useful to see Python and C++ code