Skip to content

Commit b98fb43

Browse files
author
ayasyrev
committed
fix act_fn
1 parent 1d006fa commit b98fb43

File tree

2 files changed

+276
-414
lines changed

2 files changed

+276
-414
lines changed

model_constructor/net.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def __init__(self, expansion, ni, nh, stride=1,
2929
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):
3030
super().__init__()
3131
nf,ni = nh*expansion,ni*expansion
32-
layers = [(f"conv_0", conv_layer(ni, nh, 3, stride=stride)),
32+
layers = [(f"conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn)),
3333
(f"conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False))
3434
] if expansion == 1 else [
35-
(f"conv_0",conv_layer(ni, nh, 1)),
36-
(f"conv_1",conv_layer(nh, nh, 3, stride=stride)),
35+
(f"conv_0",conv_layer(ni, nh, 1, act_fn=act_fn)),
36+
(f"conv_1",conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn)),
3737
(f"conv_2",conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False))
3838
]
3939
if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))
@@ -86,8 +86,9 @@ def body(self):
8686

8787
def _make_stem(self):
8888
stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1],
89-
stride=2 if i==0 else 1,
90-
norm=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True))
89+
stride=2 if i==0 else 1,
90+
norm=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,
91+
act_fn=self.act_fn, bn_1st=self.bn_1st))
9192
for i in range(len(self.stem_sizes)-1)]
9293
stem.append(('stem_pool', self.stem_pool))
9394
if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))

0 commit comments

Comments
 (0)