@@ -134,7 +134,8 @@ def forward(self, x, pre_logits: bool = False):
134134
135135
136136class NormMlpClassifierHead (nn .Module ):
137-
137+ """ A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
138+ """
138139 def __init__ (
139140 self ,
140141 in_features : int ,
@@ -204,3 +205,79 @@ def forward(self, x, pre_logits: bool = False):
204205 return x
205206 x = self .fc (x )
206207 return x
208+
209+
210+ class ClNormMlpClassifierHead (nn .Module ):
211+ """ A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
212+ """
213+ def __init__ (
214+ self ,
215+ in_features : int ,
216+ num_classes : int ,
217+ hidden_size : Optional [int ] = None ,
218+ pool_type : str = 'avg' ,
219+ drop_rate : float = 0. ,
220+ norm_layer : Union [str , Callable ] = 'layernorm' ,
221+ act_layer : Union [str , Callable ] = 'gelu' ,
222+ input_fmt : str = 'NHWC' ,
223+ ):
224+ """
225+ Args:
226+ in_features: The number of input features.
227+ num_classes: The number of classes for the final classifier layer (output).
228+ hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
229+ pool_type: Global pooling type, pooling disabled if empty string ('').
230+ drop_rate: Pre-classifier dropout rate.
231+ norm_layer: Normalization layer type.
232+ act_layer: MLP activation layer type (only used if hidden_size is not None).
233+ """
234+ super ().__init__ ()
235+ self .in_features = in_features
236+ self .hidden_size = hidden_size
237+ self .num_features = in_features
238+ assert pool_type in ('' , 'avg' , 'max' , 'avgmax' )
239+ self .pool_type = pool_type
240+ assert input_fmt in ('NHWC' , 'NLC' )
241+ self .pool_dim = 1 if input_fmt == 'NLC' else (1 , 2 )
242+ norm_layer = get_norm_layer (norm_layer )
243+ act_layer = get_act_layer (act_layer )
244+
245+ self .norm = norm_layer (in_features )
246+ if hidden_size :
247+ self .pre_logits = nn .Sequential (OrderedDict ([
248+ ('fc' , nn .Linear (in_features , hidden_size )),
249+ ('act' , act_layer ()),
250+ ]))
251+ self .num_features = hidden_size
252+ else :
253+ self .pre_logits = nn .Identity ()
254+ self .drop = nn .Dropout (drop_rate )
255+ self .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
256+
257+ def reset (self , num_classes : int , pool_type : Optional [str ] = None , reset_other : bool = False ):
258+ if pool_type is not None :
259+ self .pool_type = pool_type
260+ if reset_other :
261+ self .pre_logits = nn .Identity ()
262+ self .norm = nn .Identity ()
263+ self .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
264+
265+ def _global_pool (self , x ):
266+ if self .pool_type :
267+ if self .pool_type == 'avg' :
268+ x = x .mean (dim = self .pool_dim )
269+ elif self .pool_type == 'max' :
270+ x = x .amax (dim = self .pool_dim )
271+ elif self .pool_type == 'avgmax' :
272+ x = 0.5 * (x .amax (dim = self .pool_dim ) + x .mean (dim = self .pool_dim ))
273+ return x
274+
275+ def forward (self , x , pre_logits : bool = False ):
276+ x = self ._global_pool (x )
277+ x = self .norm (x )
278+ x = self .pre_logits (x )
279+ x = self .drop (x )
280+ if pre_logits :
281+ return x
282+ x = self .fc (x )
283+ return x
0 commit comments