Skip to content

Commit 16334e4

Browse files
committed
Fix two fastvit issues
1 parent 5242ba6 commit 16334e4

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

timm/models/fastvit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ def __init__(
11711171
self.add_module(layer_name, layer)
11721172
else:
11731173
# Classifier head
1174-
final_features = int(embed_dims[-1] * cls_ratio)
1174+
self.num_features = final_features = int(embed_dims[-1] * cls_ratio)
11751175
self.final_conv = MobileOneBlock(
11761176
in_chs=embed_dims[-1],
11771177
out_chs=final_features,
@@ -1182,7 +1182,6 @@ def __init__(
11821182
use_se=True,
11831183
num_conv_branches=1,
11841184
)
1185-
self.num_features = final_features
11861185
self.head = ClassifierHead(
11871186
final_features,
11881187
num_classes,
@@ -1241,11 +1240,10 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
12411240
if self.fork_feat:
12421241
# output the features of four stages for dense prediction
12431242
return outs
1244-
# output only the features of last layer for image classification
1243+
x = self.final_conv(x)
12451244
return x
12461245

12471246
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
1248-
x = self.final_conv(x)
12491247
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
12501248

12511249
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1266,6 +1264,7 @@ def _cfg(url="", **kwargs):
12661264
"interpolation": "bicubic",
12671265
"mean": IMAGENET_DEFAULT_MEAN,
12681266
"std": IMAGENET_DEFAULT_STD,
1267+
'first_conv': 'stem.0.conv_kxk.0.conv',
12691268
"classifier": "head.fc",
12701269
**kwargs,
12711270
}

0 commit comments

Comments
 (0)