1212from lightllm .utils .device_utils import get_current_device_name
1313from lightllm .utils .log_utils import init_logger
1414from typing import Callable , Optional , Union , List
15- from lightllm .utils .envs_utils import is_triton_autotune_enabled
15+ from lightllm .utils .envs_utils import get_triton_autotune_level
1616from lightllm .common .kernel_config import KernelConfigs
1717from lightllm .utils .dist_utils import get_global_world_size , get_global_rank , get_current_rank_in_node
18- from lightllm .distributed .communication_op import dist_group_manager
1918
2019logger = init_logger (__name__ )
2120
2221
22+ class AutotuneLevel :
23+ # Use the config of cached files in /lightllm/common/triton_utils/autotune_kernel_configs.
24+ USE_AUTOTUNE_HIS_CONFIG = 0
25+ # Autotune if no config is cached.
26+ ADAPTIVE_AUTOTUNE = 1
27+ # Autotune anyway to overwrite the config of cached files.
28+ FORCE_AUTOTUNE = 2
29+ # Close autotune and use the configs of cached files in lightllm/common/all_kernel_configs.
30+ CLOSE_AUTOTUNE = 3
31+
32+
2333def autotune (
2434 kernel_name : str ,
2535 configs_gen_func : Callable [[], List ],
@@ -28,6 +38,30 @@ def autotune(
2838 run_key_distance_func : Callable = lambda run_key , config_key : abs (int (run_key ) - int (config_key )),
2939 mutates_args : List [str ] = [],
3040):
41+ """Decorator that constructs and returns an Autotuner wrapper for a Triton kernel.
42+
43+ This decorator configures an Autotuner with the provided configuration
44+ generator and key functions, enabling on-demand benchmarking and caching
45+ of kernel run configurations across runs and processes.
46+
47+ Args:
48+ kernel_name (str): Human-readable kernel name used for logging and cache paths.
49+ configs_gen_func (Callable[[], List]): Function that returns candidate run configurations.
50+ static_key_func (Callable): Function that derives a static key (dict-like) from call arguments.
51+ This key identifies the cache file that stores tuned configs.
52+ run_key_func (Callable): Function that derives a run-time key from call arguments.
53+ This key indexes tuned configs within a static key's cache.
54+ run_key_distance_func (Callable, optional): Distance metric taking ``(run_key, config_key)`` and
55+ returning a comparable value; used to pick the closest config when an exact match is absent.
56+ Defaults to ``abs(int(run_key) - int(config_key))``.
57+ mutates_args (List[str], optional): Names of arguments that can be mutated by the kernel.
58+ During benchmarking, defensive clones are made to avoid side effects. Defaults to ``[]``.
59+
60+ Returns:
61+ Callable: A callable object that wraps the original function and performs autotuning
62+ as needed before invocation.
63+ """
64+
3165 def decorator (fn ):
3266 return Autotuner (
3367 fn = fn ,
@@ -53,8 +87,6 @@ def __init__(
5387 run_key_distance_func : Callable = lambda run_key , config_key : abs (int (run_key ) - int (config_key )),
5488 mutates_args : List [str ] = [],
5589 ):
56- # Whether to use this autotune decorator
57- self .disable_autotune = not is_triton_autotune_enabled ()
5890
5991 self .configs_gen_func = configs_gen_func
6092 self .kernel_name = kernel_name
@@ -81,41 +113,50 @@ def __init__(
81113 ]
82114 self ._run_key_func_param_names = [name for name , _ in inspect .signature (self .run_key_func ).parameters .items ()]
83115 self .mutates_args = mutates_args
116+
117+ assert get_triton_autotune_level () in [
118+ AutotuneLevel .USE_AUTOTUNE_HIS_CONFIG ,
119+ AutotuneLevel .ADAPTIVE_AUTOTUNE ,
120+ AutotuneLevel .FORCE_AUTOTUNE ,
121+ AutotuneLevel .CLOSE_AUTOTUNE ,
122+ ]
84123 return
85124
86125 @torch .no_grad ()
87126 def __call__ (self , * args , ** kwargs ):
88127 if kwargs .get ("run_config" , None ) is not None :
89128 return self .fn (* args , ** kwargs )
90129
91- if self .disable_autotune :
130+ # if the autotune_level is AutotuneLevel.CLOSE_AUTOTUNE, ignore the autotune
131+ autotune_level = get_triton_autotune_level ()
132+ if autotune_level == AutotuneLevel .CLOSE_AUTOTUNE :
92133 return self .fn (* args , ** kwargs )
93134
94135 rank_id = 0 if not dist .is_initialized () else get_global_rank ()
95136 world_size = 1 if not dist .is_initialized () else get_global_world_size ()
96137
97- static_key = self ._static_key (* args , ** kwargs )
138+ static_key = frozendict ( self ._static_key (* args , ** kwargs ) )
98139 run_key = str (self ._run_key (* args , ** kwargs ))
99140
100- # Lazy load
141+ # Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
101142 self ._try_load_cache (static_key )
102143
103- if static_key not in self .cached_configs :
144+ if static_key not in self .cached_configs and autotune_level == AutotuneLevel . USE_AUTOTUNE_HIS_CONFIG :
104145 if (dist .is_initialized () and get_current_rank_in_node () == 0 ) or not dist .is_initialized ():
105146 logger .warning (
106147 f"No kernel config for { self .kernel_name } in { KernelConfigs .get_config_file_name (static_key )} " ,
107148 )
108149 self .cached_configs [static_key ] = {}
109150
110- if is_triton_autotune_enabled ():
111- need_tunning = run_key not in self .cached_configs .get (static_key , {})
151+ if autotune_level in [AutotuneLevel .ADAPTIVE_AUTOTUNE , AutotuneLevel .FORCE_AUTOTUNE ]:
152+ need_tuning = (autotune_level == AutotuneLevel .FORCE_AUTOTUNE ) or (
153+ run_key not in self .cached_configs .get (static_key , {})
154+ )
112155 if world_size > 1 :
113- _need_tunnings = [None for _ in range (world_size )]
114- dist .all_gather_object (
115- _need_tunnings , obj = need_tunning , group = dist_group_manager .get_default_group ().autotune_group
116- )
117- need_tunning = any (_need_tunnings )
118- if need_tunning :
156+ _need_tunings = [None for _ in range (world_size )]
157+ dist .all_gather_object (_need_tunings , obj = need_tuning , group = self ._get_autotune_group ())
158+ need_tuning = any (_need_tunings )
159+ if need_tuning :
119160 self ._autotune (
120161 args = args ,
121162 kwargs = kwargs ,
@@ -125,12 +166,12 @@ def __call__(self, *args, **kwargs):
125166 world_size = world_size ,
126167 )
127168
128- if static_key in self .fast_match_configs and run_key in self . fast_match_configs [ static_key ]:
129- closest_config = self . fast_match_configs [ static_key ][ run_key ]
169+ closest_config = self .fast_match_configs . get ( static_key , {}). get ( run_key , None )
170+ if closest_config is not None :
130171 kwargs ["run_config" ] = closest_config
131172 return self .fn (* args , ** kwargs )
132173
133- all_configs = self .cached_configs .get (static_key )
174+ all_configs = self .cached_configs .get (static_key , {} )
134175 if len (all_configs ) != 0 :
135176 closest_config = min (
136177 list (all_configs .items ()), key = lambda item : self .run_key_distance_func (run_key , item [0 ])
@@ -146,6 +187,7 @@ def _try_load_cache(self, static_key):
146187
147188 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (static_key ))
148189 if os .path .exists (cache_file ):
190+ logger .info (f"Loading cached configs for { self .kernel_name } - { static_key } " )
149191 with open (cache_file , "rb" ) as f :
150192 self .cached_configs [static_key ] = orjson .loads (f .read ())
151193 return
@@ -194,9 +236,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
194236 if world_size > 1 :
195237 all_keys = [None for _ in range (world_size )]
196238 all_key_str = f"{ run_key } _{ static_key } "
197- dist .all_gather_object (
198- all_keys , obj = all_key_str , group = dist_group_manager .get_default_group ().autotune_group
199- )
239+ dist .all_gather_object (all_keys , obj = all_key_str , group = self ._get_autotune_group ())
200240 is_key_all_same = all (all_keys [0 ] == k for k in all_keys )
201241 if not is_key_all_same :
202242 logger .warning (
@@ -237,7 +277,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
237277 dist .all_gather_object (
238278 all_gather_configs ,
239279 obj = (best_time , run_key , dict (static_key ), best_config ),
240- group = dist_group_manager . get_default_group (). autotune_group ,
280+ group = self . _get_autotune_group () ,
241281 )
242282 all_gather_configs = sorted (all_gather_configs , key = lambda x : x [0 ])
243283 key_set = set ()
@@ -318,13 +358,19 @@ def _select_args(self, param_names, args, kwargs):
318358
319359 def _static_key (self , * args , ** kwargs ):
320360 params = self ._select_args (self ._static_key_func_param_names , args , kwargs )
321- key = self .static_key_func (* params )
322- return frozendict (key )
361+ return self .static_key_func (* params )
323362
324363 def _run_key (self , * args , ** kwargs ):
325364 params = self ._select_args (self ._run_key_func_param_names , args , kwargs )
326365 return self .run_key_func (* params )
327366
367+ def _get_autotune_group (
368+ self ,
369+ ):
370+ from lightllm .distributed .communication_op import dist_group_manager
371+
372+ return dist_group_manager .get_default_group ().autotune_group
373+
328374
329375class _BenchmarkState :
330376 def __init__ (self ):
0 commit comments