Skip to content

Commit d793deb

Browse files
committed
Merge branch 'master' of https://github.com/iamhankai/pytorch-image-models into iamhankai-master
2 parents e685618 + de445e7 commit d793deb

File tree

3 files changed

+327
-3
lines changed

3 files changed

+327
-3
lines changed

tests/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def test_model_default_cfgs(model_name, batch_size):
116116
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
117117
outputs = model.forward(input_tensor)
118118
assert len(outputs.shape) == 4
119-
if not isinstance(model, timm.models.MobileNetV3):
120-
# FIXME mobilenetv3 forward_features vs removed pooling differ
119+
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
120+
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
121121
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
122122

123123
# check classifier name matches default_cfg
@@ -150,7 +150,7 @@ def test_model_features_pretrained(model_name, batch_size):
150150

151151
EXCLUDE_JIT_FILTERS = [
152152
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
153-
'dla*', 'hrnet*', # hopefully fix at some point
153+
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
154154
]
155155

156156

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .dla import *
66
from .dpn import *
77
from .efficientnet import *
8+
from .ghostnet import *
89
from .gluon_resnet import *
910
from .gluon_xception import *
1011
from .hardcorenas import *

timm/models/ghostnet.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
"""
2+
An implementation of GhostNet Model as defined in:
3+
GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
4+
The train script of the model is similar to that of MobileNetV3
5+
Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
6+
"""
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
import math
11+
12+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13+
from .layers import SelectAdaptivePool2d
14+
from .helpers import build_model_with_cfg
15+
from .registry import register_model
16+
17+
18+
__all__ = ['GhostNet']
19+
20+
21+
def _cfg(url='', **kwargs):
22+
return {
23+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
24+
'crop_pct': 0.875, 'interpolation': 'bilinear',
25+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
26+
'first_conv': 'conv_stem', 'classifier': 'classifier',
27+
**kwargs
28+
}
29+
30+
31+
default_cfgs = {
32+
'ghostnet_050': _cfg(url=''),
33+
'ghostnet_100': _cfg(
34+
url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
35+
'ghostnet_130': _cfg(url=''),
36+
}
37+
38+
39+
def _make_divisible(v, divisor, min_value=None):
40+
"""
41+
This function is taken from the original tf repo.
42+
It ensures that all layers have a channel number that is divisible by 8
43+
It can be seen here:
44+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
45+
"""
46+
if min_value is None:
47+
min_value = divisor
48+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
49+
# Make sure that round down does not go down by more than 10%.
50+
if new_v < 0.9 * v:
51+
new_v += divisor
52+
return new_v
53+
54+
55+
def hard_sigmoid(x, inplace: bool = False):
56+
if inplace:
57+
return x.add_(3.).clamp_(0., 6.).div_(6.)
58+
else:
59+
return F.relu6(x + 3.) / 6.
60+
61+
62+
class SqueezeExcite(nn.Module):
63+
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
64+
act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_):
65+
super(SqueezeExcite, self).__init__()
66+
self.gate_fn = gate_fn
67+
reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
68+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
69+
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
70+
self.act1 = act_layer(inplace=True)
71+
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
72+
73+
def forward(self, x):
74+
x_se = self.avg_pool(x)
75+
x_se = self.conv_reduce(x_se)
76+
x_se = self.act1(x_se)
77+
x_se = self.conv_expand(x_se)
78+
x = x * self.gate_fn(x_se)
79+
return x
80+
81+
82+
class ConvBnAct(nn.Module):
83+
def __init__(self, in_chs, out_chs, kernel_size,
84+
stride=1, act_layer=nn.ReLU):
85+
super(ConvBnAct, self).__init__()
86+
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False)
87+
self.bn1 = nn.BatchNorm2d(out_chs)
88+
self.act1 = act_layer(inplace=True)
89+
90+
def forward(self, x):
91+
x = self.conv(x)
92+
x = self.bn1(x)
93+
x = self.act1(x)
94+
return x
95+
96+
97+
class GhostModule(nn.Module):
98+
def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
99+
super(GhostModule, self).__init__()
100+
self.oup = oup
101+
init_channels = math.ceil(oup / ratio)
102+
new_channels = init_channels*(ratio-1)
103+
104+
self.primary_conv = nn.Sequential(
105+
nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
106+
nn.BatchNorm2d(init_channels),
107+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
108+
)
109+
110+
self.cheap_operation = nn.Sequential(
111+
nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
112+
nn.BatchNorm2d(new_channels),
113+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
114+
)
115+
116+
def forward(self, x):
117+
x1 = self.primary_conv(x)
118+
x2 = self.cheap_operation(x1)
119+
out = torch.cat([x1,x2], dim=1)
120+
return out[:,:self.oup,:,:]
121+
122+
123+
class GhostBottleneck(nn.Module):
124+
""" Ghost bottleneck w/ optional SE"""
125+
126+
def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
127+
stride=1, act_layer=nn.ReLU, se_ratio=0.):
128+
super(GhostBottleneck, self).__init__()
129+
has_se = se_ratio is not None and se_ratio > 0.
130+
self.stride = stride
131+
132+
# Point-wise expansion
133+
self.ghost1 = GhostModule(in_chs, mid_chs, relu=True)
134+
135+
# Depth-wise convolution
136+
if self.stride > 1:
137+
self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride,
138+
padding=(dw_kernel_size-1)//2,
139+
groups=mid_chs, bias=False)
140+
self.bn_dw = nn.BatchNorm2d(mid_chs)
141+
142+
# Squeeze-and-excitation
143+
if has_se:
144+
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio)
145+
else:
146+
self.se = None
147+
148+
# Point-wise linear projection
149+
self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)
150+
151+
# shortcut
152+
if (in_chs == out_chs and self.stride == 1):
153+
self.shortcut = nn.Sequential()
154+
else:
155+
self.shortcut = nn.Sequential(
156+
nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride,
157+
padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
158+
nn.BatchNorm2d(in_chs),
159+
nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
160+
nn.BatchNorm2d(out_chs),
161+
)
162+
163+
164+
def forward(self, x):
165+
residual = x
166+
167+
# 1st ghost bottleneck
168+
x = self.ghost1(x)
169+
170+
# Depth-wise convolution
171+
if self.stride > 1:
172+
x = self.conv_dw(x)
173+
x = self.bn_dw(x)
174+
175+
# Squeeze-and-excitation
176+
if self.se is not None:
177+
x = self.se(x)
178+
179+
# 2nd ghost bottleneck
180+
x = self.ghost2(x)
181+
182+
x += self.shortcut(residual)
183+
return x
184+
185+
186+
class GhostNet(nn.Module):
187+
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3):
188+
super(GhostNet, self).__init__()
189+
# setting of inverted residual blocks
190+
self.cfgs = cfgs
191+
self.num_classes = num_classes
192+
self.dropout = dropout
193+
self.feature_info = []
194+
195+
# building first layer
196+
output_channel = _make_divisible(16 * width, 4)
197+
self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False)
198+
self.feature_info.append(dict(num_chs=output_channel, reduction=2, module=f'conv_stem'))
199+
self.bn1 = nn.BatchNorm2d(output_channel)
200+
self.act1 = nn.ReLU(inplace=True)
201+
input_channel = output_channel
202+
203+
# building inverted residual blocks
204+
stages = nn.ModuleList([])
205+
block = GhostBottleneck
206+
stage_idx = 0
207+
for cfg in self.cfgs:
208+
layers = []
209+
for k, exp_size, c, se_ratio, s in cfg:
210+
output_channel = _make_divisible(c * width, 4)
211+
hidden_channel = _make_divisible(exp_size * width, 4)
212+
layers.append(block(input_channel, hidden_channel, output_channel, k, s,
213+
se_ratio=se_ratio))
214+
input_channel = output_channel
215+
if s > 1:
216+
self.feature_info.append(dict(num_chs=output_channel, reduction=2**(stage_idx+2),
217+
module=f'blocks.{stage_idx}'))
218+
stages.append(nn.Sequential(*layers))
219+
stage_idx += 1
220+
221+
output_channel = _make_divisible(exp_size * width, 4)
222+
stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
223+
self.pool_dim = input_channel = output_channel
224+
225+
self.blocks = nn.Sequential(*stages)
226+
227+
# building last several layers
228+
self.num_features = output_channel = 1280
229+
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
230+
self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
231+
self.act2 = nn.ReLU(inplace=True)
232+
self.classifier = nn.Linear(output_channel, num_classes)
233+
234+
def get_classifier(self):
235+
return self.classifier
236+
237+
def reset_classifier(self, num_classes, global_pool='avg'):
238+
self.num_classes = num_classes
239+
# cannot meaningfully change pooling of efficient head after creation
240+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
241+
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
242+
243+
def forward_features(self, x):
244+
x = self.conv_stem(x)
245+
x = self.bn1(x)
246+
x = self.act1(x)
247+
x = self.blocks(x)
248+
x = self.global_pool(x)
249+
x = self.conv_head(x)
250+
x = self.act2(x)
251+
return x
252+
253+
def forward(self, x):
254+
x = self.forward_features(x)
255+
if not self.global_pool.is_identity():
256+
x = x.view(x.size(0), -1)
257+
if self.dropout > 0.:
258+
x = F.dropout(x, p=self.dropout, training=self.training)
259+
x = self.classifier(x)
260+
return x
261+
262+
263+
def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
264+
"""
265+
Constructs a GhostNet model
266+
"""
267+
cfgs = [
268+
# k, t, c, SE, s
269+
# stage1
270+
[[3, 16, 16, 0, 1]],
271+
# stage2
272+
[[3, 48, 24, 0, 2]],
273+
[[3, 72, 24, 0, 1]],
274+
# stage3
275+
[[5, 72, 40, 0.25, 2]],
276+
[[5, 120, 40, 0.25, 1]],
277+
# stage4
278+
[[3, 240, 80, 0, 2]],
279+
[[3, 200, 80, 0, 1],
280+
[3, 184, 80, 0, 1],
281+
[3, 184, 80, 0, 1],
282+
[3, 480, 112, 0.25, 1],
283+
[3, 672, 112, 0.25, 1]
284+
],
285+
# stage5
286+
[[5, 672, 160, 0.25, 2]],
287+
[[5, 960, 160, 0, 1],
288+
[5, 960, 160, 0.25, 1],
289+
[5, 960, 160, 0, 1],
290+
[5, 960, 160, 0.25, 1]
291+
]
292+
]
293+
model_kwargs = dict(
294+
cfgs=cfgs,
295+
width=width,
296+
**kwargs,
297+
)
298+
return build_model_with_cfg(
299+
GhostNet, variant, pretrained,
300+
default_cfg=default_cfgs[variant],
301+
feature_cfg=dict(flatten_sequential=True),
302+
**model_kwargs)
303+
304+
305+
@register_model
306+
def ghostnet_050(pretrained=False, **kwargs):
307+
""" GhostNet-0.5x """
308+
model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
309+
return model
310+
311+
312+
@register_model
313+
def ghostnet_100(pretrained=False, **kwargs):
314+
""" GhostNet-1.0x """
315+
model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
316+
return model
317+
318+
319+
@register_model
320+
def ghostnet_130(pretrained=False, **kwargs):
321+
""" GhostNet-1.3x """
322+
model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
323+
return model

0 commit comments

Comments
 (0)