11""" Normalization layers and wrappers
22"""
3+
34import torch
45import torch .nn as nn
56import torch .nn .functional as F
67
8+ from .fast_norm import is_fast_norm , fast_group_norm , fast_layer_norm
9+
710
811class GroupNorm (nn .GroupNorm ):
912 def __init__ (self , num_channels , num_groups = 32 , eps = 1e-5 , affine = True ):
1013 # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
1114 super ().__init__ (num_groups , num_channels , eps = eps , affine = affine )
15+ self .fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
1216
1317 def forward (self , x ):
14- return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
18+ if self .fast_norm :
19+ return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
20+ else :
21+ return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
1522
1623
1724class GroupNorm1 (nn .GroupNorm ):
@@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm):
2128
2229 def __init__ (self , num_channels , ** kwargs ):
2330 super ().__init__ (1 , num_channels , ** kwargs )
31+ self .fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
32+
33+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
34+ if self .fast_norm :
35+ return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
36+ else :
37+ return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
38+
39+
40+ class LayerNorm (nn .LayerNorm ):
41+ """ LayerNorm w/ fast norm option
42+ """
43+ def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
44+ super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
45+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
46+
47+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
48+ if self ._fast_norm :
49+ x = fast_layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
50+ else :
51+ x = F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
52+ return x
2453
2554
2655class LayerNorm2d (nn .LayerNorm ):
2756 """ LayerNorm for channels of '2D' spatial NCHW tensors """
2857 def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
2958 super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
59+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
3060
3161 def forward (self , x : torch .Tensor ) -> torch .Tensor :
32- return F .layer_norm (
33- x .permute (0 , 2 , 3 , 1 ), self .normalized_shape , self .weight , self .bias , self .eps ).permute (0 , 3 , 1 , 2 )
62+ x = x .permute (0 , 2 , 3 , 1 )
63+ if self ._fast_norm :
64+ x = fast_layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
65+ else :
66+ x = F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
67+ x = x .permute (0 , 3 , 1 , 2 )
68+ return x
3469
3570
3671def _is_contiguous (tensor : torch .Tensor ) -> bool :
3772 # jit is oh so lovely :/
38- # if torch.jit.is_tracing():
39- # return True
4073 if torch .jit .is_scripting ():
4174 return tensor .is_contiguous ()
4275 else :
@@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep
5184 return x
5285
5386
87+ def _layer_norm_cf_sqm (x : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor , eps : float ):
88+ u = x .mean (dim = 1 , keepdim = True )
89+ s = ((x * x ).mean (dim = 1 , keepdim = True ) - (u * u )).clamp (0 )
90+ x = (x - u ) * torch .rsqrt (s + eps )
91+ x = x * weight .view (1 , - 1 , 1 , 1 ) + bias .view (1 , - 1 , 1 , 1 )
92+ return x
93+
94+
5495class LayerNormExp2d (nn .LayerNorm ):
5596 """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
5697
0 commit comments