Skip to content

Commit e43e44c

Browse files
committed
fix base constructor
1 parent ff930ec commit e43e44c

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

model_constructor/base_constructor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
__all__ = ['act_fn', 'Stem', 'DownsampleBlock', 'BasicBlock', 'Bottleneck', 'BasicLayer', 'Body', 'Head', 'init_model',
2-
'Net']
3-
4-
# Cell
51
import torch.nn as nn
62
from collections import OrderedDict
73
from .layers import ConvLayer, Noop, Flatten
84

9-
# Cell
5+
6+
__all__ = ['act_fn', 'Stem', 'DownsampleBlock', 'BasicBlock', 'Bottleneck', 'BasicLayer', 'Body', 'Head', 'init_model',
7+
'Net']
8+
9+
1010
act_fn = nn.ReLU(inplace=True)
1111

12-
# Cell
12+
1313
class Stem(nn.Sequential):
1414
"""Base stem"""
1515

@@ -35,12 +35,12 @@ def __init__(self, c_in=3, stem_sizes=[], stem_out=64,
3535
def extra_repr(self):
3636
return f"sizes: {self.sizes}"
3737

38-
# Cell
38+
3939
def DownsampleBlock(conv_layer, ni, nf, ks, stride, act=False, **kwargs):
4040
'''Base downsample for res-like blocks'''
4141
return conv_layer(ni, nf, ks, stride, act, **kwargs)
4242

43-
# Cell
43+
4444
class BasicBlock(nn.Module):
4545
"""Basic block (simplified) as in pytorch resnet"""
4646
def __init__(self, ni, nf, expansion=1, stride=1, zero_bn=False,
@@ -63,7 +63,7 @@ def forward(self, x):
6363
identity = self.downsample(x)
6464
return self.act_conn(self.merge(out + identity))
6565

66-
# Cell
66+
6767
class Bottleneck(nn.Module):
6868
'''Bottlneck block for resnet models'''
6969
def __init__(self, ni, nh, expansion=4, stride=1, zero_bn=False,
@@ -90,7 +90,7 @@ def forward(self, x):
9090
identity = self.downsample(x)
9191
return self.act_conn(self.merge(out + identity))
9292

93-
# Cell
93+
9494
class BasicLayer(nn.Sequential):
9595
'''Layer from blocks'''
9696
def __init__(self, block, blocks, ni, nf, expansion, stride, sa=False, **kwargs):
@@ -108,7 +108,7 @@ def __init__(self, block, blocks, ni, nf, expansion, stride, sa=False, **kwargs)
108108
def extra_repr(self):
109109
return f'from {self.ni * self.expansion} to {self.nf}, {self.blocks} blocks, expansion {self.expansion}.'
110110

111-
# Cell
111+
112112
class Body(nn.Sequential):
113113
'''Constructor for body'''
114114
def __init__(self, block,
@@ -126,7 +126,7 @@ def __init__(self, block,
126126
for i in range(num_layers)]
127127
super().__init__(OrderedDict(layers))
128128

129-
# Cell
129+
130130
class Head(nn.Sequential):
131131
'''base head'''
132132
def __init__(self, ni, nf, **kwargs):
@@ -136,14 +136,14 @@ def __init__(self, ni, nf, **kwargs):
136136
('fc', nn.Linear(ni, nf)),
137137
]))
138138

139-
# Cell
139+
140140
def init_model(model, nonlinearity='leaky_relu'):
141141
'''Init model'''
142142
for m in model.modules():
143143
if isinstance(m, (nn.Conv2d, nn.Linear)):
144144
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity=nonlinearity)
145145

146-
# Cell
146+
147147
class Net(nn.Sequential):
148148
'''Constructor for model'''
149149
def __init__(self, stem=Stem,
@@ -161,4 +161,4 @@ def __init__(self, stem=Stem,
161161
sa=sa, **kwargs)),
162162
('head', head(body_out * expansion, num_classes, **kwargs))
163163
]))
164-
self.init_model(self)
164+
self.init_model(self)

0 commit comments

Comments
 (0)