Skip to content

Commit d79f3d9

Browse files
committed
Fix torchscript use for sequencer, add group_matcher, forward_head support, minor formatting
1 parent 93a79a3 commit d79f3d9

File tree

1 file changed

+60
-33
lines changed

1 file changed

+60
-33
lines changed

timm/models/sequencer.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals
7171
module.init_weights()
7272

7373

74-
def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer,
75-
norm_layer, act_layer, num_layers, bidirectional, union,
76-
with_fc, drop=0., drop_path_rate=0., **kwargs):
74+
def get_stage(
75+
index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer,
76+
norm_layer, act_layer, num_layers, bidirectional, union,
77+
with_fc, drop=0., drop_path_rate=0., **kwargs):
7778
assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
7879
blocks = []
7980
for block_idx in range(layers[index]):
8081
drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
81-
blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index],
82-
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer,
83-
act_layer=act_layer, num_layers=num_layers,
84-
bidirectional=bidirectional, union=union, with_fc=with_fc,
85-
drop=drop, drop_path=drop_path))
82+
blocks.append(block_layer(
83+
embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index],
84+
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer,
85+
num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc,
86+
drop=drop, drop_path=drop_path))
8687

8788
if index < len(embed_dims) - 1:
8889
blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1]))
@@ -101,9 +102,10 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
101102

102103
class RNN2DBase(nn.Module):
103104

104-
def __init__(self, input_size: int, hidden_size: int,
105-
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
106-
union="cat", with_fc=True):
105+
def __init__(
106+
self, input_size: int, hidden_size: int,
107+
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
108+
union="cat", with_fc=True):
107109
super().__init__()
108110

109111
self.input_size = input_size
@@ -115,6 +117,7 @@ def __init__(self, input_size: int, hidden_size: int,
115117
self.with_horizontal = True
116118
self.with_fc = with_fc
117119

120+
self.fc = None
118121
if with_fc:
119122
if union == "cat":
120123
self.fc = nn.Linear(2 * self.output_size, input_size)
@@ -159,33 +162,38 @@ def forward(self, x):
159162
v, _ = self.rnn_v(v)
160163
v = v.reshape(B, W, H, -1)
161164
v = v.permute(0, 2, 1, 3)
165+
else:
166+
v = None
162167

163168
if self.with_horizontal:
164169
h = x.reshape(-1, W, C)
165170
h, _ = self.rnn_h(h)
166171
h = h.reshape(B, H, W, -1)
172+
else:
173+
h = None
167174

168-
if self.with_vertical and self.with_horizontal:
175+
if v is not None and h is not None:
169176
if self.union == "cat":
170177
x = torch.cat([v, h], dim=-1)
171178
else:
172179
x = v + h
173-
elif self.with_vertical:
180+
elif v is not None:
174181
x = v
175-
elif self.with_horizontal:
182+
elif h is not None:
176183
x = h
177184

178-
if self.with_fc:
185+
if self.fc is not None:
179186
x = self.fc(x)
180187

181188
return x
182189

183190

184191
class LSTM2D(RNN2DBase):
185192

186-
def __init__(self, input_size: int, hidden_size: int,
187-
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
188-
union="cat", with_fc=True):
193+
def __init__(
194+
self, input_size: int, hidden_size: int,
195+
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
196+
union="cat", with_fc=True):
189197
super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc)
190198
if self.with_vertical:
191199
self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional)
@@ -194,10 +202,10 @@ def __init__(self, input_size: int, hidden_size: int,
194202

195203

196204
class Sequencer2DBlock(nn.Module):
197-
def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp,
198-
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU,
199-
num_layers=1, bidirectional=True, union="cat", with_fc=True,
200-
drop=0., drop_path=0.):
205+
def __init__(
206+
self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp,
207+
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU,
208+
num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.):
201209
super().__init__()
202210
channels_dim = int(mlp_ratio * dim)
203211
self.norm1 = norm_layer(dim)
@@ -255,6 +263,7 @@ def __init__(
255263
num_classes=1000,
256264
img_size=224,
257265
in_chans=3,
266+
global_pool='avg',
258267
layers=[4, 3, 8, 3],
259268
patch_sizes=[7, 2, 1, 1],
260269
embed_dims=[192, 384, 384, 384],
@@ -275,7 +284,9 @@ def __init__(
275284
stem_norm=False,
276285
):
277286
super().__init__()
287+
assert global_pool in ('', 'avg')
278288
self.num_classes = num_classes
289+
self.global_pool = global_pool
279290
self.num_features = embed_dims[-1] # num_features for consistency with other models
280291
self.embed_dims = embed_dims
281292
self.stem = PatchEmbed(
@@ -301,38 +312,54 @@ def init_weights(self, nlhb=False):
301312
head_bias = -math.log(self.num_classes) if nlhb else 0.
302313
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
303314

315+
@torch.jit.ignore
316+
def group_matcher(self, coarse=False):
317+
return dict(
318+
stem=r'^stem',
319+
blocks=[
320+
(r'^blocks\.(\d+)\..*\.down', (99999,)),
321+
(r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None),
322+
(r'^norm', (99999,))
323+
]
324+
)
325+
326+
@torch.jit.ignore
327+
def set_grad_checkpointing(self, enable=True):
328+
assert not enable, 'gradient checkpointing not supported'
329+
330+
@torch.jit.ignore
304331
def get_classifier(self):
305332
return self.head
306333

307-
def reset_classifier(self, num_classes, global_pool=''):
334+
def reset_classifier(self, num_classes, global_pool=None):
308335
self.num_classes = num_classes
336+
if self.global_pool is not None:
337+
assert global_pool in ('', 'avg')
338+
self.global_pool = global_pool
309339
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
310340

311341
def forward_features(self, x):
312342
x = self.stem(x)
313343
x = self.blocks(x)
314344
x = self.norm(x)
315-
x = x.mean(dim=(1, 2))
316345
return x
317346

347+
def forward_head(self, x, pre_logits: bool = False):
348+
if self.global_pool == 'avg':
349+
x = x.mean(dim=(1, 2))
350+
return x if pre_logits else self.head(x)
351+
318352
def forward(self, x):
319353
x = self.forward_features(x)
320-
x = self.head(x)
354+
x = self.forward_head(x)
321355
return x
322356

323357

324-
def checkpoint_filter_fn(state_dict, model):
325-
return state_dict
326-
327-
328358
def _create_sequencer2d(variant, pretrained=False, **kwargs):
329359
if kwargs.get('features_only', None):
330360
raise RuntimeError('features_only not implemented for Sequencer2D models.')
331361

332-
model = build_model_with_cfg(
333-
Sequencer2D, variant, pretrained,
334-
pretrained_filter_fn=checkpoint_filter_fn,
335-
**kwargs)
362+
model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs)
336363
return model
337364

338365

0 commit comments

Comments
 (0)