Skip to content

Commit 2180800

Browse files
committed
MQA query_strides bugs fix #2237. No padding for avg_pool2d if not 'same', use scale_factor for Upsample.
1 parent 474c9cf commit 2180800

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

timm/layers/attention2d.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,15 @@ def __init__(
134134
self.query = nn.Sequential()
135135
if self.has_query_strides:
136136
# FIXME dilation
137-
self.query.add_module('down_pool', create_pool2d(
138-
'avg',
139-
kernel_size=self.query_strides,
140-
padding=padding,
141-
))
137+
if padding == 'same':
138+
self.query.add_module('down_pool', create_pool2d(
139+
'avg',
140+
kernel_size=self.query_strides,
141+
padding='same',
142+
))
143+
else:
144+
# no pad if not 'same' as kern=stride=even
145+
self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
142146
self.query.add_module('norm', norm_layer(dim))
143147
self.query.add_module('proj', create_conv2d(
144148
dim,
@@ -190,7 +194,7 @@ def __init__(
190194

191195
self.output = nn.Sequential()
192196
if self.has_query_strides:
193-
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
197+
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
194198
self.output.add_module('proj', create_conv2d(
195199
self.value_dim * self.num_heads,
196200
dim_out,

0 commit comments

Comments
 (0)