|
20 | 20 | from .features import FeatureListNet, FeatureDictNet, FeatureHookNet |
21 | 21 | from .fx_features import FeatureGraphNet |
22 | 22 | from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf |
23 | | -from .layers import Conv2dSame, Linear |
| 23 | +from .layers import Conv2dSame, Linear, BatchNormAct2d |
24 | 24 | from .registry import get_pretrained_cfg |
25 | 25 |
|
26 | 26 |
|
@@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string): |
374 | 374 | bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, |
375 | 375 | groups=g, stride=old_module.stride) |
376 | 376 | set_layer(new_module, n, new_conv) |
377 | | - if isinstance(old_module, nn.BatchNorm2d): |
| 377 | + elif isinstance(old_module, BatchNormAct2d): |
| 378 | + new_bn = BatchNormAct2d( |
| 379 | + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, |
| 380 | + affine=old_module.affine, track_running_stats=True) |
| 381 | + new_bn.drop = old_module.drop |
| 382 | + new_bn.act = old_module.act |
| 383 | + set_layer(new_module, n, new_bn) |
| 384 | + elif isinstance(old_module, nn.BatchNorm2d): |
378 | 385 | new_bn = nn.BatchNorm2d( |
379 | 386 | num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, |
380 | 387 | affine=old_module.affine, track_running_stats=True) |
381 | 388 | set_layer(new_module, n, new_bn) |
382 | | - if isinstance(old_module, nn.Linear): |
| 389 | + elif isinstance(old_module, nn.Linear): |
383 | 390 | # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? |
384 | 391 | num_features = state_dict[n + '.weight'][1] |
385 | 392 | new_fc = Linear( |
|
0 commit comments