1+ # pylance: disable=overridden method
12from collections import OrderedDict
23from functools import partial
34from typing import Any , Callable , Optional , TypeVar , Union
45
5- import torch . nn as nn
6+ import torch
67from pydantic import BaseModel , validator
8+ from torch import nn
79
10+ from .helpers import nn_seq
811from .layers import ConvBnAct , SEModule , SimpleSelfAttention , get_act
912
1013__all__ = [
1114 "init_cnn" ,
12- "ResBlock" ,
15+ # "ResBlock",
1316 "ModelConstructor" ,
1417 "XResNet34" ,
1518 "XResNet50" ,
1821
1922TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
2023
24+ ListStrMod = list [tuple [str , nn .Module ]]
25+
2126
2227def init_cnn (module : nn .Module ) -> None :
2328 "Init module - kaiming_normal for Conv2d and 0 for biases."
@@ -29,16 +34,16 @@ def init_cnn(module: nn.Module) -> None:
2934 init_cnn (layer )
3035
3136
32- class ResBlock (nn .Module ):
33- """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck ."""
37+ class BasicBlock (nn .Module ):
38+ """Basic Resnet block."""
3439
3540 def __init__ (
3641 self ,
37- expansion : int ,
42+ # expansion: int,
3843 in_channels : int ,
39- mid_channels : int ,
44+ out_channels : int ,
4045 stride : int = 1 ,
41- conv_layer : type [nn . Module ] = ConvBnAct ,
46+ conv_layer : type [ConvBnAct ] = ConvBnAct ,
4247 act_fn : type [nn .Module ] = nn .ReLU ,
4348 zero_bn : bool = True ,
4449 bn_1st : bool = True ,
@@ -51,85 +56,142 @@ def __init__(
5156 ):
5257 super ().__init__ ()
5358 # pool defined at ModelConstructor.
54- out_channels , in_channels = mid_channels * expansion , in_channels * expansion
59+ # out_channels, in_channels = mid_channels * expansion, in_channels * expansion
5560 if div_groups is not None : # check if groups != 1 and div_groups
56- groups = int (mid_channels / div_groups )
57- if expansion == 1 :
58- layers = [
59- (
60- "conv_0" ,
61- conv_layer (
62- in_channels ,
63- mid_channels ,
64- 3 ,
65- stride = stride , # type: ignore
66- act_fn = act_fn ,
67- bn_1st = bn_1st ,
68- groups = in_channels if dw else groups ,
69- ),
61+ groups = int (out_channels / div_groups )
62+ layers : ListStrMod = [
63+ (
64+ "conv_0" ,
65+ conv_layer (
66+ in_channels ,
67+ out_channels ,
68+ 3 ,
69+ stride = stride , # type: ignore
70+ act_fn = act_fn ,
71+ bn_1st = bn_1st ,
72+ groups = in_channels if dw else groups ,
7073 ),
71- (
72- "conv_1" ,
73- conv_layer (
74- mid_channels ,
75- out_channels ,
76- 3 ,
77- zero_bn = zero_bn ,
78- act_fn = False ,
79- bn_1st = bn_1st ,
80- groups = mid_channels if dw else groups ,
81- ) ,
74+ ),
75+ (
76+ "conv_1" ,
77+ conv_layer (
78+ out_channels ,
79+ out_channels ,
80+ 3 ,
81+ zero_bn = zero_bn ,
82+ act_fn = False ,
83+ bn_1st = bn_1st ,
84+ groups = out_channels if dw else groups ,
8285 ),
83- ]
86+ ),
87+ ]
88+ if se :
89+ layers .append (("se" , se (out_channels )))
90+ if sa :
91+ layers .append (("sa" , sa (out_channels )))
92+ self .convs = nn_seq (layers )
93+ if stride != 1 or in_channels != out_channels :
94+ id_layers : ListStrMod = []
95+ if (
96+ stride != 1 and pool is not None
97+ ): # if pool - reduce by pool else stride 2 art id_conv
98+ id_layers .append (("pool" , pool ()))
99+ if in_channels != out_channels or (stride != 1 and pool is None ):
100+ id_layers .append (
101+ (
102+ "id_conv" ,
103+ conv_layer (
104+ in_channels ,
105+ out_channels ,
106+ 1 ,
107+ stride = 1 if pool else stride ,
108+ act_fn = False ,
109+ ),
110+ )
111+ )
112+ self .id_conv = nn_seq (id_layers )
84113 else :
85- layers = [
86- (
87- "conv_0" ,
88- conv_layer (
89- in_channels ,
90- mid_channels ,
91- 1 ,
92- act_fn = act_fn ,
93- bn_1st = bn_1st ,
94- ),
114+ self .id_conv = None
115+ self .act_fn = get_act (act_fn )
116+
117+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
118+ identity = self .id_conv (x ) if self .id_conv is not None else x
119+ return self .act_fn (self .convs (x ) + identity )
120+
121+
122+ class BottleneckBlock (nn .Module ):
123+ """Bottleneck Resnet block."""
124+
125+ def __init__ (
126+ self ,
127+ in_channels : int ,
128+ out_channels : int ,
129+ stride : int = 1 ,
130+ expansion : int = 4 ,
131+ conv_layer : type [ConvBnAct ] = ConvBnAct ,
132+ act_fn : type [nn .Module ] = nn .ReLU ,
133+ zero_bn : bool = True ,
134+ bn_1st : bool = True ,
135+ groups : int = 1 ,
136+ dw : bool = False ,
137+ div_groups : Union [None , int ] = None ,
138+ pool : Union [Callable [[], nn .Module ], None ] = None ,
139+ se : Union [nn .Module , None ] = None ,
140+ sa : Union [nn .Module , None ] = None ,
141+ ):
142+ super ().__init__ ()
143+ # pool defined at ModelConstructor.
144+ mid_channels = out_channels // expansion
145+ if div_groups is not None : # check if groups != 1 and div_groups
146+ groups = int (mid_channels / div_groups )
147+ layers : ListStrMod = [
148+ (
149+ "conv_0" ,
150+ conv_layer (
151+ in_channels ,
152+ mid_channels ,
153+ 1 ,
154+ act_fn = act_fn ,
155+ bn_1st = bn_1st ,
95156 ),
96- (
97- "conv_1" ,
98- conv_layer (
99- mid_channels ,
100- mid_channels ,
101- 3 ,
102- stride = stride ,
103- act_fn = act_fn ,
104- bn_1st = bn_1st ,
105- groups = mid_channels if dw else groups ,
106- ) ,
157+ ),
158+ (
159+ "conv_1" ,
160+ conv_layer (
161+ mid_channels ,
162+ mid_channels ,
163+ 3 ,
164+ stride = stride ,
165+ act_fn = act_fn ,
166+ bn_1st = bn_1st ,
167+ groups = mid_channels if dw else groups ,
107168 ),
108- (
109- "conv_2" ,
110- conv_layer (
111- mid_channels ,
112- out_channels ,
113- 1 ,
114- zero_bn = zero_bn ,
115- act_fn = False ,
116- bn_1st = bn_1st ,
117- ),
118- ), # noqa E501
119- ]
169+ ),
170+ (
171+ "conv_2" ,
172+ conv_layer (
173+ mid_channels ,
174+ out_channels ,
175+ 1 ,
176+ zero_bn = zero_bn ,
177+ act_fn = False ,
178+ bn_1st = bn_1st ,
179+ ),
180+ ), # noqa E501
181+ ]
120182 if se :
121183 layers .append (("se" , se (out_channels )))
122184 if sa :
123185 layers .append (("sa" , sa (out_channels )))
124- self .convs = nn . Sequential ( OrderedDict ( layers ) )
186+ self .convs = nn_seq ( layers )
125187 if stride != 1 or in_channels != out_channels :
126- id_layers = []
188+ id_layers : ListStrMod = []
127189 if (
128190 stride != 1 and pool is not None
129191 ): # if pool - reduce by pool else stride 2 art id_conv
130192 id_layers .append (("pool" , pool ()))
131193 if in_channels != out_channels or (stride != 1 and pool is None ):
132- id_layers += [
194+ id_layers . append (
133195 (
134196 "id_conv" ,
135197 conv_layer (
@@ -140,21 +202,21 @@ def __init__(
140202 act_fn = False ,
141203 ),
142204 )
143- ]
144- self .id_conv = nn . Sequential ( OrderedDict ( id_layers ) )
205+ )
206+ self .id_conv = nn_seq ( id_layers )
145207 else :
146208 self .id_conv = None
147209 self .act_fn = get_act (act_fn )
148210
149- def forward (self , x ) :
211+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
150212 identity = self .id_conv (x ) if self .id_conv is not None else x
151213 return self .act_fn (self .convs (x ) + identity )
152214
153215
154216def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
155217 """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
156218 len_stem = len (cfg .stem_sizes )
157- stem : list [ tuple [ str , nn . Module ]] = [
219+ stem : ListStrMod = [
158220 (
159221 f"conv_{ i } " ,
160222 cfg .conv_layer (
@@ -180,36 +242,32 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
180242 # if no pool on stem - stride = 2 for first layer block in body
181243 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
182244 num_blocks = cfg .layers [layer_num ]
183- block_chs = [cfg .stem_sizes [- 1 ] // cfg .expansion ] + cfg .block_sizes
184- return nn .Sequential (
185- OrderedDict (
186- [
187- (
188- f"bl_{ block_num } " ,
189- cfg .block (
190- cfg .expansion , # type: ignore
191- block_chs [layer_num ]
192- if block_num == 0
193- else block_chs [layer_num + 1 ],
194- block_chs [layer_num + 1 ],
195- stride if block_num == 0 else 1 ,
196- sa = cfg .sa
197- if (block_num == num_blocks - 1 ) and layer_num == 0
198- else None ,
199- conv_layer = cfg .conv_layer ,
200- act_fn = cfg .act_fn ,
201- pool = cfg .pool ,
202- zero_bn = cfg .zero_bn ,
203- bn_1st = cfg .bn_1st ,
204- groups = cfg .groups ,
205- div_groups = cfg .div_groups ,
206- dw = cfg .dw ,
207- se = cfg .se ,
208- ),
209- )
210- for block_num in range (num_blocks )
211- ]
245+ block_chs = [cfg .stem_sizes [- 1 ]] + cfg .block_sizes
246+ return nn_seq (
247+ (
248+ f"bl_{ block_num } " ,
249+ cfg .block (
250+ # cfg.expansion, # type: ignore
251+ block_chs [layer_num ]
252+ if block_num == 0
253+ else block_chs [layer_num + 1 ],
254+ block_chs [layer_num + 1 ],
255+ stride if block_num == 0 else 1 ,
256+ sa = cfg .sa
257+ if (block_num == num_blocks - 1 ) and layer_num == 0
258+ else None ,
259+ conv_layer = cfg .conv_layer ,
260+ act_fn = cfg .act_fn ,
261+ pool = cfg .pool ,
262+ zero_bn = cfg .zero_bn ,
263+ bn_1st = cfg .bn_1st ,
264+ groups = cfg .groups ,
265+ div_groups = cfg .div_groups ,
266+ dw = cfg .dw ,
267+ se = cfg .se ,
268+ )
212269 )
270+ for block_num in range (num_blocks )
213271 )
214272
215273
@@ -230,7 +288,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
230288 head = [
231289 ("pool" , nn .AdaptiveAvgPool2d (1 )),
232290 ("flat" , nn .Flatten ()),
233- ("fc" , nn .Linear (cfg .block_sizes [- 1 ] * cfg . expansion , cfg .num_classes )),
291+ ("fc" , nn .Linear (cfg .block_sizes [- 1 ], cfg .num_classes )),
234292 ]
235293 return nn .Sequential (OrderedDict (head ))
236294
@@ -241,7 +299,7 @@ class ModelCfg(BaseModel):
241299 name : Optional [str ] = None
242300 in_chans : int = 3
243301 num_classes : int = 1000
244- block : type [nn .Module ] = ResBlock
302+ block : type [nn .Module ] = BasicBlock
245303 conv_layer : type [nn .Module ] = ConvBnAct
246304 block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
247305 layers : list [int ] = [2 , 2 , 2 , 2 ]
0 commit comments