Skip to content

Commit 7aada66

Browse files
authored
[Autotuner] Add autotune_benchmark_fn setting (#1199)
1 parent a2f5ed1 commit 7aada66

File tree

4 files changed

+71
-2
lines changed

4 files changed

+71
-2
lines changed

docs/api/settings.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b
254254
255255
Override the callable that constructs autotuner instances. Accepts the same signature as :func:`helion.runtime.settings.default_autotuner_fn`.
256256
Pass a replacement callable via ``@helion.kernel(..., autotuner_fn=...)`` or ``helion.kernel(autotuner_fn=...)`` at definition time.
257+
258+
.. autoattribute:: Settings.autotune_benchmark_fn
259+
260+
Custom benchmark function for rebenchmarking during autotuning. Should have the signature
261+
``(fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None) -> list[float]``.
262+
If ``None`` (default), uses the built-in benchmark function.
263+
Pass a replacement callable via ``@helion.kernel(..., autotune_benchmark_fn=...)`` at definition time.
257264
```
258265

259266
Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"DifferentialEvolutionSearch"``, ``"FiniteSearch"``, and ``"RandomSearch"``.

helion/autotuner/base_search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -922,12 +922,13 @@ def rebenchmark(
922922
)
923923
repeat = min(1000, max(3, base_repeat))
924924
iterator = [functools.partial(m.fn, *self.args) for m in members]
925+
bench_fn = self.settings.autotune_benchmark_fn or interleaved_bench
925926
if self.settings.autotune_progress_bar:
926927
# pyrefly: ignore [bad-argument-type]
927-
new_timings = interleaved_bench(iterator, repeat=repeat, desc=desc)
928+
new_timings = bench_fn(iterator, repeat=repeat, desc=desc)
928929
else:
929930
# pyrefly: ignore [bad-argument-type]
930-
new_timings = interleaved_bench(iterator, repeat=repeat)
931+
new_timings = bench_fn(iterator, repeat=repeat)
931932
for m, t in zip(members, new_timings, strict=True):
932933
m.perfs.append(t)
933934
if t < self.best_perf_so_far:

helion/runtime/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ class _Settings:
412412
autotune_baseline_fn: Callable[..., object] | None = None
413413
autotune_baseline_atol: float | None = None
414414
autotune_baseline_rtol: float | None = None
415+
autotune_benchmark_fn: Callable[..., list[float]] | None = None
415416

416417

417418
class Settings(_Settings):
@@ -502,6 +503,12 @@ class Settings(_Settings):
502503
"Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. "
503504
"Defaults to 'LocalAutotuneCache'."
504505
),
506+
"autotune_benchmark_fn": (
507+
"Custom benchmark function for rebenchmarking during autotuning. "
508+
"Should have the following signature: "
509+
"(fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None) -> list[float]. "
510+
"If None (default), uses the built-in benchmark function."
511+
),
505512
}
506513

507514
def __init__(self, **settings: object) -> None:

test/test_autotuner.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,60 @@ def test_fragment_encoding(self):
12841284
encoded = fragment.encode(value)
12851285
self.assertEqual(len(encoded), dim)
12861286

1287+
@skipIfCpu("fails on Triton CPU backend")
1288+
def test_autotune_benchmark_fn(self) -> None:
1289+
"""Test that custom benchmark function is used during rebenchmarking."""
1290+
# Track benchmark function calls
1291+
benchmark_calls: list[tuple[int, int]] = [] # (num_fns, repeat)
1292+
1293+
def custom_benchmark_fn(
1294+
fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None
1295+
) -> list[float]:
1296+
benchmark_calls.append((len(fns), repeat))
1297+
# Return fake timings
1298+
return [1.0] * len(fns)
1299+
1300+
@helion.kernel(
1301+
autotune_benchmark_fn=custom_benchmark_fn,
1302+
autotune_log_level=0,
1303+
)
1304+
def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
1305+
out = torch.empty_like(a)
1306+
for tile in hl.tile(out.size()):
1307+
out[tile] = a[tile] + b[tile]
1308+
return out
1309+
1310+
args = (
1311+
torch.randn([128], device=DEVICE),
1312+
torch.randn([128], device=DEVICE),
1313+
)
1314+
1315+
bound_kernel = add.bind(args)
1316+
# Use PatternSearch which has rebenchmark method
1317+
search = PatternSearch(bound_kernel, args)
1318+
1319+
# Compile two configs
1320+
config1 = search.config_gen.random_config()
1321+
config2 = search.config_gen.random_config()
1322+
fn1 = bound_kernel.compile_config(config1)
1323+
fn2 = bound_kernel.compile_config(config2)
1324+
1325+
# Create population members (flat_values not used in rebenchmark)
1326+
member1 = PopulationMember(fn1, [1.0], (), config1)
1327+
member2 = PopulationMember(fn2, [1.1], (), config2)
1328+
1329+
search.best_perf_so_far = 1.0
1330+
1331+
# Call rebenchmark directly
1332+
search.rebenchmark([member1, member2])
1333+
1334+
# Verify custom benchmark function was called
1335+
self.assertGreater(
1336+
len(benchmark_calls), 0, "Custom benchmark function should be called"
1337+
)
1338+
# Should have been called with 2 functions
1339+
self.assertEqual(benchmark_calls[0][0], 2)
1340+
12871341

12881342
class TestAutotuneRandomSeed(RefEagerTestDisabled, TestCase):
12891343
def _autotune_and_record(self, **settings: object) -> float:

0 commit comments

Comments
 (0)