@@ -452,18 +452,20 @@ def forward(self, x) -> List[torch.Tensor]:
452452 return list (out .values ())
453453
454454
455- def _create_effnet (model_kwargs , variant , pretrained = False ):
455+ def _create_effnet (variant , pretrained = False , ** kwargs ):
456456 features_only = False
457457 model_cls = EfficientNet
458- if model_kwargs .pop ('features_only' , False ):
458+ kwargs_filter = None
459+ if kwargs .pop ('features_only' , False ):
459460 features_only = True
460- model_kwargs .pop ('num_classes' , 0 )
461- model_kwargs .pop ('num_features' , 0 )
462- model_kwargs .pop ('head_conv' , None )
461+ kwargs_filter = ('num_classes' , 'num_features' , 'head_conv' , 'global_pool' )
463462 model_cls = EfficientNetFeatures
464463 model = build_model_with_cfg (
465- model_cls , variant , pretrained , default_cfg = default_cfgs [variant ],
466- pretrained_strict = not features_only , ** model_kwargs )
464+ model_cls , variant , pretrained ,
465+ default_cfg = default_cfgs [variant ],
466+ pretrained_strict = not features_only ,
467+ kwargs_filter = kwargs_filter ,
468+ ** kwargs )
467469 if features_only :
468470 model .default_cfg = default_cfg_for_features (model .default_cfg )
469471 return model
@@ -501,7 +503,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
501503 norm_kwargs = resolve_bn_args (kwargs ),
502504 ** kwargs
503505 )
504- model = _create_effnet (model_kwargs , variant , pretrained )
506+ model = _create_effnet (variant , pretrained , ** model_kwargs )
505507 return model
506508
507509
@@ -537,7 +539,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
537539 norm_kwargs = resolve_bn_args (kwargs ),
538540 ** kwargs
539541 )
540- model = _create_effnet (model_kwargs , variant , pretrained )
542+ model = _create_effnet (variant , pretrained , ** model_kwargs )
541543 return model
542544
543545
@@ -566,7 +568,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
566568 norm_kwargs = resolve_bn_args (kwargs ),
567569 ** kwargs
568570 )
569- model = _create_effnet (model_kwargs , variant , pretrained )
571+ model = _create_effnet (variant , pretrained , ** model_kwargs )
570572 return model
571573
572574
@@ -595,7 +597,7 @@ def _gen_mobilenet_v2(
595597 act_layer = resolve_act_layer (kwargs , 'relu6' ),
596598 ** kwargs
597599 )
598- model = _create_effnet (model_kwargs , variant , pretrained )
600+ model = _create_effnet (variant , pretrained , ** model_kwargs )
599601 return model
600602
601603
@@ -625,7 +627,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
625627 norm_kwargs = resolve_bn_args (kwargs ),
626628 ** kwargs
627629 )
628- model = _create_effnet (model_kwargs , variant , pretrained )
630+ model = _create_effnet (variant , pretrained , ** model_kwargs )
629631 return model
630632
631633
@@ -660,7 +662,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
660662 norm_kwargs = resolve_bn_args (kwargs ),
661663 ** kwargs
662664 )
663- model = _create_effnet (model_kwargs , variant , pretrained )
665+ model = _create_effnet (variant , pretrained , ** model_kwargs )
664666 return model
665667
666668
@@ -706,7 +708,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
706708 norm_kwargs = resolve_bn_args (kwargs ),
707709 ** kwargs ,
708710 )
709- model = _create_effnet (model_kwargs , variant , pretrained )
711+ model = _create_effnet (variant , pretrained , ** model_kwargs )
710712 return model
711713
712714
@@ -735,7 +737,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
735737 act_layer = resolve_act_layer (kwargs , 'relu' ),
736738 ** kwargs ,
737739 )
738- model = _create_effnet (model_kwargs , variant , pretrained )
740+ model = _create_effnet (variant , pretrained , ** model_kwargs )
739741 return model
740742
741743
@@ -765,7 +767,7 @@ def _gen_efficientnet_condconv(
765767 act_layer = resolve_act_layer (kwargs , 'swish' ),
766768 ** kwargs ,
767769 )
768- model = _create_effnet (model_kwargs , variant , pretrained )
770+ model = _create_effnet (variant , pretrained , ** model_kwargs )
769771 return model
770772
771773
@@ -806,7 +808,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
806808 norm_kwargs = resolve_bn_args (kwargs ),
807809 ** kwargs ,
808810 )
809- model = _create_effnet (model_kwargs , variant , pretrained )
811+ model = _create_effnet (variant , pretrained , ** model_kwargs )
810812 return model
811813
812814
@@ -839,7 +841,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
839841 norm_kwargs = resolve_bn_args (kwargs ),
840842 ** kwargs
841843 )
842- model = _create_effnet (model_kwargs , variant , pretrained )
844+ model = _create_effnet (variant , pretrained , ** model_kwargs )
843845 return model
844846
845847
@@ -872,7 +874,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
872874 norm_kwargs = resolve_bn_args (kwargs ),
873875 ** kwargs
874876 )
875- model = _create_effnet (model_kwargs , variant , pretrained )
877+ model = _create_effnet (variant , pretrained , ** model_kwargs )
876878 return model
877879
878880
0 commit comments