@@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
148148 and object detection models.
149149 """
150150
151- def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'pre_pwl ' ,
151+ def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'bottleneck ' ,
152152 in_chans = 3 , stem_size = 16 , channel_multiplier = 1.0 , output_stride = 32 , pad_type = '' ,
153153 act_layer = nn .ReLU , drop_rate = 0. , drop_path_rate = 0. , se_kwargs = None ,
154154 norm_layer = nn .BatchNorm2d , norm_kwargs = None ):
@@ -174,34 +174,47 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
174174 channel_multiplier , 8 , None , output_stride , pad_type , act_layer , se_kwargs ,
175175 norm_layer , norm_kwargs , drop_path_rate , feature_location = feature_location , verbose = _DEBUG )
176176 self .blocks = nn .Sequential (* builder (self ._in_chs , block_args ))
177- self .feature_info = builder .features # builder provides info about feature channels for each block
177+ self ._feature_info = builder .features # builder provides info about feature channels for each block
178+ self ._stage_to_feature_idx = {
179+ v ['stage_idx' ]: fi for fi , v in self ._feature_info .items () if fi in self .out_indices }
178180 self ._in_chs = builder .in_chs
179181
180182 efficientnet_init_weights (self )
181183 if _DEBUG :
182- for k , v in self .feature_info .items ():
184+ for k , v in self ._feature_info .items ():
183185 print ('Feature idx: {}: Name: {}, Channels: {}' .format (k , v ['name' ], v ['num_chs' ]))
184186
185187 # Register feature extraction hooks with FeatureHooks helper
186- hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
187- hooks = [dict (name = self .feature_info [idx ]['name' ], type = hook_type ) for idx in out_indices ]
188- self .feature_hooks = FeatureHooks (hooks , self .named_modules ())
188+ self .feature_hooks = None
189+ if feature_location != 'bottleneck' :
190+ hooks = [dict (
191+ name = self ._feature_info [idx ]['module' ],
192+ type = self ._feature_info [idx ]['hook_type' ]) for idx in out_indices ]
193+ self .feature_hooks = FeatureHooks (hooks , self .named_modules ())
189194
190195 def feature_channels (self , idx = None ):
191196 """ Feature Channel Shortcut
192197 Returns feature channel count for each output index if idx == None. If idx is an integer, will
193198 return feature channel count for that feature block index (independent of out_indices setting).
194199 """
195200 if isinstance (idx , int ):
196- return self .feature_info [idx ]['num_chs' ]
197- return [self .feature_info [i ]['num_chs' ] for i in self .out_indices ]
201+ return self ._feature_info [idx ]['num_chs' ]
202+ return [self ._feature_info [i ]['num_chs' ] for i in self .out_indices ]
198203
199204 def forward (self , x ):
200205 x = self .conv_stem (x )
201206 x = self .bn1 (x )
202207 x = self .act1 (x )
203- self .blocks (x )
204- return self .feature_hooks .get_output (x .device )
208+ if self .feature_hooks is None :
209+ features = []
210+ for i , b in enumerate (self .blocks ):
211+ x = b (x )
212+ if i in self ._stage_to_feature_idx :
213+ features .append (x )
214+ return features
215+ else :
216+ self .blocks (x )
217+ return self .feature_hooks .get_output (x .device )
205218
206219
207220def _create_model (model_kwargs , default_cfg , pretrained = False ):
0 commit comments