1313import torch .nn as nn
1414
1515
16- @torch .jit .script
17- def evo_batch_jit (
18- x : torch .Tensor , v : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor , running_var : torch .Tensor ,
19- momentum : float , training : bool , nonlin : bool , eps : float ):
20- x_type = x .dtype
21- running_var = running_var .detach () # FIXME why is this needed, it's a buffer?
22- if training :
23- var = x .var (dim = (0 , 2 , 3 ), unbiased = False , keepdim = True ) # FIXME biased, unbiased?
24- running_var .copy_ (momentum * var + (1 - momentum ) * running_var )
25- else :
26- var = running_var .clone ()
27-
28- if nonlin :
29- # FIXME biased, unbiased?
30- d = (x * v .to (x_type )) + x .var (dim = (2 , 3 ), unbiased = False , keepdim = True ).add_ (eps ).sqrt_ ().to (dtype = x_type )
31- d = d .max (var .add (eps ).sqrt_ ().to (dtype = x_type ))
32- x = x / d
33- return x .mul_ (weight ).add_ (bias )
34- else :
35- return x .mul (weight ).add_ (bias )
36-
37-
3816class EvoNormBatch2d (nn .Module ):
39- def __init__ (self , num_features , momentum = 0.1 , nonlin = True , eps = 1e-5 , jit = True ):
17+ def __init__ (self , num_features , momentum = 0.1 , nonlin = True , eps = 1e-5 ):
4018 super (EvoNormBatch2d , self ).__init__ ()
4119 self .momentum = momentum
4220 self .nonlin = nonlin
4321 self .eps = eps
44- self .jit = jit
4522 param_shape = (1 , num_features , 1 , 1 )
4623 self .weight = nn .Parameter (torch .ones (param_shape ), requires_grad = True )
4724 self .bias = nn .Parameter (torch .zeros (param_shape ), requires_grad = True )
@@ -58,50 +35,29 @@ def reset_parameters(self):
5835
5936 def forward (self , x ):
6037 assert x .dim () == 4 , 'expected 4D input'
61-
62- if self .jit :
63- return evo_batch_jit (
64- x , self .v , self .weight , self .bias , self .running_var , self .momentum ,
65- self .training , self .nonlin , self .eps )
38+ x_type = x .dtype
39+ if self .training :
40+ var = x .var (dim = (0 , 2 , 3 ), unbiased = False , keepdim = True )
41+ self .running_var .copy_ (self .momentum * var .detach () + (1 - self .momentum ) * self .running_var )
6642 else :
67- x_type = x .dtype
68- if self .training :
69- var = x .var (dim = (0 , 2 , 3 ), keepdim = True )
70- self .running_var .copy_ (self .momentum * var + (1 - self .momentum ) * self .running_var )
71- else :
72- var = self .running_var .clone ()
73-
74- if self .nonlin :
75- v = self .v .to (dtype = x_type )
76- d = (x * v ) + x .var (dim = (2 , 3 ), keepdim = True ).add_ (self .eps ).sqrt_ ().to (dtype = x_type )
77- d = d .max (var .add (self .eps ).sqrt_ ().to (dtype = x_type ))
78- x = x / d
79- return x .mul_ (self .weight ).add_ (self .bias )
80- else :
81- return x .mul (self .weight ).add_ (self .bias )
43+ var = self .running_var .clone ()
8244
83-
84- @torch .jit .script
85- def evo_sample_jit (
86- x : torch .Tensor , v : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor ,
87- groups : int , nonlin : bool , eps : float ):
88- B , C , H , W = x .shape
89- assert C % groups == 0
90- if nonlin :
91- n = (x * v ).sigmoid_ ().reshape (B , groups , - 1 )
92- x = x .reshape (B , groups , - 1 )
93- x = n / x .var (dim = - 1 , unbiased = False , keepdim = True ).add_ (eps ).sqrt_ ()
94- x = x .reshape (B , C , H , W )
95- return x .mul_ (weight ).add_ (bias )
45+ if self .nonlin :
46+ v = self .v .to (dtype = x_type )
47+ d = (x * v ) + x .var (dim = (2 , 3 ), unbiased = False , keepdim = True ).add_ (self .eps ).sqrt_ ().to (dtype = x_type )
48+ d = d .max (var .add_ (self .eps ).sqrt_ ().to (dtype = x_type ))
49+ x = x / d
50+ return x .mul_ (self .weight ).add_ (self .bias )
51+ else :
52+ return x .mul (self .weight ).add_ (self .bias )
9653
9754
9855class EvoNormSample2d (nn .Module ):
99- def __init__ (self , num_features , nonlin = True , groups = 8 , eps = 1e-5 , jit = True ):
56+ def __init__ (self , num_features , nonlin = True , groups = 8 , eps = 1e-5 ):
10057 super (EvoNormSample2d , self ).__init__ ()
10158 self .nonlin = nonlin
10259 self .groups = groups
10360 self .eps = eps
104- self .jit = jit
10561 param_shape = (1 , num_features , 1 , 1 )
10662 self .weight = nn .Parameter (torch .ones (param_shape ), requires_grad = True )
10763 self .bias = nn .Parameter (torch .zeros (param_shape ), requires_grad = True )
@@ -117,18 +73,13 @@ def reset_parameters(self):
11773
11874 def forward (self , x ):
11975 assert x .dim () == 4 , 'expected 4D input'
120-
121- if self .jit :
122- return evo_sample_jit (
123- x , self .v , self .weight , self .bias , self .groups , self .nonlin , self .eps )
76+ B , C , H , W = x .shape
77+ assert C % self .groups == 0
78+ if self .nonlin :
79+ n = (x * self .v ).sigmoid ().reshape (B , self .groups , - 1 )
80+ x = x .reshape (B , self .groups , - 1 )
81+ x = n / x .var (dim = - 1 , unbiased = False , keepdim = True ).add_ (self .eps ).sqrt_ ()
82+ x = x .reshape (B , C , H , W )
83+ return x .mul_ (self .weight ).add_ (self .bias )
12484 else :
125- B , C , H , W = x .shape
126- assert C % self .groups == 0
127- if self .nonlin :
128- n = (x * self .v ).sigmoid ().reshape (B , self .groups , - 1 )
129- x = x .reshape (B , self .groups , - 1 )
130- x = n / (x .std (dim = - 1 , unbiased = False , keepdim = True ) + self .eps )
131- x = x .reshape (B , C , H , W )
132- return x .mul_ (self .weight ).add_ (self .bias )
133- else :
134- return x .mul (self .weight ).add_ (self .bias )
85+ return x .mul (self .weight ).add_ (self .bias )
0 commit comments