Skip to content

Commit 756ceca

Browse files
committed
Basicblock, Bottleneck
1 parent ae180ba commit 756ceca

File tree

7 files changed

+229
-116
lines changed

7 files changed

+229
-116
lines changed

src/model_constructor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from model_constructor.convmixer import ConvMixer # noqa F401
22
from model_constructor.model_constructor import (
33
ModelConstructor,
4-
ResBlock,
54
ModelCfg,
65
) # noqa F401
76

src/model_constructor/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from collections import OrderedDict
2+
from typing import Iterable
23

34
from torch import nn
45

56

6-
def nn_seq(list_of_tuples: list[tuple[str, nn.Module]]) -> nn.Sequential:
7+
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
78
"""return nn.Sequential from OrderedDict from list of tuples"""
89
return nn.Sequential(OrderedDict(list_of_tuples))

src/model_constructor/model_constructor.py

Lines changed: 163 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
# pylance: disable=overridden method
12
from collections import OrderedDict
23
from functools import partial
34
from typing import Any, Callable, Optional, TypeVar, Union
45

5-
import torch.nn as nn
6+
import torch
67
from pydantic import BaseModel, validator
8+
from torch import nn
79

10+
from .helpers import nn_seq
811
from .layers import ConvBnAct, SEModule, SimpleSelfAttention, get_act
912

1013
__all__ = [
1114
"init_cnn",
12-
"ResBlock",
15+
# "ResBlock",
1316
"ModelConstructor",
1417
"XResNet34",
1518
"XResNet50",
@@ -18,6 +21,8 @@
1821

1922
TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
2023

24+
ListStrMod = list[tuple[str, nn.Module]]
25+
2126

2227
def init_cnn(module: nn.Module) -> None:
2328
"Init module - kaiming_normal for Conv2d and 0 for biases."
@@ -29,16 +34,16 @@ def init_cnn(module: nn.Module) -> None:
2934
init_cnn(layer)
3035

3136

32-
class ResBlock(nn.Module):
33-
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
37+
class BasicBlock(nn.Module):
38+
"""Basic Resnet block."""
3439

3540
def __init__(
3641
self,
37-
expansion: int,
42+
# expansion: int,
3843
in_channels: int,
39-
mid_channels: int,
44+
out_channels: int,
4045
stride: int = 1,
41-
conv_layer: type[nn.Module] = ConvBnAct,
46+
conv_layer: type[ConvBnAct] = ConvBnAct,
4247
act_fn: type[nn.Module] = nn.ReLU,
4348
zero_bn: bool = True,
4449
bn_1st: bool = True,
@@ -51,85 +56,142 @@ def __init__(
5156
):
5257
super().__init__()
5358
# pool defined at ModelConstructor.
54-
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
59+
# out_channels, in_channels = mid_channels * expansion, in_channels * expansion
5560
if div_groups is not None: # check if groups != 1 and div_groups
56-
groups = int(mid_channels / div_groups)
57-
if expansion == 1:
58-
layers = [
59-
(
60-
"conv_0",
61-
conv_layer(
62-
in_channels,
63-
mid_channels,
64-
3,
65-
stride=stride, # type: ignore
66-
act_fn=act_fn,
67-
bn_1st=bn_1st,
68-
groups=in_channels if dw else groups,
69-
),
61+
groups = int(out_channels / div_groups)
62+
layers: ListStrMod = [
63+
(
64+
"conv_0",
65+
conv_layer(
66+
in_channels,
67+
out_channels,
68+
3,
69+
stride=stride, # type: ignore
70+
act_fn=act_fn,
71+
bn_1st=bn_1st,
72+
groups=in_channels if dw else groups,
7073
),
71-
(
72-
"conv_1",
73-
conv_layer(
74-
mid_channels,
75-
out_channels,
76-
3,
77-
zero_bn=zero_bn,
78-
act_fn=False,
79-
bn_1st=bn_1st,
80-
groups=mid_channels if dw else groups,
81-
),
74+
),
75+
(
76+
"conv_1",
77+
conv_layer(
78+
out_channels,
79+
out_channels,
80+
3,
81+
zero_bn=zero_bn,
82+
act_fn=False,
83+
bn_1st=bn_1st,
84+
groups=out_channels if dw else groups,
8285
),
83-
]
86+
),
87+
]
88+
if se:
89+
layers.append(("se", se(out_channels)))
90+
if sa:
91+
layers.append(("sa", sa(out_channels)))
92+
self.convs = nn_seq(layers)
93+
if stride != 1 or in_channels != out_channels:
94+
id_layers: ListStrMod = []
95+
if (
96+
stride != 1 and pool is not None
97+
): # if pool - reduce by pool else stride 2 art id_conv
98+
id_layers.append(("pool", pool()))
99+
if in_channels != out_channels or (stride != 1 and pool is None):
100+
id_layers.append(
101+
(
102+
"id_conv",
103+
conv_layer(
104+
in_channels,
105+
out_channels,
106+
1,
107+
stride=1 if pool else stride,
108+
act_fn=False,
109+
),
110+
)
111+
)
112+
self.id_conv = nn_seq(id_layers)
84113
else:
85-
layers = [
86-
(
87-
"conv_0",
88-
conv_layer(
89-
in_channels,
90-
mid_channels,
91-
1,
92-
act_fn=act_fn,
93-
bn_1st=bn_1st,
94-
),
114+
self.id_conv = None
115+
self.act_fn = get_act(act_fn)
116+
117+
def forward(self, x: torch.Tensor) -> torch.Tensor:
118+
identity = self.id_conv(x) if self.id_conv is not None else x
119+
return self.act_fn(self.convs(x) + identity)
120+
121+
122+
class BottleneckBlock(nn.Module):
123+
"""Bottleneck Resnet block."""
124+
125+
def __init__(
126+
self,
127+
in_channels: int,
128+
out_channels: int,
129+
stride: int = 1,
130+
expansion: int = 4,
131+
conv_layer: type[ConvBnAct] = ConvBnAct,
132+
act_fn: type[nn.Module] = nn.ReLU,
133+
zero_bn: bool = True,
134+
bn_1st: bool = True,
135+
groups: int = 1,
136+
dw: bool = False,
137+
div_groups: Union[None, int] = None,
138+
pool: Union[Callable[[], nn.Module], None] = None,
139+
se: Union[nn.Module, None] = None,
140+
sa: Union[nn.Module, None] = None,
141+
):
142+
super().__init__()
143+
# pool defined at ModelConstructor.
144+
mid_channels = out_channels // expansion
145+
if div_groups is not None: # check if groups != 1 and div_groups
146+
groups = int(mid_channels / div_groups)
147+
layers: ListStrMod = [
148+
(
149+
"conv_0",
150+
conv_layer(
151+
in_channels,
152+
mid_channels,
153+
1,
154+
act_fn=act_fn,
155+
bn_1st=bn_1st,
95156
),
96-
(
97-
"conv_1",
98-
conv_layer(
99-
mid_channels,
100-
mid_channels,
101-
3,
102-
stride=stride,
103-
act_fn=act_fn,
104-
bn_1st=bn_1st,
105-
groups=mid_channels if dw else groups,
106-
),
157+
),
158+
(
159+
"conv_1",
160+
conv_layer(
161+
mid_channels,
162+
mid_channels,
163+
3,
164+
stride=stride,
165+
act_fn=act_fn,
166+
bn_1st=bn_1st,
167+
groups=mid_channels if dw else groups,
107168
),
108-
(
109-
"conv_2",
110-
conv_layer(
111-
mid_channels,
112-
out_channels,
113-
1,
114-
zero_bn=zero_bn,
115-
act_fn=False,
116-
bn_1st=bn_1st,
117-
),
118-
), # noqa E501
119-
]
169+
),
170+
(
171+
"conv_2",
172+
conv_layer(
173+
mid_channels,
174+
out_channels,
175+
1,
176+
zero_bn=zero_bn,
177+
act_fn=False,
178+
bn_1st=bn_1st,
179+
),
180+
), # noqa E501
181+
]
120182
if se:
121183
layers.append(("se", se(out_channels)))
122184
if sa:
123185
layers.append(("sa", sa(out_channels)))
124-
self.convs = nn.Sequential(OrderedDict(layers))
186+
self.convs = nn_seq(layers)
125187
if stride != 1 or in_channels != out_channels:
126-
id_layers = []
188+
id_layers: ListStrMod = []
127189
if (
128190
stride != 1 and pool is not None
129191
): # if pool - reduce by pool else stride 2 art id_conv
130192
id_layers.append(("pool", pool()))
131193
if in_channels != out_channels or (stride != 1 and pool is None):
132-
id_layers += [
194+
id_layers.append(
133195
(
134196
"id_conv",
135197
conv_layer(
@@ -140,21 +202,21 @@ def __init__(
140202
act_fn=False,
141203
),
142204
)
143-
]
144-
self.id_conv = nn.Sequential(OrderedDict(id_layers))
205+
)
206+
self.id_conv = nn_seq(id_layers)
145207
else:
146208
self.id_conv = None
147209
self.act_fn = get_act(act_fn)
148210

149-
def forward(self, x):
211+
def forward(self, x: torch.Tensor) -> torch.Tensor:
150212
identity = self.id_conv(x) if self.id_conv is not None else x
151213
return self.act_fn(self.convs(x) + identity)
152214

153215

154216
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
155217
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
156218
len_stem = len(cfg.stem_sizes)
157-
stem: list[tuple[str, nn.Module]] = [
219+
stem: ListStrMod = [
158220
(
159221
f"conv_{i}",
160222
cfg.conv_layer(
@@ -180,36 +242,32 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
180242
# if no pool on stem - stride = 2 for first layer block in body
181243
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
182244
num_blocks = cfg.layers[layer_num]
183-
block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
184-
return nn.Sequential(
185-
OrderedDict(
186-
[
187-
(
188-
f"bl_{block_num}",
189-
cfg.block(
190-
cfg.expansion, # type: ignore
191-
block_chs[layer_num]
192-
if block_num == 0
193-
else block_chs[layer_num + 1],
194-
block_chs[layer_num + 1],
195-
stride if block_num == 0 else 1,
196-
sa=cfg.sa
197-
if (block_num == num_blocks - 1) and layer_num == 0
198-
else None,
199-
conv_layer=cfg.conv_layer,
200-
act_fn=cfg.act_fn,
201-
pool=cfg.pool,
202-
zero_bn=cfg.zero_bn,
203-
bn_1st=cfg.bn_1st,
204-
groups=cfg.groups,
205-
div_groups=cfg.div_groups,
206-
dw=cfg.dw,
207-
se=cfg.se,
208-
),
209-
)
210-
for block_num in range(num_blocks)
211-
]
245+
block_chs = [cfg.stem_sizes[-1]] + cfg.block_sizes
246+
return nn_seq(
247+
(
248+
f"bl_{block_num}",
249+
cfg.block(
250+
# cfg.expansion, # type: ignore
251+
block_chs[layer_num]
252+
if block_num == 0
253+
else block_chs[layer_num + 1],
254+
block_chs[layer_num + 1],
255+
stride if block_num == 0 else 1,
256+
sa=cfg.sa
257+
if (block_num == num_blocks - 1) and layer_num == 0
258+
else None,
259+
conv_layer=cfg.conv_layer,
260+
act_fn=cfg.act_fn,
261+
pool=cfg.pool,
262+
zero_bn=cfg.zero_bn,
263+
bn_1st=cfg.bn_1st,
264+
groups=cfg.groups,
265+
div_groups=cfg.div_groups,
266+
dw=cfg.dw,
267+
se=cfg.se,
268+
)
212269
)
270+
for block_num in range(num_blocks)
213271
)
214272

215273

@@ -230,7 +288,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
230288
head = [
231289
("pool", nn.AdaptiveAvgPool2d(1)),
232290
("flat", nn.Flatten()),
233-
("fc", nn.Linear(cfg.block_sizes[-1] * cfg.expansion, cfg.num_classes)),
291+
("fc", nn.Linear(cfg.block_sizes[-1], cfg.num_classes)),
234292
]
235293
return nn.Sequential(OrderedDict(head))
236294

@@ -241,7 +299,7 @@ class ModelCfg(BaseModel):
241299
name: Optional[str] = None
242300
in_chans: int = 3
243301
num_classes: int = 1000
244-
block: type[nn.Module] = ResBlock
302+
block: type[nn.Module] = BasicBlock
245303
conv_layer: type[nn.Module] = ConvBnAct
246304
block_sizes: list[int] = [64, 128, 256, 512]
247305
layers: list[int] = [2, 2, 2, 2]

0 commit comments

Comments
 (0)