Skip to content

Commit 9611458

Browse files
committed
Throw in some FBNetV3 code I had lying around, some refactoring of SE reduction channel calcs for all EffNet archs.
1 parent 01b9108 commit 9611458

File tree

7 files changed

+135
-31
lines changed

7 files changed

+135
-31
lines changed

timm/models/byobnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _cfg(url='', **kwargs):
9090
# experimental configs
9191
'resnet51q': _cfg(
9292
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
93-
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8),
93+
first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
9494
test_input_size=(3, 288, 288), crop_pct=1.0),
9595
'resnet61q': _cfg(
9696
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),

timm/models/efficientnet_blocks.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,16 @@ class SqueezeExcite(nn.Module):
2222
se_ratio (float): ratio of squeeze reduction
2323
act_layer (nn.Module): activation layer of containing block
2424
gate_fn (Callable): attention gate function
25-
block_in_chs (int): input channels of containing block (for calculating reduction from)
26-
reduce_from_block (bool): calculate reduction from block input channels if True
2725
force_act_layer (nn.Module): override block's activation fn if this is set/bound
28-
divisor (int): make reduction channels divisible by this
26+
round_chs_fn (Callable): specify a fn to calculate rounding of reduced chs
2927
"""
3028

3129
def __init__(
3230
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
33-
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
31+
force_act_layer=None, round_chs_fn=None):
3432
super(SqueezeExcite, self).__init__()
35-
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
36-
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
33+
round_chs_fn = round_chs_fn or round
34+
reduced_chs = round_chs_fn(in_chs * se_ratio)
3735
act_layer = force_act_layer or act_layer
3836
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
3937
self.act1 = create_act_layer(act_layer, inplace=True)
@@ -168,8 +166,7 @@ def __init__(
168166
self.act2 = act_layer(inplace=True)
169167

170168
# Squeeze-and-excitation
171-
self.se = se_layer(
172-
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
169+
self.se = se_layer(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
173170

174171
# Point-wise linear projection
175172
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
@@ -292,8 +289,7 @@ def __init__(
292289
self.act1 = act_layer(inplace=True)
293290

294291
# Squeeze-and-excitation
295-
self.se = SqueezeExcite(
296-
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
292+
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
297293

298294
# Point-wise linear projection
299295
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)

timm/models/efficientnet_builder.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,12 @@ class EfficientNetBuilder:
265265
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
266266
267267
"""
268-
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
268+
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
269269
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
270270
self.output_stride = output_stride
271271
self.pad_type = pad_type
272272
self.round_chs_fn = round_chs_fn
273+
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
273274
self.act_layer = act_layer
274275
self.norm_layer = norm_layer
275276
self.se_layer = se_layer
@@ -301,6 +302,8 @@ def _make_block(self, ba, block_idx, block_count):
301302
ba['norm_layer'] = self.norm_layer
302303
if bt != 'cn':
303304
ba['se_layer'] = self.se_layer
305+
if not self.se_from_exp and ba['se_ratio']:
306+
ba['se_ratio'] /= ba.get('exp_ratio', 1.0)
304307
ba['drop_path_rate'] = drop_path_rate
305308

306309
if bt == 'ir':
@@ -418,28 +421,28 @@ def _init_weight_goog(m, n='', fix_group_fanout=True):
418421
if fix_group_fanout:
419422
fan_out //= m.groups
420423
init_weight_fn = get_condconv_initializer(
421-
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
424+
lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
422425
init_weight_fn(m.weight)
423426
if m.bias is not None:
424-
m.bias.data.zero_()
427+
nn.init.zeros_(m.bias)
425428
elif isinstance(m, nn.Conv2d):
426429
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
427430
if fix_group_fanout:
428431
fan_out //= m.groups
429-
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
432+
nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
430433
if m.bias is not None:
431-
m.bias.data.zero_()
434+
nn.init.zeros_(m.bias)
432435
elif isinstance(m, nn.BatchNorm2d):
433-
m.weight.data.fill_(1.0)
434-
m.bias.data.zero_()
436+
nn.init.ones_(m.weight)
437+
nn.init.zeros_(m.bias)
435438
elif isinstance(m, nn.Linear):
436439
fan_out = m.weight.size(0) # fan-out
437440
fan_in = 0
438441
if 'routing_fn' in n:
439442
fan_in = m.weight.size(1)
440443
init_range = 1.0 / math.sqrt(fan_in + fan_out)
441-
m.weight.data.uniform_(-init_range, init_range)
442-
m.bias.data.zero_()
444+
nn.init.uniform_(m.weight, -init_range, init_range)
445+
nn.init.zeros_(m.bias)
443446

444447

445448
def efficientnet_init_weights(model: nn.Module, init_fn=None):

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _cfg(url='', **kwargs):
4040
}
4141

4242

43-
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4)
43+
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', round_chs_fn=partial(make_divisible, divisor=4))
4444

4545

4646
class GhostModule(nn.Module):

timm/models/hardcorenas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
66
from .efficientnet_blocks import SqueezeExcite
7-
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
7+
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
88
from .helpers import build_model_with_cfg, default_cfg_for_features
99
from .layers import get_act_fn
1010
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
@@ -40,7 +40,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
4040
"""
4141
num_features = 1280
4242
se_layer = partial(
43-
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
43+
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels)
4444
model_kwargs = dict(
4545
block_args=decode_arch_def(arch_def),
4646
num_features=num_features,

timm/models/layers/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
2828
# Make sure that round down does not go down by more than 10%.
2929
if new_v < round_limit * v:
3030
new_v += divisor
31-
return new_v
31+
return new_v

timm/models/mobilenetv3.py

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def _cfg(url='', **kwargs):
7272
'tf_mobilenetv3_small_minimal_100': _cfg(
7373
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
7474
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
75+
76+
'fbnetv3_b': _cfg(),
77+
'fbnetv3_d': _cfg(),
78+
'fbnetv3_g': _cfg(),
7579
}
7680

7781

@@ -86,7 +90,7 @@ class MobileNetV3(nn.Module):
8690
"""
8791

8892
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
89-
pad_type='', act_layer=None, norm_layer=None, se_layer=None,
93+
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
9094
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
9195
super(MobileNetV3, self).__init__()
9296
act_layer = act_layer or nn.ReLU
@@ -104,7 +108,7 @@ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_f
104108

105109
# Middle stages (IR/ER/DS Blocks)
106110
builder = EfficientNetBuilder(
107-
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn,
111+
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
108112
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
109113
self.blocks = nn.Sequential(*builder(stem_size, block_args))
110114
self.feature_info = builder.features
@@ -161,8 +165,8 @@ class MobileNetV3Features(nn.Module):
161165
and object detection models.
162166
"""
163167

164-
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
165-
in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels,
168+
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
169+
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
166170
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
167171
super(MobileNetV3Features, self).__init__()
168172
act_layer = act_layer or nn.ReLU
@@ -178,7 +182,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bo
178182

179183
# Middle stages (IR/ER/DS Blocks)
180184
builder = EfficientNetBuilder(
181-
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
185+
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
182186
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
183187
drop_path_rate=drop_path_rate, feature_location=feature_location)
184188
self.blocks = nn.Sequential(*builder(stem_size, block_args))
@@ -262,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
262266
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
263267
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
264268
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
265-
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False),
269+
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid')),
266270
**kwargs,
267271
)
268272
model = _create_mnv3(variant, pretrained, **model_kwargs)
@@ -351,7 +355,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
351355
['cn_r1_k1_s1_c960'], # hard-swish
352356
]
353357
se_layer = partial(
354-
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
358+
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, round_chs_fn=round_channels)
355359
model_kwargs = dict(
356360
block_args=decode_arch_def(arch_def),
357361
num_features=num_features,
@@ -366,6 +370,86 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
366370
return model
367371

368372

373+
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
374+
""" FBNetV3
375+
FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
376+
"""
377+
vl = variant.split('_')[-1]
378+
if vl in ('a', 'b'):
379+
stem_size = 16
380+
arch_def = [
381+
# stage 0, 112x112 in
382+
['ds_r2_k3_s1_e1_c16'],
383+
# stage 1, 112x112 in
384+
['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
385+
# stage 2, 56x56 in
386+
['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
387+
# stage 3, 28x28 in
388+
['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
389+
# stage 4, 14x14in
390+
['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
391+
# stage 5, 14x14in
392+
['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
393+
# stage 6, 7x7 in
394+
['cn_r1_k1_s1_c1344'],
395+
]
396+
elif vl == 'd':
397+
stem_size = 24
398+
arch_def = [
399+
# stage 0, 112x112 in
400+
['ds_r2_k3_s1_e1_c16'],
401+
# stage 1, 112x112 in
402+
['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
403+
# stage 2, 56x56 in
404+
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
405+
# stage 3, 28x28 in
406+
['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
407+
# stage 4, 14x14in
408+
['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
409+
# stage 5, 14x14in
410+
['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
411+
# stage 6, 7x7 in
412+
['cn_r1_k1_s1_c1440'],
413+
]
414+
elif vl == 'g':
415+
stem_size = 32
416+
arch_def = [
417+
# stage 0, 112x112 in
418+
['ds_r3_k3_s1_e1_c24'],
419+
# stage 1, 112x112 in
420+
['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
421+
# stage 2, 56x56 in
422+
['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
423+
# stage 3, 28x28 in
424+
['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
425+
# stage 4, 14x14in
426+
['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
427+
# stage 5, 14x14in
428+
['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
429+
# stage 6, 7x7 in
430+
['cn_r1_k1_s1_c1728'], # hard-swish
431+
]
432+
else:
433+
raise NotImplemented
434+
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
435+
se_layer = partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), round_chs_fn=round_chs_fn)
436+
act_layer = resolve_act_layer(kwargs, 'hard_swish')
437+
model_kwargs = dict(
438+
block_args=decode_arch_def(arch_def),
439+
num_features=1984,
440+
head_bias=False,
441+
stem_size=stem_size,
442+
round_chs_fn=round_chs_fn,
443+
se_from_exp=False,
444+
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
445+
act_layer=act_layer,
446+
se_layer=se_layer,
447+
**kwargs,
448+
)
449+
model = _create_mnv3(variant, pretrained, **model_kwargs)
450+
return model
451+
452+
369453
@register_model
370454
def mobilenetv3_large_075(pretrained=False, **kwargs):
371455
""" MobileNet V3 """
@@ -474,3 +558,24 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
474558
kwargs['pad_type'] = 'same'
475559
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
476560
return model
561+
562+
563+
@register_model
564+
def fbnetv3_b(pretrained=False, **kwargs):
565+
""" FBNetV3-B """
566+
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
567+
return model
568+
569+
570+
@register_model
571+
def fbnetv3_d(pretrained=False, **kwargs):
572+
""" FBNetV3-D """
573+
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
574+
return model
575+
576+
577+
@register_model
578+
def fbnetv3_g(pretrained=False, **kwargs):
579+
""" FBNetV3-G """
580+
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
581+
return model

0 commit comments

Comments
 (0)