Skip to content

Commit 28739bb

Browse files
committed
Merge branch 'VRandme-blur'
2 parents 7a9942a + 2681a8d commit 28739bb

File tree

5 files changed

+130
-22
lines changed

5 files changed

+130
-22
lines changed

timm/models/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def load_state_dict(checkpoint_path, use_ema=False):
3131
raise FileNotFoundError()
3232

3333

34-
def load_checkpoint(model, checkpoint_path, use_ema=False):
34+
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
3535
state_dict = load_state_dict(checkpoint_path, use_ema)
36-
model.load_state_dict(state_dict)
36+
model.load_state_dict(state_dict, strict=strict)
3737

3838

3939
def resume_checkpoint(model, checkpoint_path):

timm/models/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
1818
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
1919
from .anti_aliasing import AntiAliasDownsampleLayer
20-
from .space_to_depth import SpaceToDepthModule
20+
from .space_to_depth import SpaceToDepthModule
21+
from .blur_pool import BlurPool2d

timm/models/layers/anti_aliasing.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66

77
class AntiAliasDownsampleLayer(nn.Module):
8-
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0):
8+
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
99
super(AntiAliasDownsampleLayer, self).__init__()
1010
if no_jit:
11-
self.op = Downsample(filt_size, stride, channels)
11+
self.op = Downsample(channels, filt_size, stride)
1212
else:
13-
self.op = DownsampleJIT(filt_size, stride, channels)
13+
self.op = DownsampleJIT(channels, filt_size, stride)
1414

1515
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
1616

@@ -20,10 +20,10 @@ def forward(self, x):
2020

2121
@torch.jit.script
2222
class DownsampleJIT(object):
23-
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
23+
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
24+
self.channels = channels
2425
self.stride = stride
2526
self.filt_size = filt_size
26-
self.channels = channels
2727
assert self.filt_size == 3
2828
assert stride == 2
2929
self.filt = {} # lazy init by device for DataParallel compat
@@ -32,8 +32,7 @@ def _create_filter(self, like: torch.Tensor):
3232
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
3333
filt = filt[:, None] * filt[None, :]
3434
filt = filt / torch.sum(filt)
35-
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
36-
return filt
35+
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
3736

3837
def __call__(self, input: torch.Tensor):
3938
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
@@ -42,11 +41,11 @@ def __call__(self, input: torch.Tensor):
4241

4342

4443
class Downsample(nn.Module):
45-
def __init__(self, filt_size=3, stride=2, channels=None):
44+
def __init__(self, channels=None, filt_size=3, stride=2):
4645
super(Downsample, self).__init__()
46+
self.channels = channels
4747
self.filt_size = filt_size
4848
self.stride = stride
49-
self.channels = channels
5049

5150
assert self.filt_size == 3
5251
filt = torch.tensor([1., 2., 1.])

timm/models/layers/blur_pool.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
BlurPool layer inspired by
3+
- Kornia's Max_BlurPool2d
4+
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5+
6+
FIXME merge this impl with those in `anti_aliasing.py`
7+
8+
Hacked together by Chris Ha and Ross Wightman
9+
"""
10+
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
import numpy as np
15+
from typing import Dict
16+
from .padding import get_padding
17+
18+
19+
class BlurPool2d(nn.Module):
20+
r"""Creates a module that computes blurs and downsample a given feature map.
21+
See :cite:`zhang2019shiftinvar` for more details.
22+
Corresponds to the Downsample class, which does blurring and subsampling
23+
24+
Args:
25+
channels = Number of input channels
26+
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
27+
stride (int): downsampling filter stride
28+
29+
Returns:
30+
torch.Tensor: the transformed tensor.
31+
"""
32+
filt: Dict[str, torch.Tensor]
33+
34+
def __init__(self, channels, filt_size=3, stride=2) -> None:
35+
super(BlurPool2d, self).__init__()
36+
assert filt_size > 1
37+
self.channels = channels
38+
self.filt_size = filt_size
39+
self.stride = stride
40+
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
41+
self.padding = nn.ReflectionPad2d(pad_size)
42+
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
43+
self.filt = {} # lazy init by device for DataParallel compat
44+
45+
def _create_filter(self, like: torch.Tensor):
46+
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
47+
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
48+
49+
def _apply(self, fn):
50+
# override nn.Module _apply, reset filter cache if used
51+
self.filt = {}
52+
super(BlurPool2d, self)._apply(fn)
53+
54+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
55+
C = input_tensor.shape[1]
56+
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
57+
return F.conv2d(
58+
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

timm/models/resnet.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .registry import register_model
1414
from .helpers import load_pretrained, adapt_model_from_file
15-
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
15+
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
1616
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1717

1818
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@@ -118,6 +118,11 @@ def _cfg(url='', **kwargs):
118118
'ecaresnet101d_pruned': _cfg(
119119
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
120120
interpolation='bicubic'),
121+
'resnetblur18': _cfg(
122+
interpolation='bicubic'),
123+
'resnetblur50': _cfg(
124+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
125+
interpolation='bicubic')
121126
}
122127

123128

@@ -131,20 +136,23 @@ class BasicBlock(nn.Module):
131136

132137
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
133138
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
134-
attn_layer=None, drop_block=None, drop_path=None):
139+
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
135140
super(BasicBlock, self).__init__()
136141

137142
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
138143
assert base_width == 64, 'BasicBlock does not support changing base width'
139144
first_planes = planes // reduce_first
140145
outplanes = planes * self.expansion
141146
first_dilation = first_dilation or dilation
147+
use_aa = aa_layer is not None
142148

143149
self.conv1 = nn.Conv2d(
144-
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
150+
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
145151
dilation=first_dilation, bias=False)
146152
self.bn1 = norm_layer(first_planes)
147153
self.act1 = act_layer(inplace=True)
154+
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
155+
148156
self.conv2 = nn.Conv2d(
149157
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
150158
self.bn2 = norm_layer(outplanes)
@@ -169,6 +177,8 @@ def forward(self, x):
169177
if self.drop_block is not None:
170178
x = self.drop_block(x)
171179
x = self.act1(x)
180+
if self.aa is not None:
181+
x = self.aa(x)
172182

173183
x = self.conv2(x)
174184
x = self.bn2(x)
@@ -195,22 +205,26 @@ class Bottleneck(nn.Module):
195205

196206
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
197207
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
198-
attn_layer=None, drop_block=None, drop_path=None):
208+
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
199209
super(Bottleneck, self).__init__()
200210

201211
width = int(math.floor(planes * (base_width / 64)) * cardinality)
202212
first_planes = width // reduce_first
203213
outplanes = planes * self.expansion
204214
first_dilation = first_dilation or dilation
215+
use_aa = aa_layer is not None
205216

206217
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
207218
self.bn1 = norm_layer(first_planes)
208219
self.act1 = act_layer(inplace=True)
220+
209221
self.conv2 = nn.Conv2d(
210-
first_planes, width, kernel_size=3, stride=stride,
222+
first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
211223
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
212224
self.bn2 = norm_layer(width)
213225
self.act2 = act_layer(inplace=True)
226+
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
227+
214228
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
215229
self.bn3 = norm_layer(outplanes)
216230

@@ -240,6 +254,8 @@ def forward(self, x):
240254
if self.drop_block is not None:
241255
x = self.drop_block(x)
242256
x = self.act2(x)
257+
if self.aa is not None:
258+
x = self.aa(x)
243259

244260
x = self.conv3(x)
245261
x = self.bn3(x)
@@ -353,8 +369,9 @@ class ResNet(nn.Module):
353369
Whether to use average pooling for projection skip connection between stages/downsample.
354370
output_stride : int, default 32
355371
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
356-
act_layer : class, activation layer
357-
norm_layer : class, normalization layer
372+
act_layer : nn.Module, activation layer
373+
norm_layer : nn.Module, normalization layer
374+
aa_layer : nn.Module, anti-aliasing layer
358375
drop_rate : float, default 0.
359376
Dropout probability before classifier, for training
360377
global_pool : str, default 'avg'
@@ -363,7 +380,7 @@ class ResNet(nn.Module):
363380
def __init__(self, block, layers, num_classes=1000, in_chans=3,
364381
cardinality=1, base_width=64, stem_width=64, stem_type='',
365382
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
366-
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
383+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
367384
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
368385
block_args = block_args or dict()
369386
self.num_classes = num_classes
@@ -393,7 +410,14 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
393410
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
394411
self.bn1 = norm_layer(self.inplanes)
395412
self.act1 = act_layer(inplace=True)
396-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
413+
# Stem Pooling
414+
if aa_layer is not None:
415+
self.maxpool = nn.Sequential(*[
416+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
417+
aa_layer(channels=self.inplanes, stride=2)
418+
])
419+
else:
420+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
397421

398422
# Feature Blocks
399423
dp = DropPath(drop_path_rate) if drop_path_rate else None
@@ -410,7 +434,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
410434
assert output_stride == 32
411435
layer_args = list(zip(channels, layers, strides, dilations))
412436
layer_kwargs = dict(
413-
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
437+
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
414438
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
415439
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
416440
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
@@ -1114,3 +1138,29 @@ def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwarg
11141138
if pretrained:
11151139
load_pretrained(model, default_cfg, num_classes, in_chans)
11161140
return model
1141+
1142+
1143+
@register_model
1144+
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1145+
"""Constructs a ResNet-18 model with blur anti-aliasing
1146+
"""
1147+
default_cfg = default_cfgs['resnetblur18']
1148+
model = ResNet(
1149+
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
1150+
model.default_cfg = default_cfg
1151+
if pretrained:
1152+
load_pretrained(model, default_cfg, num_classes, in_chans)
1153+
return model
1154+
1155+
1156+
@register_model
1157+
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1158+
"""Constructs a ResNet-50 model with blur anti-aliasing
1159+
"""
1160+
default_cfg = default_cfgs['resnetblur50']
1161+
model = ResNet(
1162+
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
1163+
model.default_cfg = default_cfg
1164+
if pretrained:
1165+
load_pretrained(model, default_cfg, num_classes, in_chans)
1166+
return model

0 commit comments

Comments
 (0)