Skip to content

Commit f8a215c

Browse files
committed
A few more crossvit tweaks, fix training w/ no_weight_decay names, add crop option for scaling, adjust default crop_pct for large img size to 1.0 for better results
1 parent 7ab2491 commit f8a215c

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

timm/models/crossvit.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
def _cfg(url='', **kwargs):
4141
return {
4242
'url': url,
43-
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None,
43+
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
4444
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
4545
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
4646
'classifier': ('head.0', 'head.1'),
@@ -56,7 +56,7 @@ def _cfg(url='', **kwargs):
5656
),
5757
'crossvit_15_dagger_408': _cfg(
5858
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
59-
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
59+
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
6060
),
6161
'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
6262
'crossvit_18_dagger_240': _cfg(
@@ -65,7 +65,7 @@ def _cfg(url='', **kwargs):
6565
),
6666
'crossvit_18_dagger_408': _cfg(
6767
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
68-
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
68+
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
6969
),
7070
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
7171
'crossvit_9_dagger_240': _cfg(
@@ -263,14 +263,15 @@ def __init__(
263263
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
264264
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
265265
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
266-
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False
266+
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False,
267267
):
268268
super().__init__()
269269

270270
self.num_classes = num_classes
271271
self.img_size = to_2tuple(img_size)
272272
img_scale = to_2tuple(img_scale)
273273
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
274+
self.crop_scale = crop_scale # crop instead of interpolate for scale
274275
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
275276
self.num_branches = len(patch_size)
276277
self.embed_dim = embed_dim
@@ -307,8 +308,7 @@ def __init__(
307308
for i in range(self.num_branches)])
308309

309310
for i in range(self.num_branches):
310-
if hasattr(self, f'pos_embed_{i}'):
311-
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
311+
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
312312
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
313313

314314
self.apply(self._init_weights)
@@ -324,9 +324,12 @@ def _init_weights(self, m):
324324

325325
@torch.jit.ignore
326326
def no_weight_decay(self):
327-
out = {'cls_token'}
328-
if self.pos_embed[0].requires_grad:
329-
out.add('pos_embed')
327+
out = set()
328+
for i in range(self.num_branches):
329+
out.add(f'cls_token_{i}')
330+
pe = getattr(self, f'pos_embed_{i}', None)
331+
if pe is not None and pe.requires_grad:
332+
out.add(f'pos_embed_{i}')
330333
return out
331334

332335
def get_classifier(self):
@@ -342,23 +345,29 @@ def forward_features(self, x):
342345
B, C, H, W = x.shape
343346
xs = []
344347
for i, patch_embed in enumerate(self.patch_embed):
348+
x_ = x
345349
ss = self.img_size_scaled[i]
346-
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) if H != ss[0] else x
347-
tmp = patch_embed(x_)
350+
if H != ss[0] or W != ss[1]:
351+
if self.crop_scale and ss[0] <= H and ss[1] <= W:
352+
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
353+
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
354+
else:
355+
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
356+
x_ = patch_embed(x_)
348357
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
349358
cls_tokens = cls_tokens.expand(B, -1, -1)
350-
tmp = torch.cat((cls_tokens, tmp), dim=1)
359+
x_ = torch.cat((cls_tokens, x_), dim=1)
351360
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
352-
tmp = tmp + pos_embed
353-
tmp = self.pos_drop(tmp)
354-
xs.append(tmp)
361+
x_ = x_ + pos_embed
362+
x_ = self.pos_drop(x_)
363+
xs.append(x_)
355364

356365
for i, blk in enumerate(self.blocks):
357366
xs = blk(xs)
358367

359368
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
360369
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
361-
return [x[:, 0] for x in xs]
370+
return [xo[:, 0] for xo in xs]
362371

363372
def forward(self, x):
364373
xs = self.forward_features(x)

0 commit comments

Comments
 (0)