1- __all__ = ['act_fn' , 'Stem' , 'DownsampleBlock' , 'BasicBlock' , 'Bottleneck' , 'BasicLayer' , 'Body' , 'Head' , 'init_model' ,
2- 'Net' ]
3-
4- # Cell
51import torch .nn as nn
62from collections import OrderedDict
73from .layers import ConvLayer , Noop , Flatten
84
9- # Cell
5+
6+ __all__ = ['act_fn' , 'Stem' , 'DownsampleBlock' , 'BasicBlock' , 'Bottleneck' , 'BasicLayer' , 'Body' , 'Head' , 'init_model' ,
7+ 'Net' ]
8+
9+
1010act_fn = nn .ReLU (inplace = True )
1111
12- # Cell
12+
1313class Stem (nn .Sequential ):
1414 """Base stem"""
1515
@@ -35,12 +35,12 @@ def __init__(self, c_in=3, stem_sizes=[], stem_out=64,
3535 def extra_repr (self ):
3636 return f"sizes: { self .sizes } "
3737
38- # Cell
38+
3939def DownsampleBlock (conv_layer , ni , nf , ks , stride , act = False , ** kwargs ):
4040 '''Base downsample for res-like blocks'''
4141 return conv_layer (ni , nf , ks , stride , act , ** kwargs )
4242
43- # Cell
43+
4444class BasicBlock (nn .Module ):
4545 """Basic block (simplified) as in pytorch resnet"""
4646 def __init__ (self , ni , nf , expansion = 1 , stride = 1 , zero_bn = False ,
@@ -63,7 +63,7 @@ def forward(self, x):
6363 identity = self .downsample (x )
6464 return self .act_conn (self .merge (out + identity ))
6565
66- # Cell
66+
6767class Bottleneck (nn .Module ):
6868 '''Bottlneck block for resnet models'''
6969 def __init__ (self , ni , nh , expansion = 4 , stride = 1 , zero_bn = False ,
@@ -90,7 +90,7 @@ def forward(self, x):
9090 identity = self .downsample (x )
9191 return self .act_conn (self .merge (out + identity ))
9292
93- # Cell
93+
9494class BasicLayer (nn .Sequential ):
9595 '''Layer from blocks'''
9696 def __init__ (self , block , blocks , ni , nf , expansion , stride , sa = False , ** kwargs ):
@@ -108,7 +108,7 @@ def __init__(self, block, blocks, ni, nf, expansion, stride, sa=False, **kwargs)
108108 def extra_repr (self ):
109109 return f'from { self .ni * self .expansion } to { self .nf } , { self .blocks } blocks, expansion { self .expansion } .'
110110
111- # Cell
111+
112112class Body (nn .Sequential ):
113113 '''Constructor for body'''
114114 def __init__ (self , block ,
@@ -126,7 +126,7 @@ def __init__(self, block,
126126 for i in range (num_layers )]
127127 super ().__init__ (OrderedDict (layers ))
128128
129- # Cell
129+
130130class Head (nn .Sequential ):
131131 '''base head'''
132132 def __init__ (self , ni , nf , ** kwargs ):
@@ -136,14 +136,14 @@ def __init__(self, ni, nf, **kwargs):
136136 ('fc' , nn .Linear (ni , nf )),
137137 ]))
138138
139- # Cell
139+
140140def init_model (model , nonlinearity = 'leaky_relu' ):
141141 '''Init model'''
142142 for m in model .modules ():
143143 if isinstance (m , (nn .Conv2d , nn .Linear )):
144144 nn .init .kaiming_normal_ (m .weight , mode = 'fan_in' , nonlinearity = nonlinearity )
145145
146- # Cell
146+
147147class Net (nn .Sequential ):
148148 '''Constructor for model'''
149149 def __init__ (self , stem = Stem ,
@@ -161,4 +161,4 @@ def __init__(self, stem=Stem,
161161 sa = sa , ** kwargs )),
162162 ('head' , head (body_out * expansion , num_classes , ** kwargs ))
163163 ]))
164- self .init_model (self )
164+ self .init_model (self )
0 commit comments