@@ -403,6 +403,9 @@ def __init__(
403403 constexpr_args [name ] = arg
404404 else :
405405 self .fake_args .append (self .env .to_fake (arg , ArgumentOrigin (name )))
406+
407+ self ._apply_mark_static (args )
408+
406409 with (
407410 _maybe_skip_dtype_check_in_meta_registrations (),
408411 patch_inductor_lowerings (),
@@ -420,6 +423,24 @@ def __init__(
420423 self .maybe_log_repro (log .warning , args , config = config )
421424 raise
422425
426+ def _apply_mark_static (self , args : tuple [object , ...]) -> None :
427+ """
428+ Apply torch._dynamo.mark_static() markings from input tensors.
429+
430+ This reads _dynamo_static_indices from each tensor argument and marks
431+ the corresponding dimensions as specialized (constant) in the kernel.
432+ """
433+ for arg_idx , (arg , fake_arg ) in enumerate (zip (args , self .fake_args , strict = True )):
434+ if isinstance (arg , torch .Tensor ):
435+ static_indices = getattr (arg , "_dynamo_static_indices" , None )
436+ if static_indices :
437+ assert isinstance (fake_arg , torch .Tensor )
438+ for dim in static_indices :
439+ size = fake_arg .size (dim )
440+ if isinstance (size , torch .SymInt ):
441+ sym_expr = size ._sympy_ ()
442+ self .env .specialized_vars .update (sym_expr .free_symbols )
443+
423444 @property
424445 def settings (self ) -> Settings :
425446 """
@@ -889,12 +910,14 @@ def kernel(
889910def _tensor_key (fn : Kernel , obj : torch .Tensor ) -> Hashable :
890911 # NOTE: If a machine has two different gpu types on the same machine,
891912 # obj.device.type will incorrectly hit
913+ static_indices = frozenset (getattr (obj , "_dynamo_static_indices" , ()))
892914 if fn .settings .static_shapes :
893915 return (
894916 obj .dtype ,
895917 obj .device .type ,
896918 (* obj .size (),),
897919 (* obj .stride (),),
920+ static_indices ,
898921 )
899922 bucketed = tuple ([min (s , 2 ) for s in obj .size ()])
900923 if fn .settings .index_dtype is None :
@@ -907,11 +930,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
907930 obj .device .type ,
908931 bucketed ,
909932 needs_int64 ,
933+ static_indices ,
910934 )
911935 return (
912936 obj .dtype ,
913937 obj .device .type ,
914938 bucketed ,
939+ static_indices ,
915940 )
916941
917942
0 commit comments