@@ -403,6 +403,9 @@ def __init__(
403403 constexpr_args [name ] = arg
404404 else :
405405 self .fake_args .append (self .env .to_fake (arg , ArgumentOrigin (name )))
406+
407+ self ._apply_mark_static (args )
408+
406409 with (
407410 _maybe_skip_dtype_check_in_meta_registrations (),
408411 patch_inductor_lowerings (),
@@ -420,6 +423,20 @@ def __init__(
420423 self .maybe_log_repro (log .warning , args , config = config )
421424 raise
422425
426+ def _apply_mark_static (self , args : tuple [object , ...]) -> None :
427+ """
428+ Apply torch._dynamo.mark_static() markings from input tensors.
429+
430+ This reads _dynamo_static_indices from each tensor argument and marks
431+ the corresponding dimensions as specialized (constant) in the kernel.
432+ """
433+ for arg , fake_arg in zip (args , self .fake_args , strict = True ):
434+ if isinstance (arg , torch .Tensor ) and isinstance (fake_arg , torch .Tensor ):
435+ for dim in getattr (arg , "_dynamo_static_indices" , ()):
436+ size = fake_arg .size (dim )
437+ if isinstance (size , torch .SymInt ):
438+ self .env .specialized_vars .update (size ._sympy_ ().free_symbols )
439+
423440 @property
424441 def settings (self ) -> Settings :
425442 """
@@ -889,12 +906,14 @@ def kernel(
889906def _tensor_key (fn : Kernel , obj : torch .Tensor ) -> Hashable :
890907 # NOTE: If a machine has two different gpu types on the same machine,
891908 # obj.device.type will incorrectly hit
909+ static_indices = frozenset (getattr (obj , "_dynamo_static_indices" , ()))
892910 if fn .settings .static_shapes :
893911 return (
894912 obj .dtype ,
895913 obj .device .type ,
896914 (* obj .size (),),
897915 (* obj .stride (),),
916+ static_indices ,
898917 )
899918 bucketed = tuple ([min (s , 2 ) for s in obj .size ()])
900919 if fn .settings .index_dtype is None :
@@ -907,11 +926,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
907926 obj .device .type ,
908927 bucketed ,
909928 needs_int64 ,
929+ static_indices ,
910930 )
911931 return (
912932 obj .dtype ,
913933 obj .device .type ,
914934 bucketed ,
935+ static_indices ,
915936 )
916937
917938
0 commit comments