1414 year={2021}
1515}
1616
17- Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
17+ Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
1818
19+ Code: https://github.com/facebookresearch/deit
1920Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
2021@misc{touvron2021resmlp,
2122 title={ResMLP: Feedforward networks for image classification with data-efficient training},
@@ -94,11 +95,36 @@ def _cfg(url='', **kwargs):
9495 gmixer_12_224 = _cfg (mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
9596 gmixer_24_224 = _cfg (mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
9697
97- resmlp_12_224 = _cfg (mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
98+ resmlp_12_224 = _cfg (
99+ url = 'https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth' ,
100+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
98101 resmlp_24_224 = _cfg (
99- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth' ,
100- mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , crop_pct = 0.89 ),
101- resmlp_36_224 = _cfg (mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
102+ url = 'https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth' ,
103+ #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
104+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
105+ resmlp_36_224 = _cfg (
106+ url = 'https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth' ,
107+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
108+ resmlp_big_24_224 = _cfg (
109+ url = 'https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth' ,
110+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
111+
112+ resmlp_12_distilled_224 = _cfg (
113+ url = 'https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth' ,
114+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
115+ resmlp_24_distilled_224 = _cfg (
116+ url = 'https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth' ,
117+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
118+ resmlp_36_distilled_224 = _cfg (
119+ url = 'https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth' ,
120+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
121+ resmlp_big_24_distilled_224 = _cfg (
122+ url = 'https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth' ,
123+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
124+
125+ resmlp_big_24_224_in22ft1k = _cfg (
126+ url = 'https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth' ,
127+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ),
102128
103129 gmlp_ti16_224 = _cfg (),
104130 gmlp_s16_224 = _cfg (),
@@ -266,20 +292,27 @@ def forward(self, x):
266292 return x
267293
268294
269- def _init_weights (module : nn .Module , name : str , head_bias : float = 0. ):
295+ def _init_weights (module : nn .Module , name : str , head_bias : float = 0. , flax = False ):
270296 """ Mixer weight initialization (trying to match Flax defaults)
271297 """
272298 if isinstance (module , nn .Linear ):
273299 if name .startswith ('head' ):
274300 nn .init .zeros_ (module .weight )
275301 nn .init .constant_ (module .bias , head_bias )
276302 else :
277- nn .init .xavier_uniform_ (module .weight )
278- if module .bias is not None :
279- if 'mlp' in name :
280- nn .init .normal_ (module .bias , std = 1e-6 )
281- else :
303+ if flax :
304+ # Flax defaults
305+ lecun_normal_ (module .weight )
306+ if module .bias is not None :
282307 nn .init .zeros_ (module .bias )
308+ else :
309+ # like MLP init in vit (my original init)
310+ nn .init .xavier_uniform_ (module .weight )
311+ if module .bias is not None :
312+ if 'mlp' in name :
313+ nn .init .normal_ (module .bias , std = 1e-6 )
314+ else :
315+ nn .init .zeros_ (module .bias )
283316 elif isinstance (module , nn .Conv2d ):
284317 lecun_normal_ (module .weight )
285318 if module .bias is not None :
@@ -293,13 +326,31 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
293326 module .init_weights ()
294327
295328
329+ def checkpoint_filter_fn (state_dict , model ):
330+ """ Remap checkpoints if needed """
331+ if 'patch_embed.proj.weight' in state_dict :
332+ # Remap FB ResMlp models -> timm
333+ out_dict = {}
334+ for k , v in state_dict .items ():
335+ k = k .replace ('patch_embed.' , 'stem.' )
336+ k = k .replace ('attn.' , 'linear_tokens.' )
337+ k = k .replace ('mlp.' , 'mlp_channels.' )
338+ k = k .replace ('gamma_' , 'ls' )
339+ if k .endswith ('.alpha' ) or k .endswith ('.beta' ):
340+ v = v .reshape (1 , 1 , - 1 )
341+ out_dict [k ] = v
342+ return out_dict
343+ return state_dict
344+
345+
296346def _create_mixer (variant , pretrained = False , ** kwargs ):
297347 if kwargs .get ('features_only' , None ):
298348 raise RuntimeError ('features_only not implemented for MLP-Mixer models.' )
299349
300350 model = build_model_with_cfg (
301351 MlpMixer , variant , pretrained ,
302352 default_cfg = default_cfgs [variant ],
353+ pretrained_filter_fn = checkpoint_filter_fn ,
303354 ** kwargs )
304355 return model
305356
@@ -458,11 +509,82 @@ def resmlp_36_224(pretrained=False, **kwargs):
458509 """
459510 model_args = dict (
460511 patch_size = 16 , num_blocks = 36 , embed_dim = 384 , mlp_ratio = 4 ,
461- block_layer = partial (ResBlock , init_values = 1e-5 ), norm_layer = Affine , ** kwargs )
512+ block_layer = partial (ResBlock , init_values = 1e-6 ), norm_layer = Affine , ** kwargs )
462513 model = _create_mixer ('resmlp_36_224' , pretrained = pretrained , ** model_args )
463514 return model
464515
465516
517+ @register_model
518+ def resmlp_big_24_224 (pretrained = False , ** kwargs ):
519+ """ ResMLP-B-24
520+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
521+ """
522+ model_args = dict (
523+ patch_size = 8 , num_blocks = 24 , embed_dim = 768 , mlp_ratio = 4 ,
524+ block_layer = partial (ResBlock , init_values = 1e-6 ), norm_layer = Affine , ** kwargs )
525+ model = _create_mixer ('resmlp_big_24_224' , pretrained = pretrained , ** model_args )
526+ return model
527+
528+
529+ @register_model
530+ def resmlp_12_distilled_224 (pretrained = False , ** kwargs ):
531+ """ ResMLP-12
532+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
533+ """
534+ model_args = dict (
535+ patch_size = 16 , num_blocks = 12 , embed_dim = 384 , mlp_ratio = 4 , block_layer = ResBlock , norm_layer = Affine , ** kwargs )
536+ model = _create_mixer ('resmlp_12_distilled_224' , pretrained = pretrained , ** model_args )
537+ return model
538+
539+
540+ @register_model
541+ def resmlp_24_distilled_224 (pretrained = False , ** kwargs ):
542+ """ ResMLP-24
543+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
544+ """
545+ model_args = dict (
546+ patch_size = 16 , num_blocks = 24 , embed_dim = 384 , mlp_ratio = 4 ,
547+ block_layer = partial (ResBlock , init_values = 1e-5 ), norm_layer = Affine , ** kwargs )
548+ model = _create_mixer ('resmlp_24_distilled_224' , pretrained = pretrained , ** model_args )
549+ return model
550+
551+
552+ @register_model
553+ def resmlp_36_distilled_224 (pretrained = False , ** kwargs ):
554+ """ ResMLP-36
555+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
556+ """
557+ model_args = dict (
558+ patch_size = 16 , num_blocks = 36 , embed_dim = 384 , mlp_ratio = 4 ,
559+ block_layer = partial (ResBlock , init_values = 1e-6 ), norm_layer = Affine , ** kwargs )
560+ model = _create_mixer ('resmlp_36_distilled_224' , pretrained = pretrained , ** model_args )
561+ return model
562+
563+
564+ @register_model
565+ def resmlp_big_24_distilled_224 (pretrained = False , ** kwargs ):
566+ """ ResMLP-B-24
567+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
568+ """
569+ model_args = dict (
570+ patch_size = 8 , num_blocks = 24 , embed_dim = 768 , mlp_ratio = 4 ,
571+ block_layer = partial (ResBlock , init_values = 1e-6 ), norm_layer = Affine , ** kwargs )
572+ model = _create_mixer ('resmlp_big_24_distilled_224' , pretrained = pretrained , ** model_args )
573+ return model
574+
575+
576+ @register_model
577+ def resmlp_big_24_224_in22ft1k (pretrained = False , ** kwargs ):
578+ """ ResMLP-B-24
579+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
580+ """
581+ model_args = dict (
582+ patch_size = 8 , num_blocks = 24 , embed_dim = 768 , mlp_ratio = 4 ,
583+ block_layer = partial (ResBlock , init_values = 1e-6 ), norm_layer = Affine , ** kwargs )
584+ model = _create_mixer ('resmlp_big_24_224_in22ft1k' , pretrained = pretrained , ** model_args )
585+ return model
586+
587+
466588@register_model
467589def gmlp_ti16_224 (pretrained = False , ** kwargs ):
468590 """ gMLP-Tiny
0 commit comments