mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Use an op counter to decide when to realize a kernel (#117030)
Instead of checking the number of bytes in the string representation of the kernel Pull Request resolved: https://github.com/pytorch/pytorch/pull/117030 Approved by: https://github.com/lezcano, https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
800e2e823f
commit
978faf1fa2
@@ -447,7 +447,7 @@ class SchedulerFusionTests(TestCase):
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._stack = contextlib.ExitStack()
|
||||
cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0))
|
||||
cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0))
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -808,7 +808,7 @@ class WouldBeNiceIfItWorked:
|
||||
self.assertExpectedInline(count_numel(f, *inp), """200""")
|
||||
|
||||
# TODO: The greedy fusion strategy results in suboptimal grouping
|
||||
@patch.object(config, "realize_bytes_threshold", 0)
|
||||
@patch.object(config, "realize_opcount_threshold", 0)
|
||||
def test_fusion_choice4(self):
|
||||
def f(a, b, b2):
|
||||
c = a + b
|
||||
|
||||
@@ -247,7 +247,7 @@ warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
|
||||
# For fanouts, rematerialization can lead to exponential blowup. So, have
|
||||
# smaller threshold
|
||||
realize_reads_threshold = 4
|
||||
realize_bytes_threshold = 2000
|
||||
realize_opcount_threshold = 30
|
||||
|
||||
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
|
||||
realize_acc_reads_threshold = 8
|
||||
|
||||
@@ -968,7 +968,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
curr = result.data.data
|
||||
if isinstance(curr, Pointwise):
|
||||
# Use inner fn as a rough proxy. Good enough.
|
||||
if curr.inner_fn_str_len() > config.realize_bytes_threshold:
|
||||
if curr.has_large_inner_fn():
|
||||
result.realize()
|
||||
|
||||
# This is not complete, but it doesn't have to be: origin_node
|
||||
|
||||
@@ -333,6 +333,31 @@ class IRNode:
|
||||
realize_hint: Callable[[], None]
|
||||
|
||||
|
||||
class _OpCounterCSE:
|
||||
"""Shim to count how many ops are used"""
|
||||
|
||||
def __init__(self, inner):
|
||||
super().__init__()
|
||||
self.parent_handler = inner
|
||||
self.op_count = 0
|
||||
self.var_names = {}
|
||||
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
val = getattr(self.parent_handler, name)(*args, **kwargs)
|
||||
if name == "indirect_indexing":
|
||||
return val
|
||||
if val not in self.var_names:
|
||||
varname = f"tmp{self.op_count}"
|
||||
self.op_count += 1
|
||||
self.var_names[val] = varname
|
||||
return varname
|
||||
else:
|
||||
return self.var_names[val]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Loops(IRNode):
|
||||
device: torch.device
|
||||
@@ -400,12 +425,27 @@ class Loops(IRNode):
|
||||
]
|
||||
|
||||
@cache_on_self
|
||||
def inner_fn_str_len(self):
|
||||
return len(self.inner_fn_str())
|
||||
def inner_fn_opcount(self):
|
||||
from .ir import FlexibleLayout
|
||||
|
||||
opcounter = _OpCounterCSE(V.MockHandler())
|
||||
|
||||
with V.set_ops_handler(opcounter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
):
|
||||
result = self.inner_fn(*self.inner_fn_args())
|
||||
return opcounter.op_count
|
||||
|
||||
def inner_fn_args(self):
|
||||
return (self._index(self.ranges),)
|
||||
|
||||
def inner_fn_str(self):
|
||||
index = self._index(self.ranges)
|
||||
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)
|
||||
return V.KernelFormatterHandler.ir_to_string(
|
||||
self.inner_fn, *self.inner_fn_args()
|
||||
)
|
||||
|
||||
def has_large_inner_fn(self):
|
||||
return self.inner_fn_opcount() > config.realize_opcount_threshold
|
||||
|
||||
def inner_fn_free_unbacked_symbols(self):
|
||||
index = self._index(self.ranges)
|
||||
@@ -612,14 +652,10 @@ class Reduction(Loops):
|
||||
def index_length(self):
|
||||
return len(self.ranges) + len(self.reduction_ranges)
|
||||
|
||||
def inner_fn_str(self):
|
||||
def inner_fn_args(self):
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.reduction_ranges, "r")
|
||||
return V.KernelFormatterHandler.ir_to_string(
|
||||
self.inner_fn,
|
||||
index,
|
||||
rindex,
|
||||
)
|
||||
return (index, rindex)
|
||||
|
||||
def inner_fn_free_unbacked_symbols(self):
|
||||
index = self._index(self.ranges)
|
||||
@@ -1605,14 +1641,11 @@ class Scan(Loops):
|
||||
def index_length(self):
|
||||
return len(self.ranges) + len(self.scan_ranges)
|
||||
|
||||
def inner_fn_str(self):
|
||||
def inner_fn_args(self):
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.scan_ranges, "r")
|
||||
idx = self.reindex(index, rindex)
|
||||
return V.KernelFormatterHandler.ir_to_string(
|
||||
self.inner_fn,
|
||||
idx,
|
||||
)
|
||||
return (idx,)
|
||||
|
||||
def inner_fn_free_unbacked_symbols(self):
|
||||
index = self._index(self.ranges)
|
||||
@@ -6473,7 +6506,7 @@ class StorageBox(MutableBox):
|
||||
def has_exceeded_max_reads(self):
|
||||
return isinstance(self.data, Pointwise) and (
|
||||
self.num_reads() > config.realize_acc_reads_threshold
|
||||
or self.inner_fn_str_len() > config.realize_bytes_threshold
|
||||
or self.has_large_inner_fn()
|
||||
)
|
||||
|
||||
def mark_reuse(self, users):
|
||||
@@ -6495,7 +6528,7 @@ class StorageBox(MutableBox):
|
||||
and isinstance(self.data, (Pointwise, Reduction))
|
||||
and (
|
||||
self.num_reads() > config.realize_reads_threshold
|
||||
or len(self.inner_fn_str()) > config.realize_bytes_threshold
|
||||
or self.has_large_inner_fn()
|
||||
or (is_cpu(self.data) and should_realize_on_cpu(self.data))
|
||||
)
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user