Skip to content

Commit 4ffac30

Browse files
authored
Autotune eviction policy (#823)
1 parent 55d6aa0 commit 4ffac30

File tree

12 files changed

+360
-27
lines changed

12 files changed

+360
-27
lines changed

helion/_compiler/compile_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9696
self.specialized_vars: set[sympy.Symbol] = set()
9797
self.loop_dependency_checker = LoopDependencyChecker()
9898
self._symint_cache: dict[object, torch.SymInt] = {}
99+
self.device_load_count = (
100+
0 # Track number of loads in all device code for eviction policy tuning
101+
)
99102

100103
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
101104
from .device_function import contains_only_block_size_symbols

helion/_compiler/device_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
242242
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)
243243

244244
self.rng_seed_count = 0
245+
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
245246
# Name of the RNG seed buffer parameter in kernel signature
246247
self.rng_seed_buffer_param_name = None
247248

helion/_compiler/device_ir.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,55 @@ def visit_For(self, node: ast.For) -> None:
10651065
self.generic_visit(node)
10661066

10671067

1068+
def _count_device_loads(device_ir: DeviceIR) -> int:
1069+
"""Count the number of load operations in all device code for eviction policy tuning."""
1070+
from ..language import memory_ops
1071+
1072+
# Build set of rolled graph IDs to exclude (these are duplicates)
1073+
rolled_graph_ids = {
1074+
info.new_graph_id
1075+
for info in device_ir.rolled_reductions
1076+
if info.new_graph_id is not None
1077+
}
1078+
1079+
load_count = 0
1080+
# Walk all graphs except rolled duplicates
1081+
for graph_info in device_ir.graphs:
1082+
if graph_info.graph_id in rolled_graph_ids:
1083+
continue
1084+
1085+
for node in graph_info.graph.nodes:
1086+
# Check if this is a load operation
1087+
if node.op == "call_function" and node.target is memory_ops.load:
1088+
# Only count loads without explicit eviction policy
1089+
# (user can still specify eviction_policy to override tuning)
1090+
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1091+
eviction_policy_arg = node.kwargs.get("eviction_policy")
1092+
if eviction_policy_arg is None:
1093+
# Check if eviction_policy was passed as positional arg (index 3)
1094+
if len(node.args) >= 4:
1095+
eviction_policy_arg = node.args[3]
1096+
if eviction_policy_arg is None:
1097+
load_count += 1
1098+
return load_count
1099+
1100+
1101+
def _register_eviction_policy_tunable(load_count: int) -> None:
1102+
"""Register the eviction policy tunable for all device loads."""
1103+
if load_count == 0:
1104+
return
1105+
1106+
from ..autotuner.config_fragment import EnumFragment
1107+
from ..autotuner.config_fragment import ListOf
1108+
from ..autotuner.config_spec import VALID_EVICTION_POLICIES
1109+
1110+
env = CompileEnvironment.current()
1111+
# Register a tunable for eviction policies for all device loads
1112+
fragment = ListOf(EnumFragment(choices=VALID_EVICTION_POLICIES), length=load_count)
1113+
env.config_spec.load_eviction_policies = fragment
1114+
env.device_load_count = load_count
1115+
1116+
10681117
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
10691118
device_ir = DeviceIR()
10701119
with func, device_ir, compile_lock:
@@ -1085,6 +1134,11 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
10851134
if len(device_ir.root_ids) > 1:
10861135
# xyz not supported with shared program IDs, but persistent kernels are allowed
10871136
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
1137+
1138+
# Count all device loads and register eviction policy tunable
1139+
load_count = _count_device_loads(device_ir)
1140+
_register_eviction_policy_tunable(load_count)
1141+
10881142
return device_ir
10891143

10901144

helion/autotuner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .config_fragment import BooleanFragment as BooleanFragment
44
from .config_fragment import EnumFragment as EnumFragment
55
from .config_fragment import IntegerFragment as IntegerFragment
6+
from .config_fragment import ListOf as ListOf
67
from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment
78
from .config_spec import ConfigSpec as ConfigSpec
89
from .differential_evolution import (

helion/autotuner/config_fragment.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,49 @@ def category(self) -> Category:
221221
class NumWarpsFragment(PowerOfTwoFragment):
222222
def category(self) -> Category:
223223
return Category.NUM_WARPS
224+
225+
226+
@dataclasses.dataclass
227+
class ListOf(ConfigSpecFragment):
228+
"""Wrapper that creates a list of independently tunable fragments.
229+
230+
Example:
231+
ListOf(EnumFragment(choices=("a", "b", "c")), length=5)
232+
creates a list of 5 independently tunable enum values.
233+
"""
234+
235+
inner: ConfigSpecFragment
236+
length: int
237+
238+
def default(self) -> list[object]:
239+
"""Return a list of default values."""
240+
return [self.inner.default() for _ in range(self.length)]
241+
242+
def random(self) -> list[object]:
243+
"""Return a list of random values."""
244+
return [self.inner.random() for _ in range(self.length)]
245+
246+
def pattern_neighbors(self, current: object) -> list[object]:
247+
"""Return neighbors by changing one element at a time."""
248+
if not isinstance(current, list) or len(current) != self.length:
249+
raise ValueError(f"Expected list of length {self.length}, got {current!r}")
250+
251+
neighbors: list[object] = []
252+
# For each position, try all neighbors from the inner fragment
253+
for i in range(self.length):
254+
for neighbor_value in self.inner.pattern_neighbors(current[i]):
255+
neighbor = current.copy()
256+
neighbor[i] = neighbor_value
257+
neighbors.append(neighbor)
258+
return neighbors
259+
260+
def differential_mutation(self, a: object, b: object, c: object) -> list[object]:
261+
"""Create a new value by combining a, b, and c element-wise."""
262+
assert isinstance(a, list) and len(a) == self.length
263+
assert isinstance(b, list) and len(b) == self.length
264+
assert isinstance(c, list) and len(c) == self.length
265+
266+
return [
267+
self.inner.differential_mutation(a[i], b[i], c[i])
268+
for i in range(self.length)
269+
]

helion/autotuner/config_spec.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .config_fragment import ConfigSpecFragment
1919
from .config_fragment import EnumFragment
2020
from .config_fragment import IntegerFragment
21+
from .config_fragment import ListOf
2122
from .config_fragment import NumWarpsFragment
2223
from .config_fragment import PermutationFragment
2324
from .config_fragment import PowerOfTwoFragment
@@ -50,9 +51,11 @@
5051
"num_stages",
5152
"pid_type",
5253
"indexing",
54+
"load_eviction_policies",
5355
]
5456
)
5557
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
58+
VALID_EVICTION_POLICIES = ("", "first", "last")
5659

5760

5861
@dataclasses.dataclass
@@ -97,6 +100,11 @@ class ConfigSpec:
97100
default_factory=functools.partial(tuple, VALID_PID_TYPES)
98101
)
99102
grid_block_ids: list[int] = dataclasses.field(default_factory=list)
103+
load_eviction_policies: ListOf = dataclasses.field(
104+
default_factory=lambda: ListOf(
105+
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106+
)
107+
)
100108

101109
@staticmethod
102110
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -206,12 +214,16 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
206214
"range_multi_buffers",
207215
"range_flattens",
208216
"static_ranges",
217+
"load_eviction_policies",
209218
):
210-
if not config[name]:
211-
config.pop(name)
219+
if not config.get(name):
220+
config.pop(name, None)
212221

213222
config.setdefault("num_warps", DEFAULT_NUM_WARPS)
214223
config.setdefault("num_stages", DEFAULT_NUM_STAGES)
224+
config.setdefault(
225+
"load_eviction_policies", self.load_eviction_policies.default()
226+
)
215227
# TODO(jansel): include num_ctas and max_nreg
216228

217229
for name, values in (
@@ -266,10 +278,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
266278
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
267279
"indexing": fn(EnumFragment(self._valid_indexing_types())),
268280
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
281+
"load_eviction_policies": fn(self.load_eviction_policies),
269282
}
270283
# Add tunable parameters
271-
for key, fragment in self.user_defined_tunables.items():
272-
config[key] = fn(fragment)
284+
config.update(
285+
{key: fn(fragment) for key, fragment in self.user_defined_tunables.items()}
286+
)
273287

274288
for name in (
275289
"loop_orders",
@@ -282,9 +296,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
282296
"range_multi_buffers",
283297
"range_flattens",
284298
"static_ranges",
299+
"load_eviction_policies",
285300
):
286-
if not config[name]:
287-
config.pop(name)
301+
if not config.get(name):
302+
config.pop(name, None)
288303
self.normalize(config)
289304
return helion.Config(**config)
290305

helion/language/memory_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
__all__ = ["load", "store"]
1818

19+
# Map short config names to full Triton API names for eviction policies
20+
_EVICTION_POLICY_MAP = {
21+
"": None,
22+
"first": "evict_first",
23+
"last": "evict_last",
24+
}
25+
1926

2027
@has_side_effect
2128
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
@@ -242,6 +249,16 @@ def _(state: CodegenState) -> ast.AST:
242249
extra_mask = state.ast_args[2]
243250
assert isinstance(extra_mask, (type(None), ast.AST))
244251
eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None
252+
253+
# If no explicit eviction_policy and we're in device code, use tunable
254+
if eviction_policy is None and state.codegen.on_device:
255+
policies = state.config.load_eviction_policies
256+
idx = state.device_function.device_load_index
257+
if idx < len(policies):
258+
policy_value = policies[idx]
259+
eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value)
260+
state.device_function.device_load_index += 1
261+
245262
if eviction_policy is not None:
246263
assert isinstance(eviction_policy, str)
247264
eviction_policy = ast.Constant(value=eviction_policy)

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
range_multi_buffers: list[bool | None] | None = None,
3535
range_flattens: list[bool | None] | None = None,
3636
static_ranges: list[bool] | None = None,
37+
load_eviction_policies: list[str] | None = None,
3738
num_warps: int | None = None,
3839
num_stages: int | None = None,
3940
pid_type: PidTypeLiteral | None = None,
@@ -55,6 +56,7 @@ def __init__(
5556
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
5657
range_flattens: Controls flatten parameter for tl.range calls.
5758
static_ranges: Whether to use tl.static_range instead tl.range.
59+
load_eviction_policies: Eviction policies for load operations ("", "first", "last").
5860
num_warps: Number of warps per block.
5961
num_stages: Number of stages for software pipelining.
6062
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
@@ -74,6 +76,7 @@ def __init__(
7476
"range_multi_buffers": range_multi_buffers,
7577
"range_flattens": range_flattens,
7678
"static_ranges": static_ranges,
79+
"load_eviction_policies": load_eviction_policies,
7780
"num_warps": num_warps,
7881
"num_stages": num_stages,
7982
"indexing": indexing,
@@ -189,6 +192,10 @@ def range_flattens(self) -> list[bool | None]:
189192
def static_ranges(self) -> list[bool]:
190193
return cast("list[bool]", self.config.get("static_ranges", []))
191194

195+
@property
196+
def load_eviction_policies(self) -> list[str]:
197+
return cast("list[str]", self.config.get("load_eviction_policies", []))
198+
192199
@property
193200
def indexing(self) -> IndexingLiteral:
194201
return self.config.get("indexing", "pointer") # type: ignore[return-value]

0 commit comments

Comments
 (0)