Skip to content

Commit d4c21b9

Browse files
thohemprwightman
authored andcommitted
Update repghost.py
1 parent 7eb7d13 commit d4c21b9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/repghost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
289289
# cannot meaningfully change pooling of efficient head after creation
290290
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
291291
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
292-
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
292+
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
293293

294294
def forward_features(self, x):
295295
x = self.conv_stem(x)

0 commit comments

Comments
 (0)