4040def _cfg (url = '' , ** kwargs ):
4141 return {
4242 'url' : url ,
43- 'num_classes' : 1000 , 'input_size' : (3 , 240 , 240 ), 'pool_size' : None ,
43+ 'num_classes' : 1000 , 'input_size' : (3 , 240 , 240 ), 'pool_size' : None , 'crop_pct' : 0.875 ,
4444 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD , 'fixed_input_size' : True ,
4545 'first_conv' : ('patch_embed.0.proj' , 'patch_embed.1.proj' ),
4646 'classifier' : ('head.0' , 'head.1' ),
@@ -56,7 +56,7 @@ def _cfg(url='', **kwargs):
5656 ),
5757 'crossvit_15_dagger_408' : _cfg (
5858 url = 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth' ,
59- input_size = (3 , 408 , 408 ), first_conv = ('patch_embed.0.proj.0' , 'patch_embed.1.proj.0' ),
59+ input_size = (3 , 408 , 408 ), first_conv = ('patch_embed.0.proj.0' , 'patch_embed.1.proj.0' ), crop_pct = 1.0 ,
6060 ),
6161 'crossvit_18_240' : _cfg (url = 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth' ),
6262 'crossvit_18_dagger_240' : _cfg (
@@ -65,7 +65,7 @@ def _cfg(url='', **kwargs):
6565 ),
6666 'crossvit_18_dagger_408' : _cfg (
6767 url = 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth' ,
68- input_size = (3 , 408 , 408 ), first_conv = ('patch_embed.0.proj.0' , 'patch_embed.1.proj.0' ),
68+ input_size = (3 , 408 , 408 ), first_conv = ('patch_embed.0.proj.0' , 'patch_embed.1.proj.0' ), crop_pct = 1.0 ,
6969 ),
7070 'crossvit_9_240' : _cfg (url = 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth' ),
7171 'crossvit_9_dagger_240' : _cfg (
@@ -263,14 +263,15 @@ def __init__(
263263 self , img_size = 224 , img_scale = (1.0 , 1.0 ), patch_size = (8 , 16 ), in_chans = 3 , num_classes = 1000 ,
264264 embed_dim = (192 , 384 ), depth = ((1 , 3 , 1 ), (1 , 3 , 1 ), (1 , 3 , 1 )), num_heads = (6 , 12 ), mlp_ratio = (2. , 2. , 4. ),
265265 qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. ,
266- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), multi_conv = False
266+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), multi_conv = False , crop_scale = False ,
267267 ):
268268 super ().__init__ ()
269269
270270 self .num_classes = num_classes
271271 self .img_size = to_2tuple (img_size )
272272 img_scale = to_2tuple (img_scale )
273273 self .img_size_scaled = [tuple ([int (sj * si ) for sj in self .img_size ]) for si in img_scale ]
274+ self .crop_scale = crop_scale # crop instead of interpolate for scale
274275 num_patches = _compute_num_patches (self .img_size_scaled , patch_size )
275276 self .num_branches = len (patch_size )
276277 self .embed_dim = embed_dim
@@ -307,8 +308,7 @@ def __init__(
307308 for i in range (self .num_branches )])
308309
309310 for i in range (self .num_branches ):
310- if hasattr (self , f'pos_embed_{ i } ' ):
311- trunc_normal_ (getattr (self , f'pos_embed_{ i } ' ), std = .02 )
311+ trunc_normal_ (getattr (self , f'pos_embed_{ i } ' ), std = .02 )
312312 trunc_normal_ (getattr (self , f'cls_token_{ i } ' ), std = .02 )
313313
314314 self .apply (self ._init_weights )
@@ -324,9 +324,12 @@ def _init_weights(self, m):
324324
325325 @torch .jit .ignore
326326 def no_weight_decay (self ):
327- out = {'cls_token' }
328- if self .pos_embed [0 ].requires_grad :
329- out .add ('pos_embed' )
327+ out = set ()
328+ for i in range (self .num_branches ):
329+ out .add (f'cls_token_{ i } ' )
330+ pe = getattr (self , f'pos_embed_{ i } ' , None )
331+ if pe is not None and pe .requires_grad :
332+ out .add (f'pos_embed_{ i } ' )
330333 return out
331334
332335 def get_classifier (self ):
@@ -342,23 +345,29 @@ def forward_features(self, x):
342345 B , C , H , W = x .shape
343346 xs = []
344347 for i , patch_embed in enumerate (self .patch_embed ):
348+ x_ = x
345349 ss = self .img_size_scaled [i ]
346- x_ = torch .nn .functional .interpolate (x , size = ss , mode = 'bicubic' , align_corners = False ) if H != ss [0 ] else x
347- tmp = patch_embed (x_ )
350+ if H != ss [0 ] or W != ss [1 ]:
351+ if self .crop_scale and ss [0 ] <= H and ss [1 ] <= W :
352+ cu , cl = int (round ((H - ss [0 ]) / 2. )), int (round ((W - ss [1 ]) / 2. ))
353+ x_ = x_ [:, :, cu :cu + ss [0 ], cl :cl + ss [1 ]]
354+ else :
355+ x_ = torch .nn .functional .interpolate (x_ , size = ss , mode = 'bicubic' , align_corners = False )
356+ x_ = patch_embed (x_ )
348357 cls_tokens = self .cls_token_0 if i == 0 else self .cls_token_1 # hard-coded for torch jit script
349358 cls_tokens = cls_tokens .expand (B , - 1 , - 1 )
350- tmp = torch .cat ((cls_tokens , tmp ), dim = 1 )
359+ x_ = torch .cat ((cls_tokens , x_ ), dim = 1 )
351360 pos_embed = self .pos_embed_0 if i == 0 else self .pos_embed_1 # hard-coded for torch jit script
352- tmp = tmp + pos_embed
353- tmp = self .pos_drop (tmp )
354- xs .append (tmp )
361+ x_ = x_ + pos_embed
362+ x_ = self .pos_drop (x_ )
363+ xs .append (x_ )
355364
356365 for i , blk in enumerate (self .blocks ):
357366 xs = blk (xs )
358367
359368 # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
360369 xs = [norm (xs [i ]) for i , norm in enumerate (self .norm )]
361- return [x [:, 0 ] for x in xs ]
370+ return [xo [:, 0 ] for xo in xs ]
362371
363372 def forward (self , x ):
364373 xs = self .forward_features (x )
0 commit comments