Skip to content

Commit b2c305c

Browse files
committed
Move Mlp and PatchEmbed modules into layers. Being used in lots of models now...
1 parent 3ba6b55 commit b2c305c

File tree

10 files changed

+140
-162
lines changed

10 files changed

+140
-162
lines changed

timm/models/cait.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1717
from .helpers import build_model_with_cfg, overlay_external_default_cfg
18-
from .layers import trunc_normal_, DropPath
19-
from .vision_transformer import Mlp, PatchEmbed
18+
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
2019
from .registry import register_model
2120

2221

timm/models/coat.py

Lines changed: 17 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1717
from timm.models.helpers import load_pretrained
18-
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18+
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
1919
from timm.models.registry import register_model
2020

2121
from functools import partial
@@ -54,26 +54,6 @@ def _cfg_coat(url='', **kwargs):
5454
}
5555

5656

57-
class Mlp(nn.Module):
58-
""" Feed-forward network (FFN, a.k.a. MLP) class. """
59-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
60-
super().__init__()
61-
out_features = out_features or in_features
62-
hidden_features = hidden_features or in_features
63-
self.fc1 = nn.Linear(in_features, hidden_features)
64-
self.act = act_layer()
65-
self.fc2 = nn.Linear(hidden_features, out_features)
66-
self.drop = nn.Dropout(drop)
67-
68-
def forward(self, x):
69-
x = self.fc1(x)
70-
x = self.act(x)
71-
x = self.drop(x)
72-
x = self.fc2(x)
73-
x = self.drop(x)
74-
return x
75-
76-
7757
class ConvRelPosEnc(nn.Module):
7858
""" Convolutional relative position encoding. """
7959
def __init__(self, Ch, h, window):
@@ -348,34 +328,6 @@ def forward(self, x1, x2, x3, x4, sizes):
348328
return x1, x2, x3, x4
349329

350330

351-
class PatchEmbed(nn.Module):
352-
""" Image to Patch Embedding """
353-
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
354-
super().__init__()
355-
img_size = to_2tuple(img_size)
356-
patch_size = to_2tuple(patch_size)
357-
358-
self.img_size = img_size
359-
self.patch_size = patch_size
360-
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
361-
f"img_size {img_size} should be divided by patch_size {patch_size}."
362-
# Note: self.H, self.W and self.num_patches are not used
363-
# since the image size may change on the fly.
364-
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
365-
self.num_patches = self.H * self.W
366-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
367-
self.norm = nn.LayerNorm(embed_dim)
368-
369-
def forward(self, x):
370-
_, _, H, W = x.shape
371-
out_H, out_W = H // self.patch_size[0], W // self.patch_size[1]
372-
373-
x = self.proj(x).flatten(2).transpose(1, 2)
374-
out = self.norm(x)
375-
376-
return out, (out_H, out_W)
377-
378-
379331
class CoaT(nn.Module):
380332
""" CoaT class. """
381333
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0],
@@ -391,13 +343,17 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
391343

392344
# Patch embeddings.
393345
self.patch_embed1 = PatchEmbed(
394-
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
346+
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
347+
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
395348
self.patch_embed2 = PatchEmbed(
396-
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
349+
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
350+
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
397351
self.patch_embed3 = PatchEmbed(
398-
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
352+
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
353+
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
399354
self.patch_embed4 = PatchEmbed(
400-
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
355+
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
356+
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
401357

402358
# Class tokens.
403359
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
@@ -533,31 +489,35 @@ def forward_features(self, x0):
533489
B = x0.shape[0]
534490

535491
# Serial blocks 1.
536-
x1, (H1, W1) = self.patch_embed1(x0)
492+
x1 = self.patch_embed1(x0)
493+
H1, W1 = self.patch_embed1.out_size
537494
x1 = self.insert_cls(x1, self.cls_token1)
538495
for blk in self.serial_blocks1:
539496
x1 = blk(x1, size=(H1, W1))
540497
x1_nocls = self.remove_cls(x1)
541498
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
542499

543500
# Serial blocks 2.
544-
x2, (H2, W2) = self.patch_embed2(x1_nocls)
501+
x2 = self.patch_embed2(x1_nocls)
502+
H2, W2 = self.patch_embed2.out_size
545503
x2 = self.insert_cls(x2, self.cls_token2)
546504
for blk in self.serial_blocks2:
547505
x2 = blk(x2, size=(H2, W2))
548506
x2_nocls = self.remove_cls(x2)
549507
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
550508

551509
# Serial blocks 3.
552-
x3, (H3, W3) = self.patch_embed3(x2_nocls)
510+
x3 = self.patch_embed3(x2_nocls)
511+
H3, W3 = self.patch_embed3.out_size
553512
x3 = self.insert_cls(x3, self.cls_token3)
554513
for blk in self.serial_blocks3:
555514
x3 = blk(x3, size=(H3, W3))
556515
x3_nocls = self.remove_cls(x3)
557516
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
558517

559518
# Serial blocks 4.
560-
x4, (H4, W4) = self.patch_embed4(x3_nocls)
519+
x4 = self.patch_embed4(x3_nocls)
520+
H4, W4 = self.patch_embed4.out_size
561521
x4 = self.insert_cls(x4, self.cls_token4)
562522
for blk in self.serial_blocks4:
563523
x4 = blk(x4, size=(H4, W4))

timm/models/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from .inplace_abn import InplaceAbn
2121
from .linear import Linear
2222
from .mixed_conv2d import MixedConv2d
23+
from .mlp import Mlp, GluMlp
2324
from .norm import GroupNorm
2425
from .norm_act import BatchNormAct2d, GroupNormAct
2526
from .padding import get_padding, get_same_padding, pad_same
27+
from .patch_embed import PatchEmbed
2628
from .pool2d_same import AvgPool2dSame, create_pool2d
2729
from .se import SEModule
2830
from .selective_kernel import SelectiveKernelConv

timm/models/layers/mlp.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
""" MLP module w/ dropout and configurable activation layer
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
from torch import nn as nn
6+
7+
8+
class Mlp(nn.Module):
9+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
10+
"""
11+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
12+
super().__init__()
13+
out_features = out_features or in_features
14+
hidden_features = hidden_features or in_features
15+
self.fc1 = nn.Linear(in_features, hidden_features)
16+
self.act = act_layer()
17+
self.fc2 = nn.Linear(hidden_features, out_features)
18+
self.drop = nn.Dropout(drop)
19+
20+
def forward(self, x):
21+
x = self.fc1(x)
22+
x = self.act(x)
23+
x = self.drop(x)
24+
x = self.fc2(x)
25+
x = self.drop(x)
26+
return x
27+
28+
29+
class GluMlp(nn.Module):
30+
""" MLP w/ GLU style gating
31+
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
32+
"""
33+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
34+
super().__init__()
35+
out_features = out_features or in_features
36+
hidden_features = hidden_features or in_features
37+
self.fc1 = nn.Linear(in_features, hidden_features * 2)
38+
self.act = act_layer()
39+
self.fc2 = nn.Linear(hidden_features, out_features)
40+
self.drop = nn.Dropout(drop)
41+
42+
def forward(self, x):
43+
x = self.fc1(x)
44+
x, gates = x.chunk(2, dim=-1)
45+
x = x * self.act(gates)
46+
x = self.drop(x)
47+
x = self.fc2(x)
48+
x = self.drop(x)
49+
return x

timm/models/layers/patch_embed.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
""" Image to Patch Embedding using Conv2d
2+
3+
A convolution based approach to patchifying a 2D image w/ embedding projection.
4+
5+
Based on the impl in https://github.com/google-research/vision_transformer
6+
7+
Hacked together by / Copyright 2020 Ross Wightman
8+
"""
9+
10+
from torch import nn as nn
11+
12+
from .helpers import to_2tuple
13+
14+
15+
class PatchEmbed(nn.Module):
16+
""" 2D Image to Patch Embedding
17+
"""
18+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
19+
super().__init__()
20+
img_size = to_2tuple(img_size)
21+
patch_size = to_2tuple(patch_size)
22+
self.img_size = img_size
23+
self.patch_size = patch_size
24+
self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
25+
self.num_patches = self.out_size[0] * self.out_size[1]
26+
27+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
28+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
29+
30+
def forward(self, x):
31+
B, C, H, W = x.shape
32+
assert H == self.img_size[0] and W == self.img_size[1], \
33+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
34+
x = self.proj(x).flatten(2).transpose(1, 2)
35+
x = self.norm(x)
36+
return x

timm/models/mlp_mixer.py

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2727
from .helpers import build_model_with_cfg, overlay_external_default_cfg
28-
from .layers import DropPath, to_2tuple, lecun_normal_
28+
from .layers import PatchEmbed, Mlp, GluMlp, DropPath, lecun_normal_
2929
from .registry import register_model
3030

3131

@@ -43,6 +43,7 @@ def _cfg(url='', **kwargs):
4343
default_cfgs = dict(
4444
mixer_s32_224=_cfg(),
4545
mixer_s16_224=_cfg(),
46+
mixer_s16_glu_224=_cfg(),
4647
mixer_b32_224=_cfg(),
4748
mixer_b16_224=_cfg(
4849
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
@@ -62,65 +63,17 @@ def _cfg(url='', **kwargs):
6263
)
6364

6465

65-
class Mlp(nn.Module):
66-
""" MLP Block
67-
NOTE: same impl as ViT, move to common location
68-
"""
69-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
70-
super().__init__()
71-
out_features = out_features or in_features
72-
hidden_features = hidden_features or in_features
73-
self.fc1 = nn.Linear(in_features, hidden_features)
74-
self.act = act_layer()
75-
self.fc2 = nn.Linear(hidden_features, out_features)
76-
self.drop = nn.Dropout(drop)
77-
78-
def forward(self, x):
79-
x = self.fc1(x)
80-
x = self.act(x)
81-
x = self.drop(x)
82-
x = self.fc2(x)
83-
x = self.drop(x)
84-
return x
85-
86-
87-
class PatchEmbed(nn.Module):
88-
""" Image to Patch Embedding
89-
NOTE: same impl as ViT, move to common location
90-
"""
91-
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
92-
super().__init__()
93-
img_size = to_2tuple(img_size)
94-
patch_size = to_2tuple(patch_size)
95-
self.img_size = img_size
96-
self.patch_size = patch_size
97-
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
98-
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
99-
100-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
101-
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
102-
103-
def forward(self, x):
104-
B, C, H, W = x.shape
105-
# FIXME look at relaxing size constraints
106-
assert H == self.img_size[0] and W == self.img_size[1], \
107-
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
108-
x = self.proj(x).flatten(2).transpose(1, 2)
109-
x = self.norm(x)
110-
return x
111-
112-
11366
class MixerBlock(nn.Module):
11467

11568
def __init__(
11669
self, dim, seq_len, tokens_dim, channels_dim,
117-
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
70+
mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
11871
super().__init__()
11972
self.norm1 = norm_layer(dim)
120-
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
73+
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
12174
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
12275
self.norm2 = norm_layer(dim)
123-
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
76+
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
12477

12578
def forward(self, x):
12679
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
@@ -140,6 +93,7 @@ def __init__(
14093
hidden_dim=512,
14194
tokens_dim=256,
14295
channels_dim=2048,
96+
mlp_layer=Mlp,
14397
norm_layer=partial(nn.LayerNorm, eps=1e-6),
14498
act_layer=nn.GELU,
14599
drop=0.,
@@ -154,7 +108,7 @@ def __init__(
154108
self.blocks = nn.Sequential(*[
155109
MixerBlock(
156110
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
157-
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
111+
mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
158112
for _ in range(num_blocks)])
159113
self.norm = norm_layer(hidden_dim)
160114
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
@@ -238,6 +192,17 @@ def mixer_s16_224(pretrained=False, **kwargs):
238192
return model
239193

240194

195+
@register_model
196+
def mixer_s16_glu_224(pretrained=False, **kwargs):
197+
""" Mixer-S/16 224x224
198+
"""
199+
model_args = dict(
200+
patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=1536,
201+
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
202+
model = _create_mixer('mixer_s16_glu_224', pretrained=pretrained, **model_args)
203+
return model
204+
205+
241206
@register_model
242207
def mixer_b32_224(pretrained=False, **kwargs):
243208
""" Mixer-B/32 224x224

timm/models/resnet.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def _cfg(url='', **kwargs):
4949
'resnet26d': _cfg(
5050
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
5151
interpolation='bicubic', first_conv='conv1.0'),
52+
'resnet26t': _cfg(
53+
url='',
54+
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
5255
'resnet50': _cfg(
5356
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',
5457
interpolation='bicubic'),
@@ -723,6 +726,15 @@ def resnet26(pretrained=False, **kwargs):
723726
return _create_resnet('resnet26', pretrained, **model_args)
724727

725728

729+
@register_model
730+
def resnet26t(pretrained=False, **kwargs):
731+
"""Constructs a ResNet-26-T model.
732+
"""
733+
model_args = dict(
734+
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
735+
return _create_resnet('resnet26t', pretrained, **model_args)
736+
737+
726738
@register_model
727739
def resnet26d(pretrained=False, **kwargs):
728740
"""Constructs a ResNet-26-D model.

0 commit comments

Comments
 (0)