2828from data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2929
3030_models = [
31- 'mnasnet_050' , 'mnasnet_075' , 'mnasnet_100' , 'mnasnet_140' , 'semnasnet_050' , 'semnasnet_075' ,
32- 'semnasnet_100' , 'semnasnet_140' , 'mnasnet_small' , 'mobilenetv1_100' , 'mobilenetv2_100' ,
31+ 'mnasnet_050' , 'mnasnet_075' , 'mnasnet_100' , 'mnasnet_b1' , ' mnasnet_140' , 'semnasnet_050' , 'semnasnet_075' ,
32+ 'semnasnet_100' , 'mnasnet_a1' , ' semnasnet_140' , 'mnasnet_small' , 'mobilenetv1_100' , 'mobilenetv2_100' ,
3333 'mobilenetv3_050' , 'mobilenetv3_075' , 'mobilenetv3_100' , 'chamnetv1_100' , 'chamnetv2_100' ,
3434 'fbnetc_100' , 'spnasnet_100' , 'tflite_mnasnet_100' , 'tflite_semnasnet_100' , 'efficientnet_b0' ,
3535 'efficientnet_b1' , 'efficientnet_b2' , 'efficientnet_b3' , 'efficientnet_b4' , 'tf_efficientnet_b0' ,
@@ -50,7 +50,9 @@ def _cfg(url='', **kwargs):
5050default_cfgs = {
5151 'mnasnet_050' : _cfg (url = '' ),
5252 'mnasnet_075' : _cfg (url = '' ),
53- 'mnasnet_100' : _cfg (url = '' ),
53+ 'mnasnet_100' : _cfg (
54+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth' ,
55+ interpolation = 'bicubic' ),
5456 'tflite_mnasnet_100' : _cfg (
5557 url = 'https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1' ,
5658 interpolation = 'bicubic' ),
@@ -161,8 +163,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
161163 is assumed to indicate the block type.
162164
163165 leading string - block type (
164- ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act,
165- ca = Cascade3x3, and possibly more)
166+ ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
166167 r - number of repeat blocks,
167168 k - kernel size,
168169 s - strides (1-9),
@@ -227,15 +228,6 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
227228 block_args ['pw_group' ] = options ['g' ]
228229 if options ['g' ] > 1 :
229230 block_args ['shuffle_type' ] = 'mid'
230- elif block_type == 'ca' :
231- block_args = dict (
232- block_type = block_type ,
233- kernel_size = int (options ['k' ]),
234- out_chs = int (options ['c' ]),
235- stride = int (options ['s' ]),
236- act_fn = act_fn ,
237- noskip = noskip ,
238- )
239231 elif block_type == 'ds' or block_type == 'dsa' :
240232 block_args = dict (
241233 block_type = block_type ,
@@ -345,8 +337,6 @@ def _make_block(self, ba):
345337 elif bt == 'ds' or bt == 'dsa' :
346338 ba ['drop_connect_rate' ] = self .drop_connect_rate
347339 block = DepthwiseSeparableConv (** ba )
348- elif bt == 'ca' :
349- block = CascadeConv (** ba )
350340 elif bt == 'cn' :
351341 block = ConvBnAct (** ba )
352342 else :
@@ -565,36 +555,6 @@ def forward(self, x):
565555 return x
566556
567557
568- class CascadeConv (nn .Sequential ):
569- # FIXME haven't used yet
570- def __init__ (self , in_chs , out_chs , kernel_size = 3 , stride = 2 , act_fn = F .relu , noskip = False ,
571- bn_momentum = _BN_MOMENTUM_PT_DEFAULT , bn_eps = _BN_EPS_PT_DEFAULT ,
572- folded_bn = False , padding_same = False ):
573- super (CascadeConv , self ).__init__ ()
574- assert stride in [1 , 2 ]
575- self .has_residual = (stride == 1 and in_chs == out_chs ) and not noskip
576- self .act_fn = act_fn
577- padding = _padding_arg (1 , padding_same )
578-
579- self .conv1 = sconv2d (in_chs , in_chs , kernel_size , stride = stride , padding = padding , bias = folded_bn )
580- self .bn1 = None if folded_bn else nn .BatchNorm2d (in_chs , momentum = bn_momentum , eps = bn_eps )
581- self .conv2 = sconv2d (in_chs , out_chs , kernel_size , stride = 1 , padding = padding , bias = folded_bn )
582- self .bn2 = None if folded_bn else nn .BatchNorm2d (out_chs , momentum = bn_momentum , eps = bn_eps )
583-
584- def forward (self , x ):
585- residual = x
586- x = self .conv1 (x )
587- if self .bn1 is not None :
588- x = self .bn1 (x )
589- x = self .act_fn (x )
590- x = self .conv2 (x )
591- if self .bn2 is not None :
592- x = self .bn2 (x )
593- if self .has_residual :
594- x += residual
595- return x
596-
597-
598558class InvertedResidual (nn .Module ):
599559 """ Inverted residual block w/ optional SE"""
600560
@@ -699,7 +659,6 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
699659 super (GenEfficientNet , self ).__init__ ()
700660 self .num_classes = num_classes
701661 self .drop_rate = drop_rate
702- self .drop_connect_rate = drop_connect_rate
703662 self .act_fn = act_fn
704663 self .num_features = num_features
705664
@@ -730,7 +689,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_f
730689 nn .BatchNorm2d (self .num_features , momentum = bn_momentum , eps = bn_eps )
731690
732691 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
733- self .classifier = nn .Linear (self .num_features , self .num_classes )
692+ self .classifier = nn .Linear (self .num_features * self . global_pool . feat_mult () , self .num_classes )
734693
735694 for m in self .modules ():
736695 if weight_init == 'goog' :
@@ -1220,6 +1179,11 @@ def mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12201179 return model
12211180
12221181
1182+ def mnasnet_b1 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1183+ """ MNASNet B1, depth multiplier of 1.0. """
1184+ return mnasnet_100 (num_classes , in_chans , pretrained , ** kwargs )
1185+
1186+
12231187def tflite_mnasnet_100 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
12241188 """ MNASNet B1, depth multiplier of 1.0. """
12251189 default_cfg = default_cfgs ['tflite_mnasnet_100' ]
@@ -1273,6 +1237,11 @@ def semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
12731237 return model
12741238
12751239
1240+ def mnasnet_a1 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
1241+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
1242+ return semnasnet_100 (num_classes , in_chans , pretrained , ** kwargs )
1243+
1244+
12761245def tflite_semnasnet_100 (num_classes , in_chans = 3 , pretrained = False , ** kwargs ):
12771246 """ MNASNet A1, depth multiplier of 1.0. """
12781247 default_cfg = default_cfgs ['tflite_semnasnet_100' ]
0 commit comments