1+ # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/81_Net.ipynb (unless otherwise specified).
2+
3+ __all__ = ['init_cnn' , 'act_fn' , 'ResBlock' , 'NewResBlock' , 'NewConvLayer' , 'Net' , 'me' , 'xresnet50' ]
4+
5+ # Cell
6+ import torch .nn as nn
7+ import sys , torch
8+ from functools import partial
9+ from collections import OrderedDict
10+
11+ from .layers import *
12+
13+ # Cell
14+ act_fn = nn .ReLU (inplace = True )
15+
16+ def init_cnn (m ):
17+ if getattr (m , 'bias' , None ) is not None : nn .init .constant_ (m .bias , 0 )
18+ if isinstance (m , (nn .Conv2d ,nn .Linear )): nn .init .kaiming_normal_ (m .weight )
19+ for l in m .children (): init_cnn (l )
20+
21+
22+ # Cell
23+ class ResBlock (nn .Module ):
24+ def __init__ (self , expansion , ni , nh , stride = 1 , sa = False ,
25+ conv_layer = ConvLayer , act_fn = act_fn , zero_bn = True ,
26+ pool = nn .AvgPool2d (2 , ceil_mode = True ), sym = False ):
27+ super ().__init__ ()
28+ nf ,ni = nh * expansion ,ni * expansion
29+ layers = [(f"conv_0" , conv_layer (ni , nh , 3 , stride = stride )),
30+ (f"conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False ))
31+ ] if expansion == 1 else [
32+ (f"conv_0" ,conv_layer (ni , nh , 1 )),
33+ (f"conv_1" ,conv_layer (nh , nh , 3 , stride = stride )),
34+ (f"conv_2" ,conv_layer (nh , nf , 1 , zero_bn = zero_bn , act = False ))
35+ ]
36+ if sa : layers .append (('sa' , SimpleSelfAttention (nf ,ks = 1 ,sym = sym )))
37+ self .convs = nn .Sequential (OrderedDict (layers ))
38+ self .pool = noop if stride == 1 else pool
39+ self .idconv = noop if ni == nf else conv_layer (ni , nf , 1 , act = False )
40+ self .act_fn = act_fn
41+
42+ def forward (self , x ): return self .act_fn (self .convs (x ) + self .idconv (self .pool (x )))
43+
44+ # Cell
45+ # class version
46+ class _ResBlock (Constructor ):
47+ def __init__ (self , expansion = 1 , conv_layer = ConvLayer , act_fn = act_fn , zero_bn = True ,
48+ pool = nn .AvgPool2d (2 , ceil_mode = True ), sym = False ):
49+ super ().__init__ ()
50+ self .__dict__ .update (locals ())
51+ # self.__dict__.update(self.__dict__.pop('kwargs'))
52+
53+
54+ def __call__ (self , ni , nh , stride = 1 , sa = False ):
55+ return ResBlock (self .expansion , ni ,nh ,stride ,sa ,
56+ self .conv_layer , self .act_fn , self .zero_bn , self .pool , self .sym )
57+ def __getattr__ (self , k ):
58+ if hasattr (self , '_model' ):
59+ return getattr (self ._model , k )
60+
61+ # Cell
62+ # Still no name - just New block YET!
63+ class NewResBlock (nn .Module ):
64+ def __init__ (self , expansion , ni , nh , stride = 1 ,
65+ conv_layer = ConvLayer , act_fn = act_fn ,
66+ pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False ,sym = False , zero_bn = True ):
67+ super ().__init__ ()
68+ nf ,ni = nh * expansion ,ni * expansion
69+ self .reduce = noop if stride == 1 else pool
70+ layers = [(f"conv_0" , conv_layer (ni , nh , 3 , stride = 1 )), # stride 1 !!!
71+ (f"conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False ))
72+ ] if expansion == 1 else [
73+ (f"conv_0" ,conv_layer (ni , nh , 1 )),
74+ (f"conv_1" ,conv_layer (nh , nh , 3 , stride = 1 )), # stride 1 !!!
75+ (f"conv_2" ,conv_layer (nh , nf , 1 , zero_bn = zero_bn , act = False )) ### act!!!
76+ ]
77+ if sa : layers .append (('sa' , SimpleSelfAttention (nf ,ks = 1 ,sym = sym )))
78+ self .convs = nn .Sequential (OrderedDict (layers ))
79+ self .idconv = noop if ni == nf else conv_layer (ni , nf , 1 , act = False )
80+ self .merge = act_fn
81+
82+ def forward (self , x ):
83+ o = self .reduce (x )
84+ return self .merge (self .convs (o ) + self .idconv (o ))
85+
86+ # Cell
87+ class NewConvLayer (Constructor ):
88+ """Basic conv layers block"""
89+ def __init__ (self , bn_1st = True , act_fn = act_fn , norm = nn .BatchNorm2d ,
90+ padding = None , bias = False , groups = 1 , ** kwargs ):
91+ super ().__init__ ()
92+ self .__dict__ .update (locals ())
93+ self .__dict__ .update (self .__dict__ .pop ('kwargs' ))
94+
95+ def __call__ (self , ni , nf , ks = 3 , stride = 1 ,
96+ act = True , bn_layer = True , zero_bn = False , ** kwargs ): # todo check kwargs
97+ padding = ks // 2 if self .padding == None else self .padding
98+ layers = [('conv' , nn .Conv2d (ni , nf , ks , stride = stride , padding = padding ,
99+ bias = self .bias , groups = self .groups ))]
100+ act_bn = [('act_fn' , self .act_fn )] if act else []
101+ if bn_layer :
102+ bn = self .norm (nf )
103+ nn .init .constant_ (bn .weight , 0. if zero_bn else 1. )
104+ act_bn += [('bn' , bn )]
105+ if self .bn_1st : act_bn .reverse ()
106+ layers += act_bn
107+ return nn .Sequential (OrderedDict (layers ))
108+
109+ def __getattr__ (self , k ):
110+ if hasattr (self , '_model' ):
111+ return getattr (self ._model , k )
112+
113+ # Cell
114+ # v9
115+ class Net ():
116+ def __init__ (self , expansion = 1 , layers = [2 ,2 ,2 ,2 ], c_in = 3 , c_out = 1000 , name = 'Net' ):
117+ super ().__init__ ()
118+ self .name = name
119+ self .c_in , self .c_out ,self .expansion ,self .layers = c_in ,c_out ,expansion ,layers # todo setter for expansion
120+ self .stem_sizes = [c_in ,32 ,32 ,64 ]
121+ self .stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
122+ self .stem_bn_end = False
123+ self .norm = nn .BatchNorm2d
124+ self .act_fn = nn .ReLU (inplace = True )
125+ self .pool = nn .AvgPool2d (2 , ceil_mode = True )
126+ self .sa = False
127+ self .bn_1st = True
128+ self .zero_bn = True
129+ self ._init_cnn = init_cnn
130+ self .block = ResBlock
131+
132+ # self.conv_layer = ConvLayer
133+ self .conv_layer = NewConvLayer
134+
135+ @property
136+ def conv_layer (self ): return self ._conv_layer
137+ @conv_layer .setter
138+ def conv_layer (self , f ):
139+ self ._conv_layer = f ()
140+ self ._conv_layer .register (self )
141+
142+
143+ @property
144+ def block_szs (self ):
145+ return [64 // self .expansion ,64 ,128 ,256 ,512 ] + [256 ]* (len (self .layers )- 4 )
146+
147+ @property
148+ def stem (self ):
149+ return self ._make_stem ()
150+ @property
151+ def head (self ):
152+ return self ._make_head ()
153+ @property
154+ def body (self ):
155+ return self ._make_body ()
156+
157+ def _block (self , ni , nh , stride = 1 , sa = False ,sym = False ):
158+ return self .block (self .expansion , ni , nh , stride ,
159+ self .conv_layer , self .act_fn ,
160+ pool = self .pool , sa = sa ,sym = sym , zero_bn = self .zero_bn )
161+
162+
163+ def _make_stem (self ):
164+ stem = [(f"conv_{ i } " , self ._conv_layer (self .stem_sizes [i ], self .stem_sizes [i + 1 ],
165+ stride = 2 if i == 0 else 1 ,
166+ bn_layer = (not self .stem_bn_end ) if i == (len (self .stem_sizes )- 2 ) else True ,))
167+ for i in range (len (self .stem_sizes )- 1 )]
168+ stem .append (('stem_pool' , self .stem_pool ))
169+ if self .stem_bn_end : stem .append (('norm' , self .norm (self .stem_sizes [- 1 ])))
170+ return nn .Sequential (OrderedDict (stem ))
171+
172+ def _make_head (self ):
173+ head = [('pool' , nn .AdaptiveAvgPool2d (1 )),
174+ ('flat' , Flatten ()),
175+ ('fc' , nn .Linear (self .block_szs [- 1 ]* self .expansion , self .c_out ))]
176+ return nn .Sequential (OrderedDict (head ))
177+
178+ def _make_body (self ):
179+ blocks = [(f"l_{ i } " , self ._make_layer (self .block_szs [i ], self .block_szs [i + 1 ], l ,
180+ 1 if i == 0 else 2 , self .sa if i == 0 else False ))
181+ for i ,l in enumerate (self .layers )]
182+ return nn .Sequential (OrderedDict (blocks ))
183+
184+ def _make_layer (self ,ni ,nf ,blocks ,stride ,sa ):
185+ return nn .Sequential (OrderedDict (
186+ [(f"bl_{ i } " , self ._block (ni if i == 0 else nf , nf ,
187+ stride if i == 0 else 1 , sa = sa if i == blocks - 1 else False ))
188+ for i in range (blocks )]))
189+
190+ def __call__ (self ):
191+ model = nn .Sequential (OrderedDict ([
192+ ('stem' , self .stem ),
193+ ('body' , self .body ),
194+ ('head' , self .head )
195+ ]))
196+ self ._init_cnn (model )
197+ model .extra_repr = lambda : f"model { self .name } "
198+ return model
199+ def __repr__ (self ):
200+ return f" constr { self .name } "
201+
202+ # Cell
203+ me = sys .modules [__name__ ]
204+ for n ,e ,l in [[ 18 , 1 , [2 ,2 ,2 ,2 ] ],
205+ [ 34 , 1 , [3 ,4 ,6 ,3 ] ],
206+ [ 50 , 4 , [3 ,4 ,6 ,3 ] ],
207+ [ 101 , 4 , [3 ,4 ,23 ,3 ] ],
208+ [ 152 , 4 , [3 ,8 ,36 ,3 ] ],]:
209+ name = f'net{ n } '
210+ setattr (me , name , partial (Net , expansion = e , layers = l , name = name ))
211+ xresnet50 = partial (Net , expansion = 4 , layers = [3 , 4 , 6 , 3 ], name = 'xresnet50' )
0 commit comments