@@ -1171,7 +1171,7 @@ def __init__(
11711171 self .add_module (layer_name , layer )
11721172 else :
11731173 # Classifier head
1174- final_features = int (embed_dims [- 1 ] * cls_ratio )
1174+ self . num_features = final_features = int (embed_dims [- 1 ] * cls_ratio )
11751175 self .final_conv = MobileOneBlock (
11761176 in_chs = embed_dims [- 1 ],
11771177 out_chs = final_features ,
@@ -1182,7 +1182,6 @@ def __init__(
11821182 use_se = True ,
11831183 num_conv_branches = 1 ,
11841184 )
1185- self .num_features = final_features
11861185 self .head = ClassifierHead (
11871186 final_features ,
11881187 num_classes ,
@@ -1241,11 +1240,10 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
12411240 if self .fork_feat :
12421241 # output the features of four stages for dense prediction
12431242 return outs
1244- # output only the features of last layer for image classification
1243+ x = self . final_conv ( x )
12451244 return x
12461245
12471246 def forward_head (self , x : torch .Tensor , pre_logits : bool = False ):
1248- x = self .final_conv (x )
12491247 return self .head (x , pre_logits = True ) if pre_logits else self .head (x )
12501248
12511249 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -1266,6 +1264,7 @@ def _cfg(url="", **kwargs):
12661264 "interpolation" : "bicubic" ,
12671265 "mean" : IMAGENET_DEFAULT_MEAN ,
12681266 "std" : IMAGENET_DEFAULT_STD ,
1267+ 'first_conv' : 'stem.0.conv_kxk.0.conv' ,
12691268 "classifier" : "head.fc" ,
12701269 ** kwargs ,
12711270 }
0 commit comments