2020from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2121from timm .layers import DropPath , trunc_normal_ , to_2tuple , Mlp
2222from ._builder import build_model_with_cfg
23+ from ._pretrained import generate_default_cfgs
2324from ._registry import register_model
2425
2526__all__ = ['EfficientFormer' ] # model_registry will add each entrypoint fn to this
2627
2728
28- def _cfg (url = '' , ** kwargs ):
29- return {
30- 'url' : url ,
31- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None , 'fixed_input_size' : True ,
32- 'crop_pct' : .95 , 'interpolation' : 'bicubic' ,
33- 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
34- 'first_conv' : 'stem.conv1' , 'classifier' : ('head' , 'head_dist' ),
35- ** kwargs
36- }
37-
38-
39- default_cfgs = dict (
40- efficientformer_l1 = _cfg (
41- url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth" ,
42- ),
43- efficientformer_l3 = _cfg (
44- url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth" ,
45- ),
46- efficientformer_l7 = _cfg (
47- url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth" ,
48- ),
49- )
50-
5129EfficientFormer_width = {
5230 'l1' : (48 , 96 , 224 , 448 ),
5331 'l3' : (64 , 128 , 320 , 512 ),
@@ -99,7 +77,7 @@ def train(self, mode=True):
9977 self .attention_bias_cache = {} # clear ab cache
10078
10179 def get_attention_biases (self , device : torch .device ) -> torch .Tensor :
102- if self .training :
80+ if torch . jit . is_tracing () or self .training :
10381 return self .attention_biases [:, self .attention_bias_idxs ]
10482 else :
10583 device_key = str (device )
@@ -279,16 +257,17 @@ def __init__(
279257 ):
280258 super ().__init__ ()
281259 self .token_mixer = Pooling (pool_size = pool_size )
260+ self .ls1 = LayerScale2d (dim , layer_scale_init_value )
261+ self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
262+
282263 self .mlp = ConvMlpWithNorm (
283264 dim , hidden_features = int (dim * mlp_ratio ), act_layer = act_layer , norm_layer = norm_layer , drop = drop )
284-
285- self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
286- self .ls1 = LayerScale2d (dim , layer_scale_init_value )
287265 self .ls2 = LayerScale2d (dim , layer_scale_init_value )
266+ self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
288267
289268 def forward (self , x ):
290- x = x + self .drop_path (self .ls1 (self .token_mixer (x )))
291- x = x + self .drop_path (self .ls2 (self .mlp (x )))
269+ x = x + self .drop_path1 (self .ls1 (self .token_mixer (x )))
270+ x = x + self .drop_path2 (self .ls2 (self .mlp (x )))
292271 return x
293272
294273
@@ -514,6 +493,30 @@ def _checkpoint_filter_fn(state_dict, model):
514493 return out_dict
515494
516495
496+ def _cfg (url = '' , ** kwargs ):
497+ return {
498+ 'url' : url ,
499+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None , 'fixed_input_size' : True ,
500+ 'crop_pct' : .95 , 'interpolation' : 'bicubic' ,
501+ 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
502+ 'first_conv' : 'stem.conv1' , 'classifier' : ('head' , 'head_dist' ),
503+ ** kwargs
504+ }
505+
506+
507+ default_cfgs = generate_default_cfgs ({
508+ 'efficientformer_l1.snap_dist_in1k' : _cfg (
509+ hf_hub_id = 'timm/' ,
510+ ),
511+ 'efficientformer_l3.snap_dist_in1k' : _cfg (
512+ hf_hub_id = 'timm/' ,
513+ ),
514+ 'efficientformer_l7.snap_dist_in1k' : _cfg (
515+ hf_hub_id = 'timm/' ,
516+ ),
517+ })
518+
519+
517520def _create_efficientformer (variant , pretrained = False , ** kwargs ):
518521 model = build_model_with_cfg (
519522 EfficientFormer , variant , pretrained ,
@@ -524,30 +527,30 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
524527
525528@register_model
526529def efficientformer_l1 (pretrained = False , ** kwargs ):
527- model_kwargs = dict (
530+ model_args = dict (
528531 depths = EfficientFormer_depth ['l1' ],
529532 embed_dims = EfficientFormer_width ['l1' ],
530533 num_vit = 1 ,
531- ** kwargs )
532- return _create_efficientformer ('efficientformer_l1' , pretrained = pretrained , ** model_kwargs )
534+ )
535+ return _create_efficientformer ('efficientformer_l1' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
533536
534537
535538@register_model
536539def efficientformer_l3 (pretrained = False , ** kwargs ):
537- model_kwargs = dict (
540+ model_args = dict (
538541 depths = EfficientFormer_depth ['l3' ],
539542 embed_dims = EfficientFormer_width ['l3' ],
540543 num_vit = 4 ,
541- ** kwargs )
542- return _create_efficientformer ('efficientformer_l3' , pretrained = pretrained , ** model_kwargs )
544+ )
545+ return _create_efficientformer ('efficientformer_l3' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
543546
544547
545548@register_model
546549def efficientformer_l7 (pretrained = False , ** kwargs ):
547- model_kwargs = dict (
550+ model_args = dict (
548551 depths = EfficientFormer_depth ['l7' ],
549552 embed_dims = EfficientFormer_width ['l7' ],
550553 num_vit = 8 ,
551- ** kwargs )
552- return _create_efficientformer ('efficientformer_l7' , pretrained = pretrained , ** model_kwargs )
554+ )
555+ return _create_efficientformer ('efficientformer_l7' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
553556
0 commit comments