Skip to content

Commit 2ed8f24

Browse files
committed
A few more changes for 0.3.2 maint release. Linear layer change for mobilenetv3 and inception_v3, support no bias for linear wrapper.
1 parent 6504a42 commit 2ed8f24

File tree

7 files changed

+13
-10
lines changed

7 files changed

+13
-10
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_model_load_pretrained(model_name, batch_size):
121121
create_model(model_name, pretrained=True, in_chans=in_chans)
122122

123123
@pytest.mark.timeout(120)
124-
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
124+
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*']))
125125
@pytest.mark.parametrize('batch_size', [1])
126126
def test_model_features_pretrained(model_name, batch_size):
127127
"""Create that pretrained weights load when features_only==True."""

timm/models/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.utils.model_zoo as model_zoo
1515

1616
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
17-
from .layers import Conv2dSame
17+
from .layers import Conv2dSame, Linear
1818

1919

2020
_logger = logging.getLogger(__name__)
@@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string):
234234
if isinstance(old_module, nn.Linear):
235235
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
236236
num_features = state_dict[n + '.weight'][1]
237-
new_fc = nn.Linear(
237+
new_fc = Linear(
238238
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
239239
set_layer(new_module, n, new_fc)
240240
if hasattr(new_module, 'num_features'):

timm/models/inception_v3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1111
from .helpers import build_model_with_cfg
1212
from .registry import register_model
13-
from .layers import trunc_normal_, create_classifier
13+
from .layers import trunc_normal_, create_classifier, Linear
1414

1515

1616
def _cfg(url='', **kwargs):
@@ -250,7 +250,7 @@ def __init__(self, in_channels, num_classes, conv_block=None):
250250
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
251251
self.conv1 = conv_block(128, 768, kernel_size=5)
252252
self.conv1.stddev = 0.01
253-
self.fc = nn.Linear(768, num_classes)
253+
self.fc = Linear(768, num_classes)
254254
self.fc.stddev = 0.001
255255

256256
def forward(self, x):

timm/models/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
1919
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
2020
from .inplace_abn import InplaceAbn
21+
from .linear import Linear
2122
from .mixed_conv2d import MixedConv2d
2223
from .norm_act import BatchNormAct2d
2324
from .padding import get_padding

timm/models/layers/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class Linear(nn.Linear):
1313
"""
1414
def forward(self, input: torch.Tensor) -> torch.Tensor:
1515
if torch.jit.is_scripting():
16-
return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype))
16+
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17+
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
1718
else:
18-
return F.linear(input, self.weight, self.bias)
19+
return F.linear(input, self.weight, self.bias)

timm/models/mobilenetv3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
1919
from .features import FeatureInfo, FeatureHooks
2020
from .helpers import build_model_with_cfg
21-
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
21+
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
2222
from .registry import register_model
2323

2424
__all__ = ['MobileNetV3']
@@ -105,7 +105,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
105105
num_pooled_chs = head_chs * self.global_pool.feat_mult()
106106
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
107107
self.act2 = act_layer(inplace=True)
108-
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
108+
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
109109

110110
efficientnet_init_weights(self)
111111

@@ -123,7 +123,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
123123
self.num_classes = num_classes
124124
# cannot meaningfully change pooling of efficient head after creation
125125
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
126-
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
126+
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
127127

128128
def forward_features(self, x):
129129
x = self.conv_stem(x)

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def main():
327327
bn_tf=args.bn_tf,
328328
bn_momentum=args.bn_momentum,
329329
bn_eps=args.bn_eps,
330+
scriptable=args.torchscript,
330331
checkpoint_path=args.initial_checkpoint)
331332

332333
if args.local_rank == 0:

0 commit comments

Comments
 (0)