Skip to content

Commit d387469

Browse files
committed
fix xresnet
1 parent f1f43a2 commit d387469

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

model_constructor/xresnet.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
__all__ = ['DownsampleLayer', 'XResBlock', 'xresnet18', 'xresnet34', 'xresnet50']
2-
3-
# Cell
41
import torch.nn as nn
52
from collections import OrderedDict
63

7-
# Cell
84
from .base_constructor import Net
95
from .layers import ConvLayer, Noop, act_fn
106

11-
# Cell
7+
8+
__all__ = ['DownsampleLayer', 'XResBlock', 'xresnet18', 'xresnet34', 'xresnet50']
9+
10+
1211
class DownsampleLayer(nn.Sequential):
1312
"""Downsample layer for Xresnet Resblock"""
13+
1414
def __init__(self, conv_layer, ni, nf, stride, act,
1515
pool=nn.AvgPool2d(2, ceil_mode=True), pool_1st=True,
1616
**kwargs):
@@ -20,8 +20,10 @@ def __init__(self, conv_layer, ni, nf, stride, act,
2020
layers.reverse()
2121
super().__init__(OrderedDict(layers))
2222

23-
# Cell
23+
2424
class XResBlock(nn.Module):
25+
'''XResnet block'''
26+
2527
def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
2628
conv_layer=ConvLayer, act_fn=act_fn, **kwargs):
2729
super().__init__()
@@ -42,15 +44,17 @@ def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
4244
def forward(self, x):
4345
return self.act_fn(self.merge(self.convs(x) + self.identity(x)))
4446

45-
# Cell
47+
4648
def xresnet18(**kwargs):
4749
"""Constructs a xresnet-18 model. """
4850
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs)
4951

52+
5053
def xresnet34(**kwargs):
5154
"""Constructs axresnet-34 model. """
5255
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs)
5356

57+
5458
def xresnet50(**kwargs):
5559
"""Constructs axresnet-34 model. """
56-
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs)
60+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs)

0 commit comments

Comments
 (0)