@@ -29,11 +29,11 @@ def __init__(self, expansion, ni, nh, stride=1,
2929 pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False ,sym = False ):
3030 super ().__init__ ()
3131 nf ,ni = nh * expansion ,ni * expansion
32- layers = [(f"conv_0" , conv_layer (ni , nh , 3 , stride = stride )),
32+ layers = [(f"conv_0" , conv_layer (ni , nh , 3 , stride = stride , act_fn = act_fn )),
3333 (f"conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False ))
3434 ] if expansion == 1 else [
35- (f"conv_0" ,conv_layer (ni , nh , 1 )),
36- (f"conv_1" ,conv_layer (nh , nh , 3 , stride = stride )),
35+ (f"conv_0" ,conv_layer (ni , nh , 1 , act_fn = act_fn )),
36+ (f"conv_1" ,conv_layer (nh , nh , 3 , stride = stride , act_fn = act_fn )),
3737 (f"conv_2" ,conv_layer (nh , nf , 1 , zero_bn = zero_bn , act = False ))
3838 ]
3939 if sa : layers .append (('sa' , SimpleSelfAttention (nf ,ks = 1 ,sym = sym )))
@@ -86,8 +86,9 @@ def body(self):
8686
8787 def _make_stem (self ):
8888 stem = [(f"conv_{ i } " , self .conv_layer (self .stem_sizes [i ], self .stem_sizes [i + 1 ],
89- stride = 2 if i == 0 else 1 ,
90- norm = (not self .stem_bn_end ) if i == (len (self .stem_sizes )- 2 ) else True ))
89+ stride = 2 if i == 0 else 1 ,
90+ norm = (not self .stem_bn_end ) if i == (len (self .stem_sizes )- 2 ) else True ,
91+ act_fn = self .act_fn , bn_1st = self .bn_1st ))
9192 for i in range (len (self .stem_sizes )- 1 )]
9293 stem .append (('stem_pool' , self .stem_pool ))
9394 if self .stem_bn_end : stem .append (('norm' , self .norm (self .stem_sizes [- 1 ])))
0 commit comments