Skip to content

Commit 914dd58

Browse files
authored
add AutotuneLevel for more detailed autotune (#1031)
1 parent 237ae00 commit 914dd58

File tree

6 files changed

+96
-39
lines changed

6 files changed

+96
-39
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from lightllm.utils.envs_utils import get_env_start_args
2525
from lightllm.distributed.communication_op import dist_group_manager
2626
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
27+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2728
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
28-
from lightllm.utils.envs_utils import set_model_init_status, is_triton_autotune_enabled, disable_triton_autotune
29+
from lightllm.utils.envs_utils import set_model_init_status, set_triton_autotune_level, get_triton_autotune_level
2930
from lightllm.utils.infer_utils import post_empty_cache
3031

3132
logger = init_logger(__name__)
@@ -731,7 +732,7 @@ def autotune_layers(self):
731732
@torch.no_grad()
732733
@post_empty_cache
733734
def _autotune_warmup(self):
734-
if not is_triton_autotune_enabled():
735+
if get_triton_autotune_level() not in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
735736
return
736737

737738
torch.distributed.barrier()
@@ -794,7 +795,7 @@ def _autotune_warmup(self):
794795
torch.cuda.empty_cache()
795796
self.layers_num = layer_num_bak
796797
torch.distributed.barrier()
797-
disable_triton_autotune()
798+
set_triton_autotune_level(AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG)
798799

799800
@final
800801
@torch.no_grad()

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
)
1818
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1919
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
20-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
20+
from lightllm.utils.envs_utils import get_triton_autotune_level
21+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2122
from lightllm.utils.log_utils import init_logger
2223

2324
logger = init_logger(__name__)
@@ -358,7 +359,7 @@ def prefilled_group_gemm(
358359
######################################## warning ##################################################
359360
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
360361
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
361-
if is_triton_autotune_enabled():
362+
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
362363
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
363364
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
364365
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
17+
from lightllm.utils.envs_utils import get_triton_autotune_level
18+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
1819
import numpy as np
1920

2021
logger = init_logger(__name__)
@@ -191,7 +192,7 @@ def fused_experts_impl(
191192
######################################## warning ##################################################
192193
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
193194
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
194-
if is_triton_autotune_enabled():
195+
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
195196
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
196197
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
197198
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/topk_select.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
2525
from lightllm.common.fused_moe.softmax_topk import softmax_topk
26+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
27+
from lightllm.utils.envs_utils import get_triton_autotune_level
2628

2729
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
2830

@@ -221,4 +223,12 @@ def select_experts(
221223
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
222224
)
223225

226+
######################################## warning ##################################################
227+
# here is used to match autotune feature, make topk_ids more random
228+
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
229+
rand_gen = torch.Generator(device="cuda")
230+
rand_gen.manual_seed(router_logits.shape[0])
231+
router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device="cuda")
232+
_, topk_ids = torch.topk(router_logits, k=top_k, dim=1)
233+
224234
return topk_weights, topk_ids

lightllm/common/triton_utils/autotuner.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,24 @@
1212
from lightllm.utils.device_utils import get_current_device_name
1313
from lightllm.utils.log_utils import init_logger
1414
from typing import Callable, Optional, Union, List
15-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
15+
from lightllm.utils.envs_utils import get_triton_autotune_level
1616
from lightllm.common.kernel_config import KernelConfigs
1717
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
18-
from lightllm.distributed.communication_op import dist_group_manager
1918

2019
logger = init_logger(__name__)
2120

2221

22+
class AutotuneLevel:
23+
# Use the config of cached files in /lightllm/common/triton_utils/autotune_kernel_configs.
24+
USE_AUTOTUNE_HIS_CONFIG = 0
25+
# Autotune if no config is cached.
26+
ADAPTIVE_AUTOTUNE = 1
27+
# Autotune anyway to overwrite the config of cached files.
28+
FORCE_AUTOTUNE = 2
29+
# Close autotune and use the configs of cached files in lightllm/common/all_kernel_configs.
30+
CLOSE_AUTOTUNE = 3
31+
32+
2333
def autotune(
2434
kernel_name: str,
2535
configs_gen_func: Callable[[], List],
@@ -28,6 +38,30 @@ def autotune(
2838
run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)),
2939
mutates_args: List[str] = [],
3040
):
41+
"""Decorator that constructs and returns an Autotuner wrapper for a Triton kernel.
42+
43+
This decorator configures an Autotuner with the provided configuration
44+
generator and key functions, enabling on-demand benchmarking and caching
45+
of kernel run configurations across runs and processes.
46+
47+
Args:
48+
kernel_name (str): Human-readable kernel name used for logging and cache paths.
49+
configs_gen_func (Callable[[], List]): Function that returns candidate run configurations.
50+
static_key_func (Callable): Function that derives a static key (dict-like) from call arguments.
51+
This key identifies the cache file that stores tuned configs.
52+
run_key_func (Callable): Function that derives a run-time key from call arguments.
53+
This key indexes tuned configs within a static key's cache.
54+
run_key_distance_func (Callable, optional): Distance metric taking ``(run_key, config_key)`` and
55+
returning a comparable value; used to pick the closest config when an exact match is absent.
56+
Defaults to ``abs(int(run_key) - int(config_key))``.
57+
mutates_args (List[str], optional): Names of arguments that can be mutated by the kernel.
58+
During benchmarking, defensive clones are made to avoid side effects. Defaults to ``[]``.
59+
60+
Returns:
61+
Callable: A callable object that wraps the original function and performs autotuning
62+
as needed before invocation.
63+
"""
64+
3165
def decorator(fn):
3266
return Autotuner(
3367
fn=fn,
@@ -53,8 +87,6 @@ def __init__(
5387
run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)),
5488
mutates_args: List[str] = [],
5589
):
56-
# Whether to use this autotune decorator
57-
self.disable_autotune = not is_triton_autotune_enabled()
5890

5991
self.configs_gen_func = configs_gen_func
6092
self.kernel_name = kernel_name
@@ -81,41 +113,50 @@ def __init__(
81113
]
82114
self._run_key_func_param_names = [name for name, _ in inspect.signature(self.run_key_func).parameters.items()]
83115
self.mutates_args = mutates_args
116+
117+
assert get_triton_autotune_level() in [
118+
AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG,
119+
AutotuneLevel.ADAPTIVE_AUTOTUNE,
120+
AutotuneLevel.FORCE_AUTOTUNE,
121+
AutotuneLevel.CLOSE_AUTOTUNE,
122+
]
84123
return
85124

86125
@torch.no_grad()
87126
def __call__(self, *args, **kwargs):
88127
if kwargs.get("run_config", None) is not None:
89128
return self.fn(*args, **kwargs)
90129

91-
if self.disable_autotune:
130+
# if the autotune_level is AutotuneLevel.CLOSE_AUTOTUNE, ignore the autotune
131+
autotune_level = get_triton_autotune_level()
132+
if autotune_level == AutotuneLevel.CLOSE_AUTOTUNE:
92133
return self.fn(*args, **kwargs)
93134

94135
rank_id = 0 if not dist.is_initialized() else get_global_rank()
95136
world_size = 1 if not dist.is_initialized() else get_global_world_size()
96137

97-
static_key = self._static_key(*args, **kwargs)
138+
static_key = frozendict(self._static_key(*args, **kwargs))
98139
run_key = str(self._run_key(*args, **kwargs))
99140

100-
# Lazy load
141+
# Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
101142
self._try_load_cache(static_key)
102143

103-
if static_key not in self.cached_configs:
144+
if static_key not in self.cached_configs and autotune_level == AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG:
104145
if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized():
105146
logger.warning(
106147
f"No kernel config for {self.kernel_name} in {KernelConfigs.get_config_file_name(static_key)}",
107148
)
108149
self.cached_configs[static_key] = {}
109150

110-
if is_triton_autotune_enabled():
111-
need_tunning = run_key not in self.cached_configs.get(static_key, {})
151+
if autotune_level in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
152+
need_tuning = (autotune_level == AutotuneLevel.FORCE_AUTOTUNE) or (
153+
run_key not in self.cached_configs.get(static_key, {})
154+
)
112155
if world_size > 1:
113-
_need_tunnings = [None for _ in range(world_size)]
114-
dist.all_gather_object(
115-
_need_tunnings, obj=need_tunning, group=dist_group_manager.get_default_group().autotune_group
116-
)
117-
need_tunning = any(_need_tunnings)
118-
if need_tunning:
156+
_need_tunings = [None for _ in range(world_size)]
157+
dist.all_gather_object(_need_tunings, obj=need_tuning, group=self._get_autotune_group())
158+
need_tuning = any(_need_tunings)
159+
if need_tuning:
119160
self._autotune(
120161
args=args,
121162
kwargs=kwargs,
@@ -125,12 +166,12 @@ def __call__(self, *args, **kwargs):
125166
world_size=world_size,
126167
)
127168

128-
if static_key in self.fast_match_configs and run_key in self.fast_match_configs[static_key]:
129-
closest_config = self.fast_match_configs[static_key][run_key]
169+
closest_config = self.fast_match_configs.get(static_key, {}).get(run_key, None)
170+
if closest_config is not None:
130171
kwargs["run_config"] = closest_config
131172
return self.fn(*args, **kwargs)
132173

133-
all_configs = self.cached_configs.get(static_key)
174+
all_configs = self.cached_configs.get(static_key, {})
134175
if len(all_configs) != 0:
135176
closest_config = min(
136177
list(all_configs.items()), key=lambda item: self.run_key_distance_func(run_key, item[0])
@@ -146,6 +187,7 @@ def _try_load_cache(self, static_key):
146187

147188
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
148189
if os.path.exists(cache_file):
190+
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
149191
with open(cache_file, "rb") as f:
150192
self.cached_configs[static_key] = orjson.loads(f.read())
151193
return
@@ -194,9 +236,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
194236
if world_size > 1:
195237
all_keys = [None for _ in range(world_size)]
196238
all_key_str = f"{run_key}_{static_key}"
197-
dist.all_gather_object(
198-
all_keys, obj=all_key_str, group=dist_group_manager.get_default_group().autotune_group
199-
)
239+
dist.all_gather_object(all_keys, obj=all_key_str, group=self._get_autotune_group())
200240
is_key_all_same = all(all_keys[0] == k for k in all_keys)
201241
if not is_key_all_same:
202242
logger.warning(
@@ -237,7 +277,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
237277
dist.all_gather_object(
238278
all_gather_configs,
239279
obj=(best_time, run_key, dict(static_key), best_config),
240-
group=dist_group_manager.get_default_group().autotune_group,
280+
group=self._get_autotune_group(),
241281
)
242282
all_gather_configs = sorted(all_gather_configs, key=lambda x: x[0])
243283
key_set = set()
@@ -318,13 +358,19 @@ def _select_args(self, param_names, args, kwargs):
318358

319359
def _static_key(self, *args, **kwargs):
320360
params = self._select_args(self._static_key_func_param_names, args, kwargs)
321-
key = self.static_key_func(*params)
322-
return frozendict(key)
361+
return self.static_key_func(*params)
323362

324363
def _run_key(self, *args, **kwargs):
325364
params = self._select_args(self._run_key_func_param_names, args, kwargs)
326365
return self.run_key_func(*params)
327366

367+
def _get_autotune_group(
368+
self,
369+
):
370+
from lightllm.distributed.communication_op import dist_group_manager
371+
372+
return dist_group_manager.get_default_group().autotune_group
373+
328374

329375
class _BenchmarkState:
330376
def __init__(self):

lightllm/utils/envs_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,13 @@ def get_kv_quant_calibration_inference_count():
149149
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000))
150150

151151

152-
def is_triton_autotune_enabled():
153-
# Whether Triton autotune is enabled (read-only check)
154-
mark = os.getenv("LIGHTLLM_TRITON_AUTOTUNE", "False").upper() in ["ON", "TRUE", "1"]
155-
return mark
152+
def get_triton_autotune_level():
153+
return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0))
156154

157155

158-
def disable_triton_autotune():
159-
# Disable Triton autotune (setter)
160-
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "False"
156+
def set_triton_autotune_level(level: int):
157+
os.environ["LIGHTLLM_TRITON_AUTOTUNE_LEVEL"] = str(level)
158+
return
161159

162160

163161
g_model_init_done = False

0 commit comments

Comments
 (0)