From 978faf1fa29444f78a7ca805f8abc032cb29e0d8 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 26 Jan 2024 20:28:36 -0600 Subject: [PATCH] Use an op counter to decide when to realize a kernel (#117030) Instead of checking the number of bytes in the string representation of the kernel Pull Request resolved: https://github.com/pytorch/pytorch/pull/117030 Approved by: https://github.com/lezcano, https://github.com/peterbell10 --- test/inductor/test_perf.py | 4 +-- torch/_inductor/config.py | 2 +- torch/_inductor/graph.py | 2 +- torch/_inductor/ir.py | 67 ++++++++++++++++++++++++++++---------- 4 files changed, 54 insertions(+), 21 deletions(-) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 1ac817738b0..39fc70dd687 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -447,7 +447,7 @@ class SchedulerFusionTests(TestCase): def setUpClass(cls): super().setUpClass() cls._stack = contextlib.ExitStack() - cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0)) + cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0)) @classmethod def tearDownClass(cls): @@ -808,7 +808,7 @@ class WouldBeNiceIfItWorked: self.assertExpectedInline(count_numel(f, *inp), """200""") # TODO: The greedy fusion strategy results in suboptimal grouping - @patch.object(config, "realize_bytes_threshold", 0) + @patch.object(config, "realize_opcount_threshold", 0) def test_fusion_choice4(self): def f(a, b, b2): c = a + b diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0946c0a16f9..1357aab6c83 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -247,7 +247,7 @@ warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" # For fanouts, rematerialization can lead to exponential blowup. So, have # smaller threshold realize_reads_threshold = 4 -realize_bytes_threshold = 2000 +realize_opcount_threshold = 30 # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index cdae338b2c9..32770b1f546 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -968,7 +968,7 @@ class GraphLowering(torch.fx.Interpreter): curr = result.data.data if isinstance(curr, Pointwise): # Use inner fn as a rough proxy. Good enough. - if curr.inner_fn_str_len() > config.realize_bytes_threshold: + if curr.has_large_inner_fn(): result.realize() # This is not complete, but it doesn't have to be: origin_node diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f6662b4b8e2..37958adf177 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -333,6 +333,31 @@ class IRNode: realize_hint: Callable[[], None] +class _OpCounterCSE: + """Shim to count how many ops are used""" + + def __init__(self, inner): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names = {} + + def __getattr__(self, name): + def inner(*args, **kwargs): + val = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return val + if val not in self.var_names: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + else: + return self.var_names[val] + + return inner + + @dataclasses.dataclass class Loops(IRNode): device: torch.device @@ -400,12 +425,27 @@ class Loops(IRNode): ] @cache_on_self - def inner_fn_str_len(self): - return len(self.inner_fn_str()) + def inner_fn_opcount(self): + from .ir import FlexibleLayout + + opcounter = _OpCounterCSE(V.MockHandler()) + + with V.set_ops_handler(opcounter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = self.inner_fn(*self.inner_fn_args()) + return opcounter.op_count + + def inner_fn_args(self): + return (self._index(self.ranges),) def inner_fn_str(self): - index = self._index(self.ranges) - return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index) + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self): + return self.inner_fn_opcount() > config.realize_opcount_threshold def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) @@ -612,14 +652,10 @@ class Reduction(Loops): def index_length(self): return len(self.ranges) + len(self.reduction_ranges) - def inner_fn_str(self): + def inner_fn_args(self): index = self._index(self.ranges) rindex = self._index(self.reduction_ranges, "r") - return V.KernelFormatterHandler.ir_to_string( - self.inner_fn, - index, - rindex, - ) + return (index, rindex) def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) @@ -1605,14 +1641,11 @@ class Scan(Loops): def index_length(self): return len(self.ranges) + len(self.scan_ranges) - def inner_fn_str(self): + def inner_fn_args(self): index = self._index(self.ranges) rindex = self._index(self.scan_ranges, "r") idx = self.reindex(index, rindex) - return V.KernelFormatterHandler.ir_to_string( - self.inner_fn, - idx, - ) + return (idx,) def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) @@ -6473,7 +6506,7 @@ class StorageBox(MutableBox): def has_exceeded_max_reads(self): return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold - or self.inner_fn_str_len() > config.realize_bytes_threshold + or self.has_large_inner_fn() ) def mark_reuse(self, users): @@ -6495,7 +6528,7 @@ class StorageBox(MutableBox): and isinstance(self.data, (Pointwise, Reduction)) and ( self.num_reads() > config.realize_reads_threshold - or len(self.inner_fn_str()) > config.realize_bytes_threshold + or self.has_large_inner_fn() or (is_cpu(self.data) and should_realize_on_cpu(self.data)) ) ):