3030from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
3131from .fx_features import register_notrace_function
3232from .helpers import build_model_with_cfg , named_apply
33- from .layers import trunc_normal_tf_ , DropPath , to_2tuple , Mlp , get_attn , get_act_layer , get_norm_layer , \
34- ClassifierHead , LayerNorm2d , _assert
33+ from .layers import DropPath , to_2tuple , to_ntuple , Mlp , ClassifierHead , LayerNorm2d , \
34+ get_attn , get_act_layer , get_norm_layer , _assert
3535from .registry import register_model
3636from .vision_transformer_relpos import RelPosMlp , RelPosBias # FIXME move to common location
3737
@@ -321,7 +321,7 @@ def __init__(
321321 depth : int ,
322322 num_heads : int ,
323323 feat_size : Tuple [int , int ],
324- window_size : int ,
324+ window_size : Tuple [ int , int ] ,
325325 downsample : bool = True ,
326326 global_norm : bool = False ,
327327 stage_norm : bool = False ,
@@ -347,8 +347,9 @@ def __init__(
347347 else :
348348 self .downsample = nn .Identity ()
349349 self .feat_size = feat_size
350+ window_size = to_2tuple (window_size )
350351
351- feat_levels = int (math .log2 (min (feat_size ) / window_size ))
352+ feat_levels = int (math .log2 (min (feat_size ) / min ( window_size ) ))
352353 self .global_block = FeatureBlock (dim , feat_levels )
353354 self .global_norm = norm_layer_cl (dim ) if global_norm else nn .Identity ()
354355
@@ -400,7 +401,8 @@ def __init__(
400401 num_classes : int = 1000 ,
401402 global_pool : str = 'avg' ,
402403 img_size : Tuple [int , int ] = 224 ,
403- window_size : Tuple [int , ...] = (7 , 7 , 14 , 7 ),
404+ window_ratio : Tuple [int , ...] = (32 , 32 , 16 , 32 ),
405+ window_size : Tuple [int , ...] = None ,
404406 embed_dim : int = 64 ,
405407 depths : Tuple [int , ...] = (3 , 4 , 19 , 5 ),
406408 num_heads : Tuple [int , ...] = (2 , 4 , 8 , 16 ),
@@ -411,7 +413,7 @@ def __init__(
411413 proj_drop_rate : float = 0. ,
412414 attn_drop_rate : float = 0. ,
413415 drop_path_rate : float = 0. ,
414- weight_init = 'vit ' ,
416+ weight_init = '' ,
415417 act_layer : str = 'gelu' ,
416418 norm_layer : str = 'layernorm2d' ,
417419 norm_layer_cl : str = 'layernorm' ,
@@ -429,6 +431,11 @@ def __init__(
429431 self .drop_rate = drop_rate
430432 num_stages = len (depths )
431433 self .num_features = int (embed_dim * 2 ** (num_stages - 1 ))
434+ if window_size is not None :
435+ window_size = to_ntuple (num_stages )(window_size )
436+ else :
437+ assert window_ratio is not None
438+ window_size = tuple ([(img_size [0 ] // r , img_size [1 ] // r ) for r in to_ntuple (num_stages )(window_ratio )])
432439
433440 self .stem = Stem (
434441 in_chs = in_chans ,
@@ -480,7 +487,7 @@ def _init_weights(self, module, name, scheme='vit'):
480487 nn .init .zeros_ (module .bias )
481488 else :
482489 if isinstance (module , nn .Linear ):
483- trunc_normal_tf_ (module .weight , std = .02 )
490+ nn . init . normal_ (module .weight , std = .02 )
484491 if module .bias is not None :
485492 nn .init .zeros_ (module .bias )
486493
@@ -490,7 +497,6 @@ def no_weight_decay(self):
490497 k for k , _ in self .named_parameters ()
491498 if any (n in k for n in ["relative_position_bias_table" , "rel_pos.mlp" ])}
492499
493-
494500 @torch .jit .ignore
495501 def group_matcher (self , coarse = False ):
496502 matcher = dict (
@@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs):
567573 model_kwargs = dict (
568574 depths = (3 , 4 , 19 , 5 ),
569575 num_heads = (3 , 6 , 12 , 24 ),
570- window_size = (7 , 7 , 14 , 7 ),
571576 embed_dim = 96 ,
572577 mlp_ratio = 2 ,
573578 layer_scale = 1e-5 ,
@@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs):
580585 model_kwargs = dict (
581586 depths = (3 , 4 , 19 , 5 ),
582587 num_heads = (4 , 8 , 16 , 32 ),
583- window_size = (7 , 7 , 14 , 7 ),
584588 embed_dim = 128 ,
585589 mlp_ratio = 2 ,
586590 layer_scale = 1e-5 ,
0 commit comments