1- __all__ = ['DownsampleLayer' , 'XResBlock' , 'xresnet18' , 'xresnet34' , 'xresnet50' ]
2-
3- # Cell
41import torch .nn as nn
52from collections import OrderedDict
63
7- # Cell
84from .base_constructor import Net
95from .layers import ConvLayer , Noop , act_fn
106
11- # Cell
7+
8+ __all__ = ['DownsampleLayer' , 'XResBlock' , 'xresnet18' , 'xresnet34' , 'xresnet50' ]
9+
10+
1211class DownsampleLayer (nn .Sequential ):
1312 """Downsample layer for Xresnet Resblock"""
13+
1414 def __init__ (self , conv_layer , ni , nf , stride , act ,
1515 pool = nn .AvgPool2d (2 , ceil_mode = True ), pool_1st = True ,
1616 ** kwargs ):
@@ -20,8 +20,10 @@ def __init__(self, conv_layer, ni, nf, stride, act,
2020 layers .reverse ()
2121 super ().__init__ (OrderedDict (layers ))
2222
23- # Cell
23+
2424class XResBlock (nn .Module ):
25+ '''XResnet block'''
26+
2527 def __init__ (self , ni , nh , expansion = 1 , stride = 1 , zero_bn = True ,
2628 conv_layer = ConvLayer , act_fn = act_fn , ** kwargs ):
2729 super ().__init__ ()
@@ -42,15 +44,17 @@ def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
4244 def forward (self , x ):
4345 return self .act_fn (self .merge (self .convs (x ) + self .identity (x )))
4446
45- # Cell
47+
4648def xresnet18 (** kwargs ):
4749 """Constructs a xresnet-18 model. """
4850 return Net (stem_sizes = [32 , 32 ], block = XResBlock , blocks = [2 , 2 , 2 , 2 ], expansion = 1 , ** kwargs )
4951
52+
5053def xresnet34 (** kwargs ):
5154 """Constructs axresnet-34 model. """
5255 return Net (stem_sizes = [32 , 32 ], block = XResBlock , blocks = [3 , 4 , 6 , 3 ], expansion = 1 , ** kwargs )
5356
57+
5458def xresnet50 (** kwargs ):
5559 """Constructs axresnet-34 model. """
56- return Net (stem_sizes = [32 , 32 ], block = XResBlock , blocks = [3 , 4 , 6 , 3 ], expansion = 4 , ** kwargs )
60+ return Net (stem_sizes = [32 , 32 ], block = XResBlock , blocks = [3 , 4 , 6 , 3 ], expansion = 4 , ** kwargs )
0 commit comments