@@ -18,27 +18,20 @@ class StdConv2d(nn.Conv2d):
1818 https://arxiv.org/abs/1903.10520v2
1919 """
2020 def __init__ (
21- self , in_channel , out_channels , kernel_size , stride = 1 , padding = None , dilation = 1 ,
22- groups = 1 , bias = False , eps = 1e-5 , use_layernorm = True ):
21+ self , in_channel , out_channels , kernel_size , stride = 1 , padding = None ,
22+ dilation = 1 , groups = 1 , bias = False , eps = 1e-6 ):
2323 if padding is None :
2424 padding = get_padding (kernel_size , stride , dilation )
2525 super ().__init__ (
2626 in_channel , out_channels , kernel_size , stride = stride ,
2727 padding = padding , dilation = dilation , groups = groups , bias = bias )
2828 self .eps = eps
29- self .use_layernorm = use_layernorm
30-
31- def get_weight (self ):
32- if self .use_layernorm :
33- # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
34- weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
35- else :
36- std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
37- weight = (self .weight - mean ) / (std + self .eps )
38- return weight
3929
4030 def forward (self , x ):
41- x = F .conv2d (x , self .get_weight (), self .bias , self .stride , self .padding , self .dilation , self .groups )
31+ weight = F .batch_norm (
32+ self .weight .view (1 , self .out_channels , - 1 ), None , None ,
33+ eps = self .eps , training = True , momentum = 0. ).reshape_as (self .weight )
34+ x = F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
4235 return x
4336
4437
@@ -49,29 +42,22 @@ class StdConv2dSame(nn.Conv2d):
4942 https://arxiv.org/abs/1903.10520v2
5043 """
5144 def __init__ (
52- self , in_channel , out_channels , kernel_size , stride = 1 , padding = 'SAME' , dilation = 1 ,
53- groups = 1 , bias = False , eps = 1e-5 , use_layernorm = True ):
45+ self , in_channel , out_channels , kernel_size , stride = 1 , padding = 'SAME' ,
46+ dilation = 1 , groups = 1 , bias = False , eps = 1e-6 ):
5447 padding , is_dynamic = get_padding_value (padding , kernel_size , stride = stride , dilation = dilation )
5548 super ().__init__ (
5649 in_channel , out_channels , kernel_size , stride = stride , padding = padding , dilation = dilation ,
5750 groups = groups , bias = bias )
5851 self .same_pad = is_dynamic
5952 self .eps = eps
60- self .use_layernorm = use_layernorm
61-
62- def get_weight (self ):
63- if self .use_layernorm :
64- # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
65- weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
66- else :
67- std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
68- weight = (self .weight - mean ) / (std + self .eps )
69- return weight
7053
7154 def forward (self , x ):
7255 if self .same_pad :
7356 x = pad_same (x , self .kernel_size , self .stride , self .dilation )
74- x = F .conv2d (x , self .get_weight (), self .bias , self .stride , self .padding , self .dilation , self .groups )
57+ weight = F .batch_norm (
58+ self .weight .view (1 , self .out_channels , - 1 ), None , None ,
59+ eps = self .eps , training = True , momentum = 0. ).reshape_as (self .weight )
60+ x = F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
7561 return x
7662
7763
@@ -85,8 +71,8 @@ class ScaledStdConv2d(nn.Conv2d):
8571 """
8672
8773 def __init__ (
88- self , in_channels , out_channels , kernel_size , stride = 1 , padding = None , dilation = 1 , groups = 1 ,
89- bias = True , gamma = 1.0 , eps = 1e-5 , gain_init = 1.0 , use_layernorm = True ):
74+ self , in_channels , out_channels , kernel_size , stride = 1 , padding = None ,
75+ dilation = 1 , groups = 1 , bias = True , gamma = 1.0 , eps = 1e-6 , gain_init = 1.0 ):
9076 if padding is None :
9177 padding = get_padding (kernel_size , stride , dilation )
9278 super ().__init__ (
@@ -95,19 +81,13 @@ def __init__(
9581 self .gain = nn .Parameter (torch .full ((self .out_channels , 1 , 1 , 1 ), gain_init ))
9682 self .scale = gamma * self .weight [0 ].numel () ** - 0.5 # gamma * 1 / sqrt(fan-in)
9783 self .eps = eps
98- self .use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
99-
100- def get_weight (self ):
101- if self .use_layernorm :
102- # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
103- weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
104- else :
105- std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
106- weight = (self .weight - mean ) / (std + self .eps )
107- return weight .mul_ (self .gain * self .scale )
10884
10985 def forward (self , x ):
110- return F .conv2d (x , self .get_weight (), self .bias , self .stride , self .padding , self .dilation , self .groups )
86+ weight = F .batch_norm (
87+ self .weight .view (1 , self .out_channels , - 1 ), None , None ,
88+ weight = (self .gain * self .scale ).view (- 1 ),
89+ eps = self .eps , training = True , momentum = 0. ).reshape_as (self .weight )
90+ return F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
11191
11292
11393class ScaledStdConv2dSame (nn .Conv2d ):
@@ -120,8 +100,8 @@ class ScaledStdConv2dSame(nn.Conv2d):
120100 """
121101
122102 def __init__ (
123- self , in_channels , out_channels , kernel_size , stride = 1 , padding = 'SAME' , dilation = 1 , groups = 1 ,
124- bias = True , gamma = 1.0 , eps = 1e-5 , gain_init = 1.0 , use_layernorm = True ):
103+ self , in_channels , out_channels , kernel_size , stride = 1 , padding = 'SAME' ,
104+ dilation = 1 , groups = 1 , bias = True , gamma = 1.0 , eps = 1e-6 , gain_init = 1.0 ):
125105 padding , is_dynamic = get_padding_value (padding , kernel_size , stride = stride , dilation = dilation )
126106 super ().__init__ (
127107 in_channels , out_channels , kernel_size , stride = stride , padding = padding , dilation = dilation ,
@@ -130,18 +110,12 @@ def __init__(
130110 self .scale = gamma * self .weight [0 ].numel () ** - 0.5
131111 self .same_pad = is_dynamic
132112 self .eps = eps
133- self .use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
134-
135- def get_weight (self ):
136- if self .use_layernorm :
137- # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
138- weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
139- else :
140- std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
141- weight = (self .weight - mean ) / (std + self .eps )
142- return weight .mul_ (self .gain * self .scale )
143113
144114 def forward (self , x ):
145115 if self .same_pad :
146116 x = pad_same (x , self .kernel_size , self .stride , self .dilation )
147- return F .conv2d (x , self .get_weight (), self .bias , self .stride , self .padding , self .dilation , self .groups )
117+ weight = F .batch_norm (
118+ self .weight .view (1 , self .out_channels , - 1 ), None , None ,
119+ weight = (self .gain * self .scale ).view (- 1 ),
120+ eps = self .eps , training = True , momentum = 0. ).reshape_as (self .weight )
121+ return F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
0 commit comments