1818from .config_fragment import ConfigSpecFragment
1919from .config_fragment import EnumFragment
2020from .config_fragment import IntegerFragment
21+ from .config_fragment import ListOf
2122from .config_fragment import NumWarpsFragment
2223from .config_fragment import PermutationFragment
2324from .config_fragment import PowerOfTwoFragment
5051 "num_stages" ,
5152 "pid_type" ,
5253 "indexing" ,
54+ "load_eviction_policies" ,
5355 ]
5456)
5557VALID_PID_TYPES = ("flat" , "xyz" , "persistent_blocked" , "persistent_interleaved" )
58+ VALID_EVICTION_POLICIES = ("" , "first" , "last" )
5659
5760
5861@dataclasses .dataclass
@@ -97,6 +100,11 @@ class ConfigSpec:
97100 default_factory = functools .partial (tuple , VALID_PID_TYPES )
98101 )
99102 grid_block_ids : list [int ] = dataclasses .field (default_factory = list )
103+ load_eviction_policies : ListOf = dataclasses .field (
104+ default_factory = lambda : ListOf (
105+ EnumFragment (choices = VALID_EVICTION_POLICIES ), length = 0
106+ )
107+ )
100108
101109 @staticmethod
102110 def _valid_indexing_types () -> tuple [IndexingLiteral , ...]:
@@ -206,12 +214,16 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
206214 "range_multi_buffers" ,
207215 "range_flattens" ,
208216 "static_ranges" ,
217+ "load_eviction_policies" ,
209218 ):
210- if not config [ name ] :
211- config .pop (name )
219+ if not config . get ( name ) :
220+ config .pop (name , None )
212221
213222 config .setdefault ("num_warps" , DEFAULT_NUM_WARPS )
214223 config .setdefault ("num_stages" , DEFAULT_NUM_STAGES )
224+ config .setdefault (
225+ "load_eviction_policies" , self .load_eviction_policies .default ()
226+ )
215227 # TODO(jansel): include num_ctas and max_nreg
216228
217229 for name , values in (
@@ -266,10 +278,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
266278 "num_stages" : fn (IntegerFragment (1 , 8 , DEFAULT_NUM_STAGES )),
267279 "indexing" : fn (EnumFragment (self ._valid_indexing_types ())),
268280 "pid_type" : fn (EnumFragment (self .allowed_pid_types )),
281+ "load_eviction_policies" : fn (self .load_eviction_policies ),
269282 }
270283 # Add tunable parameters
271- for key , fragment in self .user_defined_tunables .items ():
272- config [key ] = fn (fragment )
284+ config .update (
285+ {key : fn (fragment ) for key , fragment in self .user_defined_tunables .items ()}
286+ )
273287
274288 for name in (
275289 "loop_orders" ,
@@ -282,9 +296,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
282296 "range_multi_buffers" ,
283297 "range_flattens" ,
284298 "static_ranges" ,
299+ "load_eviction_policies" ,
285300 ):
286- if not config [ name ] :
287- config .pop (name )
301+ if not config . get ( name ) :
302+ config .pop (name , None )
288303 self .normalize (config )
289304 return helion .Config (** config )
290305
0 commit comments