@@ -268,12 +268,9 @@ def __init__(
268268 super ().__init__ ()
269269
270270 self .num_classes = num_classes
271- if not isinstance (img_size , (tuple , list )):
272- img_size = to_2tuple (img_size )
273- self .img_size = img_size
274- if not isinstance (img_scale , (tuple , list )):
275- img_scale = to_2tuple (img_scale )
276- self .img_size_scaled = [tuple ([int (sj * si ) for sj in img_size ]) for si in img_scale ]
271+ self .img_size = to_2tuple (img_size )
272+ img_scale = to_2tuple (img_scale )
273+ self .img_size_scaled = [tuple ([int (sj * si ) for sj in self .img_size ]) for si in img_scale ]
277274 num_patches = _compute_num_patches (self .img_size_scaled , patch_size )
278275 self .num_branches = len (patch_size )
279276 self .embed_dim = embed_dim
@@ -346,7 +343,7 @@ def forward_features(self, x):
346343 xs = []
347344 for i , patch_embed in enumerate (self .patch_embed ):
348345 ss = self .img_size_scaled [i ]
349- x_ = torch .nn .functional .interpolate (x , size = ss , mode = 'bicubic' ) if H != ss [0 ] else x
346+ x_ = torch .nn .functional .interpolate (x , size = ss , mode = 'bicubic' , align_corners = False ) if H != ss [0 ] else x
350347 tmp = patch_embed (x_ )
351348 cls_tokens = self .cls_token_0 if i == 0 else self .cls_token_1 # hard-coded for torch jit script
352349 cls_tokens = cls_tokens .expand (B , - 1 , - 1 )
@@ -361,15 +358,12 @@ def forward_features(self, x):
361358
362359 # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
363360 xs = [norm (xs [i ]) for i , norm in enumerate (self .norm )]
364- return tuple ( [x [:, 0 ] for x in xs ])
361+ return [x [:, 0 ] for x in xs ]
365362
366363 def forward (self , x ):
367364 xs = self .forward_features (x )
368365 ce_logits = [head (xs [i ]) for i , head in enumerate (self .head )]
369- if isinstance (self .head [0 ], nn .Identity ):
370- # FIXME to pass current passthrough features tests, could use better approach
371- ce_logits = tuple (ce_logits )
372- else :
366+ if not isinstance (self .head [0 ], nn .Identity ):
373367 ce_logits = torch .mean (torch .stack (ce_logits , dim = 0 ), dim = 0 )
374368 return ce_logits
375369
0 commit comments