Use an op counter to decide when to realize a kernel (#117030)

Instead of checking the number of bytes in the string representation
of the kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117030
Approved by: https://github.com/lezcano, https://github.com/peterbell10
This commit is contained in:
Isuru Fernando
2024-01-26 20:28:36 -06:00
committed by PyTorch MergeBot
parent 800e2e823f
commit 978faf1fa2
4 changed files with 54 additions and 21 deletions

View File

@@ -447,7 +447,7 @@ class SchedulerFusionTests(TestCase):
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0))
cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0))
@classmethod
def tearDownClass(cls):
@@ -808,7 +808,7 @@ class WouldBeNiceIfItWorked:
self.assertExpectedInline(count_numel(f, *inp), """200""")
# TODO: The greedy fusion strategy results in suboptimal grouping
@patch.object(config, "realize_bytes_threshold", 0)
@patch.object(config, "realize_opcount_threshold", 0)
def test_fusion_choice4(self):
def f(a, b, b2):
c = a + b

View File

@@ -247,7 +247,7 @@ warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
# For fanouts, rematerialization can lead to exponential blowup. So, have
# smaller threshold
realize_reads_threshold = 4
realize_bytes_threshold = 2000
realize_opcount_threshold = 30
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
realize_acc_reads_threshold = 8

View File

@@ -968,7 +968,7 @@ class GraphLowering(torch.fx.Interpreter):
curr = result.data.data
if isinstance(curr, Pointwise):
# Use inner fn as a rough proxy. Good enough.
if curr.inner_fn_str_len() > config.realize_bytes_threshold:
if curr.has_large_inner_fn():
result.realize()
# This is not complete, but it doesn't have to be: origin_node

View File

@@ -333,6 +333,31 @@ class IRNode:
realize_hint: Callable[[], None]
class _OpCounterCSE:
"""Shim to count how many ops are used"""
def __init__(self, inner):
super().__init__()
self.parent_handler = inner
self.op_count = 0
self.var_names = {}
def __getattr__(self, name):
def inner(*args, **kwargs):
val = getattr(self.parent_handler, name)(*args, **kwargs)
if name == "indirect_indexing":
return val
if val not in self.var_names:
varname = f"tmp{self.op_count}"
self.op_count += 1
self.var_names[val] = varname
return varname
else:
return self.var_names[val]
return inner
@dataclasses.dataclass
class Loops(IRNode):
device: torch.device
@@ -400,12 +425,27 @@ class Loops(IRNode):
]
@cache_on_self
def inner_fn_str_len(self):
return len(self.inner_fn_str())
def inner_fn_opcount(self):
from .ir import FlexibleLayout
opcounter = _OpCounterCSE(V.MockHandler())
with V.set_ops_handler(opcounter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = self.inner_fn(*self.inner_fn_args())
return opcounter.op_count
def inner_fn_args(self):
return (self._index(self.ranges),)
def inner_fn_str(self):
index = self._index(self.ranges)
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)
return V.KernelFormatterHandler.ir_to_string(
self.inner_fn, *self.inner_fn_args()
)
def has_large_inner_fn(self):
return self.inner_fn_opcount() > config.realize_opcount_threshold
def inner_fn_free_unbacked_symbols(self):
index = self._index(self.ranges)
@@ -612,14 +652,10 @@ class Reduction(Loops):
def index_length(self):
return len(self.ranges) + len(self.reduction_ranges)
def inner_fn_str(self):
def inner_fn_args(self):
index = self._index(self.ranges)
rindex = self._index(self.reduction_ranges, "r")
return V.KernelFormatterHandler.ir_to_string(
self.inner_fn,
index,
rindex,
)
return (index, rindex)
def inner_fn_free_unbacked_symbols(self):
index = self._index(self.ranges)
@@ -1605,14 +1641,11 @@ class Scan(Loops):
def index_length(self):
return len(self.ranges) + len(self.scan_ranges)
def inner_fn_str(self):
def inner_fn_args(self):
index = self._index(self.ranges)
rindex = self._index(self.scan_ranges, "r")
idx = self.reindex(index, rindex)
return V.KernelFormatterHandler.ir_to_string(
self.inner_fn,
idx,
)
return (idx,)
def inner_fn_free_unbacked_symbols(self):
index = self._index(self.ranges)
@@ -6473,7 +6506,7 @@ class StorageBox(MutableBox):
def has_exceeded_max_reads(self):
return isinstance(self.data, Pointwise) and (
self.num_reads() > config.realize_acc_reads_threshold
or self.inner_fn_str_len() > config.realize_bytes_threshold
or self.has_large_inner_fn()
)
def mark_reuse(self, users):
@@ -6495,7 +6528,7 @@ class StorageBox(MutableBox):
and isinstance(self.data, (Pointwise, Reduction))
and (
self.num_reads() > config.realize_reads_threshold
or len(self.inner_fn_str()) > config.realize_bytes_threshold
or self.has_large_inner_fn()
or (is_cpu(self.data) and should_realize_on_cpu(self.data))
)
):