@@ -109,6 +109,7 @@ def __init__(
109109 dim ,
110110 dim_out = None ,
111111 stride = 1 ,
112+ dilation = 1 ,
112113 mlp_ratio = 4 ,
113114 conv_mlp = False ,
114115 conv_bias = True ,
@@ -124,7 +125,8 @@ def __init__(
124125 mlp_layer = ConvMlp if conv_mlp else Mlp
125126 self .use_conv_mlp = conv_mlp
126127
127- self .conv_dw = create_conv2d (dim , dim_out , kernel_size = 7 , stride = stride , depthwise = True , bias = conv_bias )
128+ self .conv_dw = create_conv2d (
129+ dim , dim_out , kernel_size = 7 , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
128130 self .norm = norm_layer (dim_out )
129131 self .mlp = mlp_layer (dim_out , int (mlp_ratio * dim_out ), act_layer = act_layer )
130132 self .gamma = nn .Parameter (ls_init_value * torch .ones (dim_out )) if ls_init_value > 0 else None
@@ -156,6 +158,7 @@ def __init__(
156158 out_chs ,
157159 stride = 2 ,
158160 depth = 2 ,
161+ dilation = (1 , 1 ),
159162 drop_path_rates = None ,
160163 ls_init_value = 1.0 ,
161164 conv_mlp = False ,
@@ -166,10 +169,14 @@ def __init__(
166169 super ().__init__ ()
167170 self .grad_checkpointing = False
168171
169- if in_chs != out_chs or stride > 1 :
172+ if in_chs != out_chs or stride > 1 or dilation [0 ] != dilation [1 ]:
173+ ds_ks = 2 if stride > 1 or dilation [0 ] != dilation [1 ] else 1
174+ pad = 'same' if dilation [1 ] > 1 else 0 # same padding needed if dilation used
170175 self .downsample = nn .Sequential (
171176 norm_layer (in_chs ),
172- nn .Conv2d (in_chs , out_chs , kernel_size = stride , stride = stride , bias = conv_bias ),
177+ create_conv2d (
178+ in_chs , out_chs , kernel_size = ds_ks , stride = stride ,
179+ dilation = dilation [0 ], padding = pad , bias = conv_bias ),
173180 )
174181 in_chs = out_chs
175182 else :
@@ -181,6 +188,7 @@ def __init__(
181188 stage_blocks .append (ConvNeXtBlock (
182189 dim = in_chs ,
183190 dim_out = out_chs ,
191+ dilation = dilation [1 ],
184192 drop_path = drop_path_rates [i ],
185193 ls_init_value = ls_init_value ,
186194 conv_mlp = conv_mlp ,
@@ -235,7 +243,7 @@ def __init__(
235243 drop_path_rate = 0. ,
236244 ):
237245 super ().__init__ ()
238- assert output_stride == 32
246+ assert output_stride in ( 8 , 16 , 32 )
239247 if norm_layer is None :
240248 norm_layer = partial (LayerNorm2d , eps = 1e-6 )
241249 norm_layer_cl = norm_layer if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 )
@@ -263,22 +271,27 @@ def __init__(
263271 padding = stem_kernel_size // 2 , bias = conv_bias ),
264272 norm_layer (dims [0 ]),
265273 )
266- prev_chs = dims [0 ]
267- curr_stride = stem_stride
268274
269275 self .stages = nn .Sequential ()
270276 dp_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
271277 stages = []
278+ prev_chs = dims [0 ]
279+ curr_stride = stem_stride
280+ dilation = 1
272281 # 4 feature resolution stages, each consisting of multiple residual blocks
273282 for i in range (4 ):
274283 stride = 2 if curr_stride == 2 or i > 0 else 1
275- # FIXME support dilation / output_stride
284+ if curr_stride >= output_stride and stride > 1 :
285+ dilation *= stride
286+ stride = 1
276287 curr_stride *= stride
288+ first_dilation = 1 if dilation in (1 , 2 ) else 2
277289 out_chs = dims [i ]
278290 stages .append (ConvNeXtStage (
279291 prev_chs ,
280292 out_chs ,
281293 stride = stride ,
294+ dilation = (first_dilation , dilation ),
282295 depth = depths [i ],
283296 drop_path_rates = dp_rates [i ],
284297 ls_init_value = ls_init_value ,
0 commit comments