Skip to content

Commit 511a8e8

Browse files
committed
Add official ResMLP weights.
1 parent b9cfb64 commit 511a8e8

File tree

1 file changed

+134
-12
lines changed

1 file changed

+134
-12
lines changed

timm/models/mlp_mixer.py

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
year={2021}
1515
}
1616
17-
Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
17+
Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
1818
19+
Code: https://github.com/facebookresearch/deit
1920
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
2021
@misc{touvron2021resmlp,
2122
title={ResMLP: Feedforward networks for image classification with data-efficient training},
@@ -94,11 +95,36 @@ def _cfg(url='', **kwargs):
9495
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
9596
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
9697

97-
resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
98+
resmlp_12_224=_cfg(
99+
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
100+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
98101
resmlp_24_224=_cfg(
99-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
100-
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.89),
101-
resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
102+
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
103+
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
104+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
105+
resmlp_36_224=_cfg(
106+
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
107+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
108+
resmlp_big_24_224=_cfg(
109+
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
110+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
111+
112+
resmlp_12_distilled_224=_cfg(
113+
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
114+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
115+
resmlp_24_distilled_224=_cfg(
116+
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
117+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
118+
resmlp_36_distilled_224=_cfg(
119+
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
120+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
121+
resmlp_big_24_distilled_224=_cfg(
122+
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
123+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
124+
125+
resmlp_big_24_224_in22ft1k=_cfg(
126+
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
127+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
102128

103129
gmlp_ti16_224=_cfg(),
104130
gmlp_s16_224=_cfg(),
@@ -266,20 +292,27 @@ def forward(self, x):
266292
return x
267293

268294

269-
def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
295+
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
270296
""" Mixer weight initialization (trying to match Flax defaults)
271297
"""
272298
if isinstance(module, nn.Linear):
273299
if name.startswith('head'):
274300
nn.init.zeros_(module.weight)
275301
nn.init.constant_(module.bias, head_bias)
276302
else:
277-
nn.init.xavier_uniform_(module.weight)
278-
if module.bias is not None:
279-
if 'mlp' in name:
280-
nn.init.normal_(module.bias, std=1e-6)
281-
else:
303+
if flax:
304+
# Flax defaults
305+
lecun_normal_(module.weight)
306+
if module.bias is not None:
282307
nn.init.zeros_(module.bias)
308+
else:
309+
# like MLP init in vit (my original init)
310+
nn.init.xavier_uniform_(module.weight)
311+
if module.bias is not None:
312+
if 'mlp' in name:
313+
nn.init.normal_(module.bias, std=1e-6)
314+
else:
315+
nn.init.zeros_(module.bias)
283316
elif isinstance(module, nn.Conv2d):
284317
lecun_normal_(module.weight)
285318
if module.bias is not None:
@@ -293,13 +326,31 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
293326
module.init_weights()
294327

295328

329+
def checkpoint_filter_fn(state_dict, model):
330+
""" Remap checkpoints if needed """
331+
if 'patch_embed.proj.weight' in state_dict:
332+
# Remap FB ResMlp models -> timm
333+
out_dict = {}
334+
for k, v in state_dict.items():
335+
k = k.replace('patch_embed.', 'stem.')
336+
k = k.replace('attn.', 'linear_tokens.')
337+
k = k.replace('mlp.', 'mlp_channels.')
338+
k = k.replace('gamma_', 'ls')
339+
if k.endswith('.alpha') or k.endswith('.beta'):
340+
v = v.reshape(1, 1, -1)
341+
out_dict[k] = v
342+
return out_dict
343+
return state_dict
344+
345+
296346
def _create_mixer(variant, pretrained=False, **kwargs):
297347
if kwargs.get('features_only', None):
298348
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
299349

300350
model = build_model_with_cfg(
301351
MlpMixer, variant, pretrained,
302352
default_cfg=default_cfgs[variant],
353+
pretrained_filter_fn=checkpoint_filter_fn,
303354
**kwargs)
304355
return model
305356

@@ -458,11 +509,82 @@ def resmlp_36_224(pretrained=False, **kwargs):
458509
"""
459510
model_args = dict(
460511
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
461-
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
512+
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
462513
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
463514
return model
464515

465516

517+
@register_model
518+
def resmlp_big_24_224(pretrained=False, **kwargs):
519+
""" ResMLP-B-24
520+
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
521+
"""
522+
model_args = dict(
523+
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
524+
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
525+
model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
526+
return model
527+
528+
529+
@register_model
530+
def resmlp_12_distilled_224(pretrained=False, **kwargs):
531+
""" ResMLP-12
532+
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
533+
"""
534+
model_args = dict(
535+
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
536+
model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
537+
return model
538+
539+
540+
@register_model
541+
def resmlp_24_distilled_224(pretrained=False, **kwargs):
542+
""" ResMLP-24
543+
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
544+
"""
545+
model_args = dict(
546+
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
547+
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
548+
model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
549+
return model
550+
551+
552+
@register_model
553+
def resmlp_36_distilled_224(pretrained=False, **kwargs):
554+
""" ResMLP-36
555+
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
556+
"""
557+
model_args = dict(
558+
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
559+
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
560+
model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
561+
return model
562+
563+
564+
@register_model
565+
def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
566+
""" ResMLP-B-24
567+
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
568+
"""
569+
model_args = dict(
570+
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
571+
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
572+
model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
573+
return model
574+
575+
576+
@register_model
577+
def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
578+
""" ResMLP-B-24
579+
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
580+
"""
581+
model_args = dict(
582+
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
583+
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
584+
model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
585+
return model
586+
587+
466588
@register_model
467589
def gmlp_ti16_224(pretrained=False, **kwargs):
468590
""" gMLP-Tiny

0 commit comments

Comments
 (0)