@@ -403,6 +403,9 @@ def __init__(
403403 constexpr_args [name ] = arg
404404 else :
405405 self .fake_args .append (self .env .to_fake (arg , ArgumentOrigin (name )))
406+
407+ self ._apply_mark_static (args )
408+
406409 with (
407410 _maybe_skip_dtype_check_in_meta_registrations (),
408411 patch_inductor_lowerings (),
@@ -420,6 +423,20 @@ def __init__(
420423 self .maybe_log_repro (log .warning , args , config = config )
421424 raise
422425
426+ def _apply_mark_static (self , args : tuple [object , ...]) -> None :
427+ """
428+ Apply torch._dynamo.mark_static() markings from input tensors.
429+
430+ This reads _dynamo_static_indices from each tensor argument and marks
431+ the corresponding dimensions as specialized (constant) in the kernel.
432+ """
433+ for arg , fake_arg in zip (args , self .fake_args , strict = True ):
434+ if isinstance (arg , torch .Tensor ) and isinstance (fake_arg , torch .Tensor ):
435+ for dim in getattr (arg , "_dynamo_static_indices" , ()):
436+ size = fake_arg .size (dim )
437+ if isinstance (size , torch .SymInt ):
438+ self .env .specialized_vars .update (size ._sympy_ ().free_symbols )
439+
423440 @property
424441 def settings (self ) -> Settings :
425442 """
@@ -891,12 +908,14 @@ def kernel(
891908def _tensor_key (fn : Kernel , obj : torch .Tensor ) -> Hashable :
892909 # NOTE: If a machine has two different gpu types on the same machine,
893910 # obj.device.type will incorrectly hit
911+ static_indices = frozenset (getattr (obj , "_dynamo_static_indices" , ()))
894912 if fn .settings .static_shapes :
895913 return (
896914 obj .dtype ,
897915 obj .device .type ,
898916 (* obj .size (),),
899917 (* obj .stride (),),
918+ static_indices ,
900919 )
901920 bucketed = tuple ([min (s , 2 ) for s in obj .size ()])
902921 if fn .settings .index_dtype is None :
@@ -909,11 +928,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
909928 obj .device .type ,
910929 bucketed ,
911930 needs_int64 ,
931+ static_indices ,
912932 )
913933 return (
914934 obj .dtype ,
915935 obj .device .type ,
916936 bucketed ,
937+ static_indices ,
917938 )
918939
919940
0 commit comments