Skip to content

Commit c5e0d1c

Browse files
committed
Add dilation support to convnext, allows output_stride=8 and 16 use. Fix #1341
1 parent 5e7d47c commit c5e0d1c

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

timm/models/convnext.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
dim,
110110
dim_out=None,
111111
stride=1,
112+
dilation=1,
112113
mlp_ratio=4,
113114
conv_mlp=False,
114115
conv_bias=True,
@@ -124,7 +125,8 @@ def __init__(
124125
mlp_layer = ConvMlp if conv_mlp else Mlp
125126
self.use_conv_mlp = conv_mlp
126127

127-
self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
128+
self.conv_dw = create_conv2d(
129+
dim, dim_out, kernel_size=7, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
128130
self.norm = norm_layer(dim_out)
129131
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
130132
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
@@ -156,6 +158,7 @@ def __init__(
156158
out_chs,
157159
stride=2,
158160
depth=2,
161+
dilation=(1, 1),
159162
drop_path_rates=None,
160163
ls_init_value=1.0,
161164
conv_mlp=False,
@@ -166,10 +169,14 @@ def __init__(
166169
super().__init__()
167170
self.grad_checkpointing = False
168171

169-
if in_chs != out_chs or stride > 1:
172+
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
173+
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
174+
pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
170175
self.downsample = nn.Sequential(
171176
norm_layer(in_chs),
172-
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
177+
create_conv2d(
178+
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
179+
dilation=dilation[0], padding=pad, bias=conv_bias),
173180
)
174181
in_chs = out_chs
175182
else:
@@ -181,6 +188,7 @@ def __init__(
181188
stage_blocks.append(ConvNeXtBlock(
182189
dim=in_chs,
183190
dim_out=out_chs,
191+
dilation=dilation[1],
184192
drop_path=drop_path_rates[i],
185193
ls_init_value=ls_init_value,
186194
conv_mlp=conv_mlp,
@@ -235,7 +243,7 @@ def __init__(
235243
drop_path_rate=0.,
236244
):
237245
super().__init__()
238-
assert output_stride == 32
246+
assert output_stride in (8, 16, 32)
239247
if norm_layer is None:
240248
norm_layer = partial(LayerNorm2d, eps=1e-6)
241249
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
@@ -263,22 +271,27 @@ def __init__(
263271
padding=stem_kernel_size // 2, bias=conv_bias),
264272
norm_layer(dims[0]),
265273
)
266-
prev_chs = dims[0]
267-
curr_stride = stem_stride
268274

269275
self.stages = nn.Sequential()
270276
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
271277
stages = []
278+
prev_chs = dims[0]
279+
curr_stride = stem_stride
280+
dilation = 1
272281
# 4 feature resolution stages, each consisting of multiple residual blocks
273282
for i in range(4):
274283
stride = 2 if curr_stride == 2 or i > 0 else 1
275-
# FIXME support dilation / output_stride
284+
if curr_stride >= output_stride and stride > 1:
285+
dilation *= stride
286+
stride = 1
276287
curr_stride *= stride
288+
first_dilation = 1 if dilation in (1, 2) else 2
277289
out_chs = dims[i]
278290
stages.append(ConvNeXtStage(
279291
prev_chs,
280292
out_chs,
281293
stride=stride,
294+
dilation=(first_dilation, dilation),
282295
depth=depths[i],
283296
drop_path_rates=dp_rates[i],
284297
ls_init_value=ls_init_value,

0 commit comments

Comments
 (0)