Skip to content

Commit dc51334

Browse files
committed
Fix pruned adapt for EfficientNet models that are now using BatchNormAct layers
1 parent 024fc4d commit dc51334

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

timm/models/helpers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
2121
from .fx_features import FeatureGraphNet
2222
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
23-
from .layers import Conv2dSame, Linear
23+
from .layers import Conv2dSame, Linear, BatchNormAct2d
2424
from .registry import get_pretrained_cfg
2525

2626

@@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string):
374374
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
375375
groups=g, stride=old_module.stride)
376376
set_layer(new_module, n, new_conv)
377-
if isinstance(old_module, nn.BatchNorm2d):
377+
elif isinstance(old_module, BatchNormAct2d):
378+
new_bn = BatchNormAct2d(
379+
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
380+
affine=old_module.affine, track_running_stats=True)
381+
new_bn.drop = old_module.drop
382+
new_bn.act = old_module.act
383+
set_layer(new_module, n, new_bn)
384+
elif isinstance(old_module, nn.BatchNorm2d):
378385
new_bn = nn.BatchNorm2d(
379386
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
380387
affine=old_module.affine, track_running_stats=True)
381388
set_layer(new_module, n, new_bn)
382-
if isinstance(old_module, nn.Linear):
389+
elif isinstance(old_module, nn.Linear):
383390
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
384391
num_features = state_dict[n + '.weight'][1]
385392
new_fc = Linear(

0 commit comments

Comments
 (0)