44from collections import OrderedDict
55from typing import Callable , Union
66
7- import torch .nn as nn
7+ import torch
8+ from torch import nn
89from torch .nn import Mish
910
11+ from model_constructor .helpers import nn_seq
12+
1013from .layers import ConvBnAct , get_act
11- from .model_constructor import ModelConstructor
14+ from .model_constructor import ListStrMod , ModelConstructor
1215
1316__all__ = [
14- "YaResBlock" ,
17+ "YaBasicBlock" ,
18+ "YaBottleneckBlock" ,
19+ "YaResNet" ,
1520 "YaResNet34" ,
1621 "YaResNet50" ,
1722]
1823
1924
20- class YaResBlock (nn .Module ):
21- """YaResBlock. Reduce by pool instead of stride 2"""
25+ class YaBasicBlock (nn .Module ):
26+ """Ya Basic block.
27+ Reduce by pool instead of stride 2"""
2228
2329 def __init__ (
2430 self ,
25- expansion : int ,
2631 in_channels : int ,
27- mid_channels : int ,
32+ out_channels : int ,
2833 stride : int = 1 ,
29- conv_layer = ConvBnAct ,
34+ conv_layer : type [ ConvBnAct ] = ConvBnAct ,
3035 act_fn : type [nn .Module ] = nn .ReLU ,
3136 zero_bn : bool = True ,
3237 bn_1st : bool = True ,
3338 groups : int = 1 ,
3439 dw : bool = False ,
3540 div_groups : Union [None , int ] = None ,
3641 pool : Union [Callable [[], nn .Module ], None ] = None ,
37- se : Union [type [ nn .Module ] , None ] = None ,
38- sa : Union [type [ nn .Module ] , None ] = None ,
42+ se : Union [nn .Module , None ] = None ,
43+ sa : Union [nn .Module , None ] = None ,
3944 ):
4045 super ().__init__ ()
4146 # pool defined at ModelConstructor.
42- out_channels , in_channels = mid_channels * expansion , in_channels * expansion
4347 if div_groups is not None : # check if groups != 1 and div_groups
44- groups = int (mid_channels / div_groups )
48+ groups = int (out_channels / div_groups )
4549
4650 if stride != 1 :
4751 if pool is None :
@@ -51,74 +55,133 @@ def __init__(
5155 self .reduce = pool ()
5256 else :
5357 self .reduce = None
54- if expansion == 1 :
55- layers = [
56- (
57- "conv_0" ,
58- conv_layer (
59- in_channels ,
60- mid_channels ,
61- 3 ,
62- stride = 1 ,
63- act_fn = act_fn ,
64- bn_1st = bn_1st ,
65- groups = in_channels if dw else groups ,
66- ),
58+
59+ layers : ListStrMod = [
60+ (
61+ "conv_0" ,
62+ conv_layer (
63+ in_channels ,
64+ out_channels ,
65+ 3 ,
66+ act_fn = act_fn ,
67+ bn_1st = bn_1st ,
68+ groups = in_channels if dw else groups ,
6769 ),
68- (
69- "conv_1" ,
70- conv_layer (
71- mid_channels ,
72- out_channels ,
73- 3 ,
74- zero_bn = zero_bn ,
75- act_fn = False ,
76- bn_1st = bn_1st ,
77- groups = mid_channels if dw else groups ,
78- ) ,
70+ ),
71+ (
72+ "conv_1" ,
73+ conv_layer (
74+ out_channels ,
75+ out_channels ,
76+ 3 ,
77+ zero_bn = zero_bn ,
78+ act_fn = False ,
79+ bn_1st = bn_1st ,
80+ groups = out_channels if dw else groups ,
7981 ),
80- ]
82+ ),
83+ ]
84+ if se :
85+ layers .append (("se" , se (out_channels )))
86+ if sa :
87+ layers .append (("sa" , sa (out_channels )))
88+ self .convs = nn_seq (layers )
89+
90+ if in_channels != out_channels :
91+ self .id_conv = conv_layer (
92+ in_channels ,
93+ out_channels ,
94+ 1 ,
95+ stride = 1 ,
96+ act_fn = False ,
97+ )
8198 else :
82- layers = [
83- (
84- "conv_0" ,
85- conv_layer (
86- in_channels ,
87- mid_channels ,
88- 1 ,
89- act_fn = act_fn ,
90- bn_1st = bn_1st ,
91- ),
99+ self .id_conv = None
100+ self .merge = get_act (act_fn )
101+
102+ def forward (self , x : torch .Tensor ) -> torch .Tensor : # type: ignore
103+ if self .reduce :
104+ x = self .reduce (x )
105+ identity = self .id_conv (x ) if self .id_conv is not None else x
106+ return self .merge (self .convs (x ) + identity )
107+
108+
109+ class YaBottleneckBlock (nn .Module ):
110+ """Ya Bottleneck block.
111+ Reduce by pool instead of stride 2"""
112+
113+ def __init__ (
114+ self ,
115+ in_channels : int ,
116+ out_channels : int ,
117+ stride : int = 1 ,
118+ expansion : int = 4 ,
119+ conv_layer : type [ConvBnAct ] = ConvBnAct ,
120+ act_fn : type [nn .Module ] = nn .ReLU ,
121+ zero_bn : bool = True ,
122+ bn_1st : bool = True ,
123+ groups : int = 1 ,
124+ dw : bool = False ,
125+ div_groups : Union [None , int ] = None ,
126+ pool : Union [Callable [[], nn .Module ], None ] = None ,
127+ se : Union [nn .Module , None ] = None ,
128+ sa : Union [nn .Module , None ] = None ,
129+ ):
130+ super ().__init__ ()
131+ # pool defined at ModelConstructor.
132+ mid_channels = out_channels // expansion
133+ if div_groups is not None : # check if groups != 1 and div_groups
134+ groups = int (mid_channels / div_groups )
135+
136+ if stride != 1 :
137+ if pool is None :
138+ self .reduce = conv_layer (in_channels , in_channels , 1 , stride = 2 )
139+ # warnings.warn("pool not passed") # need to warn?
140+ else :
141+ self .reduce = pool ()
142+ else :
143+ self .reduce = None
144+
145+ layers : ListStrMod = [
146+ (
147+ "conv_0" ,
148+ conv_layer (
149+ in_channels ,
150+ mid_channels ,
151+ 1 ,
152+ act_fn = act_fn ,
153+ bn_1st = bn_1st ,
154+ ),
155+ ),
156+ (
157+ "conv_1" ,
158+ conv_layer (
159+ mid_channels ,
160+ mid_channels ,
161+ 3 ,
162+ act_fn = act_fn ,
163+ bn_1st = bn_1st ,
164+ groups = mid_channels if dw else groups ,
92165 ),
93- (
94- "conv_1" ,
95- conv_layer (
96- mid_channels ,
97- mid_channels ,
98- 3 ,
99- stride = 1 ,
100- act_fn = act_fn ,
101- bn_1st = bn_1st ,
102- groups = mid_channels if dw else groups ,
103- ),
166+ ),
167+ (
168+ "conv_2" ,
169+ conv_layer (
170+ mid_channels ,
171+ out_channels ,
172+ 1 ,
173+ zero_bn = zero_bn ,
174+ act_fn = False ,
175+ bn_1st = bn_1st ,
104176 ),
105- (
106- "conv_2" ,
107- conv_layer (
108- mid_channels ,
109- out_channels ,
110- 1 ,
111- zero_bn = zero_bn ,
112- act_fn = False ,
113- bn_1st = bn_1st ,
114- ),
115- ), # noqa E501
116- ]
177+ ),
178+ ]
117179 if se :
118- layers .append (("se" , se (out_channels ))) # type: ignore
180+ layers .append (("se" , se (out_channels )))
119181 if sa :
120- layers .append (("sa" , sa (out_channels ))) # type: ignore
121- self .convs = nn .Sequential (OrderedDict (layers ))
182+ layers .append (("sa" , sa (out_channels )))
183+ self .convs = nn_seq (layers )
184+
122185 if in_channels != out_channels :
123186 self .id_conv = conv_layer (
124187 in_channels ,
@@ -131,20 +194,23 @@ def __init__(
131194 self .id_conv = None
132195 self .merge = get_act (act_fn )
133196
134- def forward (self , x ):
197+ def forward (self , x : torch . Tensor ) -> torch . Tensor : # type: ignore
135198 if self .reduce :
136199 x = self .reduce (x )
137200 identity = self .id_conv (x ) if self .id_conv is not None else x
138201 return self .merge (self .convs (x ) + identity )
139202
140203
141- class YaResNet34 (ModelConstructor ):
142- block : type [nn .Module ] = YaResBlock
143- expansion : int = 1
144- layers : list [int ] = [3 , 4 , 6 , 3 ]
204+ class YaResNet (ModelConstructor ):
205+ block : type [nn .Module ] = YaBasicBlock
145206 stem_sizes : list [int ] = [3 , 32 , 64 , 64 ]
146207 act_fn : type [nn .Module ] = Mish
147208
148209
210+ class YaResNet34 (YaResNet ):
211+ stem_sizes : list [int ] = [3 , 32 , 64 , 64 ]
212+
213+
149214class YaResNet50 (YaResNet34 ):
150- expansion : int = 4
215+ block : type [nn .Module ] = YaBottleneckBlock
216+ block_sizes : list [int ] = [256 , 512 , 1024 , 2048 ]
0 commit comments