@@ -134,11 +134,15 @@ def __init__(
134134 self .query = nn .Sequential ()
135135 if self .has_query_strides :
136136 # FIXME dilation
137- self .query .add_module ('down_pool' , create_pool2d (
138- 'avg' ,
139- kernel_size = self .query_strides ,
140- padding = padding ,
141- ))
137+ if padding == 'same' :
138+ self .query .add_module ('down_pool' , create_pool2d (
139+ 'avg' ,
140+ kernel_size = self .query_strides ,
141+ padding = 'same' ,
142+ ))
143+ else :
144+ # no pad if not 'same' as kern=stride=even
145+ self .query .add_module ('down_pool' , nn .AvgPool2d (kernel_size = query_strides ))
142146 self .query .add_module ('norm' , norm_layer (dim ))
143147 self .query .add_module ('proj' , create_conv2d (
144148 dim ,
@@ -190,7 +194,7 @@ def __init__(
190194
191195 self .output = nn .Sequential ()
192196 if self .has_query_strides :
193- self .output .add_module ('upsample' , nn .Upsample (self .query_strides , mode = 'bilinear' , align_corners = False ))
197+ self .output .add_module ('upsample' , nn .Upsample (scale_factor = self .query_strides , mode = 'bilinear' , align_corners = False ))
194198 self .output .add_module ('proj' , create_conv2d (
195199 self .value_dim * self .num_heads ,
196200 dim_out ,
0 commit comments