Skip to content

Commit 8f4b92a

Browse files
committed
ya basic & bottle
1 parent 812950c commit 8f4b92a

File tree

6 files changed

+200
-101
lines changed

6 files changed

+200
-101
lines changed

src/model_constructor/universal_blocks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def forward(self, x: torch.Tensor): # type: ignore
140140

141141

142142
class YaResBlock(nn.Module):
143-
"""YaResBlock. Reduce by pool instead of stride 2"""
143+
"""YaResBlock. Reduce by pool instead of stride 2.
144+
Universal model, as XResNet.
145+
If `expansion=1` - `Basic` block, else - `Bottleneck`"""
144146

145147
def __init__(
146148
self,

src/model_constructor/yaresnet.py

Lines changed: 145 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,48 @@
44
from collections import OrderedDict
55
from typing import Callable, Union
66

7-
import torch.nn as nn
7+
import torch
8+
from torch import nn
89
from torch.nn import Mish
910

11+
from model_constructor.helpers import nn_seq
12+
1013
from .layers import ConvBnAct, get_act
11-
from .model_constructor import ModelConstructor
14+
from .model_constructor import ListStrMod, ModelConstructor
1215

1316
__all__ = [
14-
"YaResBlock",
17+
"YaBasicBlock",
18+
"YaBottleneckBlock",
19+
"YaResNet",
1520
"YaResNet34",
1621
"YaResNet50",
1722
]
1823

1924

20-
class YaResBlock(nn.Module):
21-
"""YaResBlock. Reduce by pool instead of stride 2"""
25+
class YaBasicBlock(nn.Module):
26+
"""Ya Basic block.
27+
Reduce by pool instead of stride 2"""
2228

2329
def __init__(
2430
self,
25-
expansion: int,
2631
in_channels: int,
27-
mid_channels: int,
32+
out_channels: int,
2833
stride: int = 1,
29-
conv_layer=ConvBnAct,
34+
conv_layer: type[ConvBnAct] = ConvBnAct,
3035
act_fn: type[nn.Module] = nn.ReLU,
3136
zero_bn: bool = True,
3237
bn_1st: bool = True,
3338
groups: int = 1,
3439
dw: bool = False,
3540
div_groups: Union[None, int] = None,
3641
pool: Union[Callable[[], nn.Module], None] = None,
37-
se: Union[type[nn.Module], None] = None,
38-
sa: Union[type[nn.Module], None] = None,
42+
se: Union[nn.Module, None] = None,
43+
sa: Union[nn.Module, None] = None,
3944
):
4045
super().__init__()
4146
# pool defined at ModelConstructor.
42-
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
4347
if div_groups is not None: # check if groups != 1 and div_groups
44-
groups = int(mid_channels / div_groups)
48+
groups = int(out_channels / div_groups)
4549

4650
if stride != 1:
4751
if pool is None:
@@ -51,74 +55,133 @@ def __init__(
5155
self.reduce = pool()
5256
else:
5357
self.reduce = None
54-
if expansion == 1:
55-
layers = [
56-
(
57-
"conv_0",
58-
conv_layer(
59-
in_channels,
60-
mid_channels,
61-
3,
62-
stride=1,
63-
act_fn=act_fn,
64-
bn_1st=bn_1st,
65-
groups=in_channels if dw else groups,
66-
),
58+
59+
layers: ListStrMod = [
60+
(
61+
"conv_0",
62+
conv_layer(
63+
in_channels,
64+
out_channels,
65+
3,
66+
act_fn=act_fn,
67+
bn_1st=bn_1st,
68+
groups=in_channels if dw else groups,
6769
),
68-
(
69-
"conv_1",
70-
conv_layer(
71-
mid_channels,
72-
out_channels,
73-
3,
74-
zero_bn=zero_bn,
75-
act_fn=False,
76-
bn_1st=bn_1st,
77-
groups=mid_channels if dw else groups,
78-
),
70+
),
71+
(
72+
"conv_1",
73+
conv_layer(
74+
out_channels,
75+
out_channels,
76+
3,
77+
zero_bn=zero_bn,
78+
act_fn=False,
79+
bn_1st=bn_1st,
80+
groups=out_channels if dw else groups,
7981
),
80-
]
82+
),
83+
]
84+
if se:
85+
layers.append(("se", se(out_channels)))
86+
if sa:
87+
layers.append(("sa", sa(out_channels)))
88+
self.convs = nn_seq(layers)
89+
90+
if in_channels != out_channels:
91+
self.id_conv = conv_layer(
92+
in_channels,
93+
out_channels,
94+
1,
95+
stride=1,
96+
act_fn=False,
97+
)
8198
else:
82-
layers = [
83-
(
84-
"conv_0",
85-
conv_layer(
86-
in_channels,
87-
mid_channels,
88-
1,
89-
act_fn=act_fn,
90-
bn_1st=bn_1st,
91-
),
99+
self.id_conv = None
100+
self.merge = get_act(act_fn)
101+
102+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
103+
if self.reduce:
104+
x = self.reduce(x)
105+
identity = self.id_conv(x) if self.id_conv is not None else x
106+
return self.merge(self.convs(x) + identity)
107+
108+
109+
class YaBottleneckBlock(nn.Module):
110+
"""Ya Bottleneck block.
111+
Reduce by pool instead of stride 2"""
112+
113+
def __init__(
114+
self,
115+
in_channels: int,
116+
out_channels: int,
117+
stride: int = 1,
118+
expansion: int = 4,
119+
conv_layer: type[ConvBnAct] = ConvBnAct,
120+
act_fn: type[nn.Module] = nn.ReLU,
121+
zero_bn: bool = True,
122+
bn_1st: bool = True,
123+
groups: int = 1,
124+
dw: bool = False,
125+
div_groups: Union[None, int] = None,
126+
pool: Union[Callable[[], nn.Module], None] = None,
127+
se: Union[nn.Module, None] = None,
128+
sa: Union[nn.Module, None] = None,
129+
):
130+
super().__init__()
131+
# pool defined at ModelConstructor.
132+
mid_channels = out_channels // expansion
133+
if div_groups is not None: # check if groups != 1 and div_groups
134+
groups = int(mid_channels / div_groups)
135+
136+
if stride != 1:
137+
if pool is None:
138+
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
139+
# warnings.warn("pool not passed") # need to warn?
140+
else:
141+
self.reduce = pool()
142+
else:
143+
self.reduce = None
144+
145+
layers: ListStrMod = [
146+
(
147+
"conv_0",
148+
conv_layer(
149+
in_channels,
150+
mid_channels,
151+
1,
152+
act_fn=act_fn,
153+
bn_1st=bn_1st,
154+
),
155+
),
156+
(
157+
"conv_1",
158+
conv_layer(
159+
mid_channels,
160+
mid_channels,
161+
3,
162+
act_fn=act_fn,
163+
bn_1st=bn_1st,
164+
groups=mid_channels if dw else groups,
92165
),
93-
(
94-
"conv_1",
95-
conv_layer(
96-
mid_channels,
97-
mid_channels,
98-
3,
99-
stride=1,
100-
act_fn=act_fn,
101-
bn_1st=bn_1st,
102-
groups=mid_channels if dw else groups,
103-
),
166+
),
167+
(
168+
"conv_2",
169+
conv_layer(
170+
mid_channels,
171+
out_channels,
172+
1,
173+
zero_bn=zero_bn,
174+
act_fn=False,
175+
bn_1st=bn_1st,
104176
),
105-
(
106-
"conv_2",
107-
conv_layer(
108-
mid_channels,
109-
out_channels,
110-
1,
111-
zero_bn=zero_bn,
112-
act_fn=False,
113-
bn_1st=bn_1st,
114-
),
115-
), # noqa E501
116-
]
177+
),
178+
]
117179
if se:
118-
layers.append(("se", se(out_channels))) # type: ignore
180+
layers.append(("se", se(out_channels)))
119181
if sa:
120-
layers.append(("sa", sa(out_channels))) # type: ignore
121-
self.convs = nn.Sequential(OrderedDict(layers))
182+
layers.append(("sa", sa(out_channels)))
183+
self.convs = nn_seq(layers)
184+
122185
if in_channels != out_channels:
123186
self.id_conv = conv_layer(
124187
in_channels,
@@ -131,20 +194,23 @@ def __init__(
131194
self.id_conv = None
132195
self.merge = get_act(act_fn)
133196

134-
def forward(self, x):
197+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
135198
if self.reduce:
136199
x = self.reduce(x)
137200
identity = self.id_conv(x) if self.id_conv is not None else x
138201
return self.merge(self.convs(x) + identity)
139202

140203

141-
class YaResNet34(ModelConstructor):
142-
block: type[nn.Module] = YaResBlock
143-
expansion: int = 1
144-
layers: list[int] = [3, 4, 6, 3]
204+
class YaResNet(ModelConstructor):
205+
block: type[nn.Module] = YaBasicBlock
145206
stem_sizes: list[int] = [3, 32, 64, 64]
146207
act_fn: type[nn.Module] = Mish
147208

148209

210+
class YaResNet34(YaResNet):
211+
stem_sizes: list[int] = [3, 32, 64, 64]
212+
213+
149214
class YaResNet50(YaResNet34):
150-
expansion: int = 4
215+
block: type[nn.Module] = YaBottleneckBlock
216+
block_sizes: list[int] = [256, 512, 1024, 2048]

tests/test_blocks.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from model_constructor.layers import SEModule, SimpleSelfAttention
88
from model_constructor.model_constructor import BasicBlock, BottleneckBlock
9+
from model_constructor.yaresnet import YaBasicBlock, YaBottleneckBlock
910

1011
from .parameters import ids_fn
1112

@@ -14,8 +15,12 @@
1415

1516

1617
params = dict(
17-
Block=[BasicBlock, BottleneckBlock],
18-
# expansion=[1, 2],
18+
Block=[
19+
BasicBlock,
20+
BottleneckBlock,
21+
YaBasicBlock,
22+
YaBottleneckBlock,
23+
],
1924
out_channels=[8, 16],
2025
stride=[1, 2],
2126
div_groups=[None, 2],
@@ -34,7 +39,6 @@ def pytest_generate_tests(metafunc):
3439
def test_block(Block, out_channels, stride, div_groups, pool, se, sa):
3540
"""test block"""
3641
in_channels = 8
37-
# out_channels = mid_channels * expansion
3842
block = Block(
3943
in_channels,
4044
out_channels,
@@ -48,3 +52,18 @@ def test_block(Block, out_channels, stride, div_groups, pool, se, sa):
4852
out = block(xb)
4953
out_size = img_size if stride == 1 else img_size // stride
5054
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])
55+
56+
57+
def test_block_dw(Block, out_channels, stride):
58+
"""test block, dw=1"""
59+
in_channels = 8
60+
block = Block(
61+
in_channels,
62+
out_channels,
63+
stride,
64+
dw=1,
65+
)
66+
xb = torch.randn(bs_test, in_channels, img_size, img_size)
67+
out = block(xb)
68+
out_size = img_size if stride == 1 else img_size // stride
69+
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])

tests/test_blocks_universal.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,20 @@ def test_block(Block, expansion, mid_channels, stride, div_groups, pool, se, sa)
4949
out = block(xb)
5050
out_size = img_size if stride == 1 else img_size // stride
5151
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])
52+
53+
54+
def test_block_dw(Block, expansion, mid_channels, stride):
55+
"""test block, dw=1"""
56+
in_channels = 8
57+
out_channels = mid_channels * expansion
58+
block = Block(
59+
expansion,
60+
in_channels,
61+
mid_channels,
62+
stride,
63+
dw=1,
64+
)
65+
xb = torch.randn(bs_test, in_channels * expansion, img_size, img_size)
66+
out = block(xb)
67+
out_size = img_size if stride == 1 else img_size // stride
68+
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])

tests/test_mc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
import torch
44

5-
from model_constructor.layers import (SEModule, SEModuleConv,
6-
SimpleSelfAttention)
7-
from model_constructor.model_constructor import (BottleneckBlock,
8-
ModelConstructor)
5+
from model_constructor.layers import SEModule, SEModuleConv, SimpleSelfAttention
6+
from model_constructor.model_constructor import BottleneckBlock, ModelConstructor
97

108
bs_test = 4
119
in_chans = 3

0 commit comments

Comments
 (0)