2020from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2121from timm .layers import DropPath , trunc_normal_ , to_2tuple , Mlp
2222from ._builder import build_model_with_cfg
23+ from ._manipulate import checkpoint_seq
24+ from ._pretrained import generate_default_cfgs
2325from ._registry import register_model
2426
2527__all__ = ['EfficientFormer' ] # model_registry will add each entrypoint fn to this
2628
2729
28- def _cfg (url = '' , ** kwargs ):
29- return {
30- 'url' : url ,
31- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None , 'fixed_input_size' : True ,
32- 'crop_pct' : .95 , 'interpolation' : 'bicubic' ,
33- 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
34- 'first_conv' : 'stem.conv1' , 'classifier' : ('head' , 'head_dist' ),
35- ** kwargs
36- }
37-
38-
39- default_cfgs = dict (
40- efficientformer_l1 = _cfg (
41- url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l1_1000d_224-5b08fab0.pth" ,
42- ),
43- efficientformer_l3 = _cfg (
44- url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l3_300d_224-6816624f.pth" ,
45- ),
46- efficientformer_l7 = _cfg (
47- url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/efficientformer_l7_300d_224-e957ab75.pth" ,
48- ),
49- )
50-
5130EfficientFormer_width = {
5231 'l1' : (48 , 96 , 224 , 448 ),
5332 'l3' : (64 , 128 , 320 , 512 ),
@@ -99,7 +78,7 @@ def train(self, mode=True):
9978 self .attention_bias_cache = {} # clear ab cache
10079
10180 def get_attention_biases (self , device : torch .device ) -> torch .Tensor :
102- if self .training :
81+ if torch . jit . is_tracing () or self .training :
10382 return self .attention_biases [:, self .attention_bias_idxs ]
10483 else :
10584 device_key = str (device )
@@ -279,16 +258,17 @@ def __init__(
279258 ):
280259 super ().__init__ ()
281260 self .token_mixer = Pooling (pool_size = pool_size )
261+ self .ls1 = LayerScale2d (dim , layer_scale_init_value )
262+ self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
263+
282264 self .mlp = ConvMlpWithNorm (
283265 dim , hidden_features = int (dim * mlp_ratio ), act_layer = act_layer , norm_layer = norm_layer , drop = drop )
284-
285- self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
286- self .ls1 = LayerScale2d (dim , layer_scale_init_value )
287266 self .ls2 = LayerScale2d (dim , layer_scale_init_value )
267+ self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
288268
289269 def forward (self , x ):
290- x = x + self .drop_path (self .ls1 (self .token_mixer (x )))
291- x = x + self .drop_path (self .ls2 (self .mlp (x )))
270+ x = x + self .drop_path1 (self .ls1 (self .token_mixer (x )))
271+ x = x + self .drop_path2 (self .ls2 (self .mlp (x )))
292272 return x
293273
294274
@@ -356,7 +336,10 @@ def __init__(
356336
357337 def forward (self , x ):
358338 x = self .downsample (x )
359- x = self .blocks (x )
339+ if self .grad_checkpointing and not torch .jit .is_scripting ():
340+ x = checkpoint_seq (self .blocks , x )
341+ else :
342+ x = self .blocks (x )
360343 return x
361344
362345
@@ -514,6 +497,30 @@ def _checkpoint_filter_fn(state_dict, model):
514497 return out_dict
515498
516499
500+ def _cfg (url = '' , ** kwargs ):
501+ return {
502+ 'url' : url ,
503+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None , 'fixed_input_size' : True ,
504+ 'crop_pct' : .95 , 'interpolation' : 'bicubic' ,
505+ 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
506+ 'first_conv' : 'stem.conv1' , 'classifier' : ('head' , 'head_dist' ),
507+ ** kwargs
508+ }
509+
510+
511+ default_cfgs = generate_default_cfgs ({
512+ 'efficientformer_l1.snap_dist_in1k' : _cfg (
513+ hf_hub_id = 'timm/' ,
514+ ),
515+ 'efficientformer_l3.snap_dist_in1k' : _cfg (
516+ hf_hub_id = 'timm/' ,
517+ ),
518+ 'efficientformer_l7.snap_dist_in1k' : _cfg (
519+ hf_hub_id = 'timm/' ,
520+ ),
521+ })
522+
523+
517524def _create_efficientformer (variant , pretrained = False , ** kwargs ):
518525 model = build_model_with_cfg (
519526 EfficientFormer , variant , pretrained ,
@@ -524,30 +531,30 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
524531
525532@register_model
526533def efficientformer_l1 (pretrained = False , ** kwargs ):
527- model_kwargs = dict (
534+ model_args = dict (
528535 depths = EfficientFormer_depth ['l1' ],
529536 embed_dims = EfficientFormer_width ['l1' ],
530537 num_vit = 1 ,
531- ** kwargs )
532- return _create_efficientformer ('efficientformer_l1' , pretrained = pretrained , ** model_kwargs )
538+ )
539+ return _create_efficientformer ('efficientformer_l1' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
533540
534541
535542@register_model
536543def efficientformer_l3 (pretrained = False , ** kwargs ):
537- model_kwargs = dict (
544+ model_args = dict (
538545 depths = EfficientFormer_depth ['l3' ],
539546 embed_dims = EfficientFormer_width ['l3' ],
540547 num_vit = 4 ,
541- ** kwargs )
542- return _create_efficientformer ('efficientformer_l3' , pretrained = pretrained , ** model_kwargs )
548+ )
549+ return _create_efficientformer ('efficientformer_l3' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
543550
544551
545552@register_model
546553def efficientformer_l7 (pretrained = False , ** kwargs ):
547- model_kwargs = dict (
554+ model_args = dict (
548555 depths = EfficientFormer_depth ['l7' ],
549556 embed_dims = EfficientFormer_width ['l7' ],
550557 num_vit = 8 ,
551- ** kwargs )
552- return _create_efficientformer ('efficientformer_l7' , pretrained = pretrained , ** model_kwargs )
558+ )
559+ return _create_efficientformer ('efficientformer_l7' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
553560
0 commit comments