Skip to content

Commit 6cdeca2

Browse files
rwightmanchris-ha458
authored andcommitted
Some cleanup and fixes for initial BlurPool impl. Still some testing and tweaks to go...
1 parent acd1b6c commit 6cdeca2

File tree

2 files changed

+43
-52
lines changed

2 files changed

+43
-52
lines changed

timm/models/layers/blurpool.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
'''
1+
"""
22
BlurPool layer inspired by
3-
Kornia's Max_BlurPool2d
4-
and
5-
Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
3+
- Kornia's Max_BlurPool2d
4+
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5+
6+
Hacked together by Chris Ha and Ross Wightman
7+
"""
68

7-
'''
89

910
import torch
1011
import torch.nn as nn
1112
import torch.nn.functional as F
13+
import numpy as np
14+
from .padding import get_padding
1215

1316

1417
class BlurPool2d(nn.Module):
@@ -25,30 +28,30 @@ class BlurPool2d(nn.Module):
2528
Examples:
2629
"""
2730

28-
def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None:
31+
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
2932
super(BlurPool2d, self).__init__()
30-
assert blur_filter_size in [3, 5]
33+
assert blur_filter_size > 1
3134
self.channels = channels
3235
self.blur_filter_size = blur_filter_size
3336
self.stride = stride
3437

35-
if blur_filter_size == 3:
36-
pad_size = [1] * 4
37-
blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2
38-
else:
39-
pad_size = [2] * 4
40-
blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4
41-
38+
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
4239
self.padding = nn.ReflectionPad2d(pad_size)
43-
blur_filter = blur_matrix * blur_matrix.T
40+
41+
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
42+
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
43+
# FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
44+
# plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
4445
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
4546

4647
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
4748
if not torch.is_tensor(input_tensor):
48-
raise TypeError("Input input type is not a torch.Tensor. Got {}"
49-
.format(type(input_tensor)))
49+
raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor)))
5050
if not len(input_tensor.shape) == 4:
51-
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
52-
.format(input_tensor.shape))
51+
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape))
5352
# apply blur_filter on input
54-
return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1])
53+
return F.conv2d(
54+
self.padding(input_tensor),
55+
self.blur_filter.type(input_tensor.dtype),
56+
stride=self.stride,
57+
groups=input_tensor.shape[1])

timm/models/resnet.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,14 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
127127
first_planes = planes // reduce_first
128128
outplanes = planes * self.expansion
129129
first_dilation = first_dilation or dilation
130-
self.blur = blur
131130

132-
if blur and stride==2:
133-
self.conv1 = nn.Conv2d(
134-
inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation,
135-
dilation=first_dilation, bias=False)
136-
self.blurpool=BlurPool2d(channels=first_planes)
137-
else:
138-
self.conv1 = nn.Conv2d(
139-
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
131+
self.conv1 = nn.Conv2d(
132+
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
140133
dilation=first_dilation, bias=False)
141-
self.blurpool = None
142-
143134
self.bn1 = norm_layer(first_planes)
144135
self.act1 = act_layer(inplace=True)
136+
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
137+
145138
self.conv2 = nn.Conv2d(
146139
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
147140
self.bn2 = norm_layer(outplanes)
@@ -165,11 +158,9 @@ def forward(self, x):
165158
x = self.bn1(x)
166159
if self.drop_block is not None:
167160
x = self.drop_block(x)
161+
x = self.act1(x)
168162
if self.blurpool is not None:
169-
x = self.act1(x)
170163
x = self.blurpool(x)
171-
else:
172-
x = self.act1(x)
173164

174165
x = self.conv2(x)
175166
x = self.bn2(x)
@@ -209,19 +200,13 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
209200
self.bn1 = norm_layer(first_planes)
210201
self.act1 = act_layer(inplace=True)
211202

212-
if blur and stride==2:
213-
self.conv2 = nn.Conv2d(
214-
first_planes, width, kernel_size=3, stride=1,
215-
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
216-
self.blurpool = BlurPool2d(channels=width)
217-
else:
218-
self.conv2 = nn.Conv2d(
219-
first_planes, width, kernel_size=3, stride=stride,
220-
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
221-
self.blurpool = None
222-
203+
self.conv2 = nn.Conv2d(
204+
first_planes, width, kernel_size=3, stride=1 if blur else stride,
205+
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
223206
self.bn2 = norm_layer(width)
224207
self.act2 = act_layer(inplace=True)
208+
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
209+
225210
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
226211
self.bn3 = norm_layer(outplanes)
227212

@@ -251,6 +236,8 @@ def forward(self, x):
251236
if self.drop_block is not None:
252237
x = self.drop_block(x)
253238
x = self.act2(x)
239+
if self.blurpool is not None:
240+
x = self.blurpool(x)
254241

255242
x = self.conv3(x)
256243
x = self.bn3(x)
@@ -412,11 +399,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
412399
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
413400
self.bn1 = norm_layer(self.inplanes)
414401
self.act1 = act_layer(inplace=True)
415-
# Stem Blur
402+
# Stem Pooling
416403
if 'max' in blur :
417404
self.maxpool = nn.Sequential(*[
418-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
419-
BlurPool2d(channels=self.inplanes)])
405+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
406+
BlurPool2d(channels=self.inplanes, stride=2)
407+
])
420408
else :
421409
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
422410

@@ -470,8 +458,8 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
470458

471459
block_kwargs = dict(
472460
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
473-
dilation=dilation, **kwargs)
474-
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)]
461+
dilation=dilation, blur=self.blur, **kwargs)
462+
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
475463
self.inplanes = planes * block.expansion
476464
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
477465

@@ -1075,7 +1063,7 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
10751063
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
10761064
"""Constructs a ResNet-50 model. With assembled-cnn style blur
10771065
"""
1078-
default_cfg = default_cfgs['resnetblur18']
1079-
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs)
1066+
default_cfg = default_cfgs['resnetblur50']
1067+
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
10801068
model.default_cfg = default_cfg
10811069
return model

0 commit comments

Comments
 (0)