mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Claude generate the change and I review/test/publish it. The tests will fail in fbcode (thanks Ed for flagging). Disabling it in fbcode for now. I think for now oss signal is enough. If we really want signals in fbcode, we probably need replace the 'python' command running the benchmark script with 'buck' command. One instance of the failure in fbcode: https://www.internalfb.com/tasks/?t=246383644 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169074 Approved by: https://github.com/eellison, https://github.com/v0i0
178 lines
6.5 KiB
Python
178 lines
6.5 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import contextlib
|
|
import os
|
|
import pathlib
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._inductor.config as inductor_config
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import fresh_cache
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfXpu,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
HAS_GPU_AND_TRITON,
|
|
IS_BIG_GPU,
|
|
IS_FBCODE,
|
|
)
|
|
|
|
|
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class DeterministicTest(TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._exit_stack = contextlib.ExitStack()
|
|
self._exit_stack.enter_context(fresh_cache())
|
|
|
|
def tearDown(self) -> None:
|
|
self._exit_stack.close()
|
|
super().tearDown()
|
|
|
|
def test_use_deterministic_algorithsm(self):
|
|
old_val = torch.are_deterministic_algorithms_enabled()
|
|
try:
|
|
for new_val in [True, False, True]:
|
|
torch.use_deterministic_algorithms(new_val, warn_only=True)
|
|
self.assertEqual(inductor_config.deterministic, new_val)
|
|
finally:
|
|
torch.use_deterministic_algorithms(old_val, warn_only=True)
|
|
|
|
@skipIfXpu(msg="pad_mm is not enabled for XPU.")
|
|
@parametrize("deterministic", [False, True])
|
|
def test_mm_padding(self, deterministic):
|
|
with inductor_config.patch(deterministic=deterministic):
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x @ y
|
|
|
|
inps = [torch.rand([2049, 2049], device=GPU_TYPE) for _ in range(2)]
|
|
out = foo(*inps)
|
|
self.assertEqual(out, inps[0] @ inps[1])
|
|
|
|
if deterministic:
|
|
self.assertTrue(counters["inductor"]["pad_mm_bench"] == 0)
|
|
else:
|
|
self.assertTrue(counters["inductor"]["pad_mm_bench"] > 0)
|
|
|
|
@parametrize("deterministic", [False, True])
|
|
@inductor_config.patch(max_autotune=True)
|
|
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
|
def test_max_autotune(self, deterministic):
|
|
with inductor_config.patch(deterministic=deterministic):
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x @ y
|
|
|
|
inps = [torch.rand([2048, 2048], device=GPU_TYPE) for _ in range(2)]
|
|
out = foo(*inps)
|
|
self.assertEqual(out, inps[0] @ inps[1])
|
|
|
|
if deterministic:
|
|
self.assertTrue(counters["inductor"]["select_algorithm_autotune"] == 0)
|
|
else:
|
|
self.assertTrue(counters["inductor"]["select_algorithm_autotune"] > 0)
|
|
|
|
def test_pointwise_coordesc_tuning(self):
|
|
@torch.compile(mode="max-autotune")
|
|
def f(x):
|
|
return x + 1
|
|
|
|
x = torch.randn(2048, device=GPU_TYPE)
|
|
self.assertEqual(f(x), x + 1)
|
|
|
|
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0)
|
|
|
|
@parametrize("deterministic", [False, True])
|
|
def test_reduction_coordesc_tuning(self, deterministic):
|
|
with inductor_config.patch(
|
|
deterministic=deterministic, coordinate_descent_tuning=True
|
|
):
|
|
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x.sum(dim=-1)
|
|
|
|
inp = torch.rand([2048, 2048], device=GPU_TYPE)
|
|
|
|
out = foo(inp)
|
|
self.assertEqual(out, inp.sum(dim=-1))
|
|
|
|
if deterministic:
|
|
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] == 0)
|
|
else:
|
|
self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Skipping run2run determinism test in fbcode")
|
|
@parametrize("model_name", ["GoogleFnet", "BertForMaskedLM", "DistillGPT2"])
|
|
@parametrize("training_or_inference", ["training", "inference"])
|
|
@parametrize("precision", ["float32", "bfloat16", "float16", "amp"])
|
|
def test_run2run_determinism(self, model_name, training_or_inference, precision):
|
|
"""
|
|
Test run2run determinism for a few huggingface models.
|
|
|
|
The test assumes benchmarks/dynamo/huggingface.py can be found from
|
|
the current working directory.
|
|
"""
|
|
|
|
def _setup_env(env):
|
|
env["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" # disable autotune cache
|
|
env["TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE"] = "0"
|
|
env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"
|
|
if enable_determinism:
|
|
env["TORCHINDUCTOR_DETERMINISTIC"] = "1"
|
|
|
|
# set to false if you want to check how the test fails without
|
|
# the deterministic mode
|
|
enable_determinism = True
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
saved_pkl = os.path.join(tmpdir, "saved.pkl")
|
|
cmd = (
|
|
f"{sys.executable} {REPO_ROOT}/benchmarks/dynamo/huggingface.py --backend inductor"
|
|
+ f" --{precision} --accuracy --only {model_name} --{training_or_inference}"
|
|
+ f" --disable-cudagraphs --save-model-outputs-to={saved_pkl}"
|
|
)
|
|
print("Command", cmd)
|
|
env = os.environ.copy()
|
|
_setup_env(env)
|
|
out = subprocess.run(cmd.split(), capture_output=True, env=env)
|
|
|
|
# We don't check the accuracy against eager here because some
|
|
# of the combination between model and precision can not
|
|
# pass that accuracy test. But it's still valuable to make
|
|
# sure we generate bitwise equivalent result from run to run.
|
|
# self.assertTrue("pass" in out.stdout.decode())
|
|
|
|
cmd = (
|
|
f"{sys.executable} {REPO_ROOT}/benchmarks/dynamo/huggingface.py --backend inductor"
|
|
+ f" --{precision} --accuracy --only {model_name} --{training_or_inference}"
|
|
+ f" --disable-cudagraphs --compare-model-outputs-with={saved_pkl}"
|
|
)
|
|
print("Command", cmd)
|
|
|
|
# distort benchmarking results
|
|
env["TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT"] = "inverse"
|
|
out = subprocess.run(cmd.split(), capture_output=True, env=env)
|
|
self.assertTrue(
|
|
"The result is bitwise equivalent to the previously saved result"
|
|
in out.stdout.decode(),
|
|
f"stdout: {out.stdout.decode()}, stderr: {out.stderr.decode()}",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_GPU_AND_TRITON:
|
|
run_tests()
|