11""" Normalization + Activation Layers
22"""
3- from typing import Union , List
3+ from typing import Union , List , Optional , Any
44
55import torch
66from torch import nn as nn
@@ -18,10 +18,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
1818 instead of composing it as a .bn member.
1919 """
2020 def __init__ (
21- self , num_features , eps = 1e-5 , momentum = 0.1 , affine = True , track_running_stats = True ,
22- apply_act = True , act_layer = nn .ReLU , inplace = True , drop_layer = None ):
23- super (BatchNormAct2d , self ).__init__ (
24- num_features , eps = eps , momentum = momentum , affine = affine , track_running_stats = track_running_stats )
21+ self ,
22+ num_features ,
23+ eps = 1e-5 ,
24+ momentum = 0.1 ,
25+ affine = True ,
26+ track_running_stats = True ,
27+ apply_act = True ,
28+ act_layer = nn .ReLU ,
29+ inplace = True ,
30+ drop_layer = None ,
31+ device = None ,
32+ dtype = None
33+ ):
34+ try :
35+ factory_kwargs = {'device' : device , 'dtype' : dtype }
36+ super (BatchNormAct2d , self ).__init__ (
37+ num_features , eps = eps , momentum = momentum , affine = affine , track_running_stats = track_running_stats ,
38+ ** factory_kwargs
39+ )
40+ except TypeError :
41+ # NOTE for backwards compat with old PyTorch w/o factory device/dtype support
42+ super (BatchNormAct2d , self ).__init__ (
43+ num_features , eps = eps , momentum = momentum , affine = affine , track_running_stats = track_running_stats )
2544 self .drop = drop_layer () if drop_layer is not None else nn .Identity ()
2645 act_layer = get_act_layer (act_layer ) # string -> nn.Module
2746 if act_layer is not None and apply_act :
@@ -81,6 +100,62 @@ def forward(self, x):
81100 return x
82101
83102
103+ class SyncBatchNormAct (nn .SyncBatchNorm ):
104+ # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
105+ # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
106+ # but ONLY when used in conjunction with the timm conversion function below.
107+ # Do not create this module directly or use the PyTorch conversion function.
108+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
109+ x = super ().forward (x ) # SyncBN doesn't work with torchscript anyways, so this is fine
110+ if hasattr (self , "drop" ):
111+ x = self .drop (x )
112+ if hasattr (self , "act" ):
113+ x = self .act (x )
114+ return x
115+
116+
117+ def convert_sync_batchnorm (module , process_group = None ):
118+ # convert both BatchNorm and BatchNormAct layers to Synchronized variants
119+ module_output = module
120+ if isinstance (module , torch .nn .modules .batchnorm ._BatchNorm ):
121+ if isinstance (module , BatchNormAct2d ):
122+ # convert timm norm + act layer
123+ module_output = SyncBatchNormAct (
124+ module .num_features ,
125+ module .eps ,
126+ module .momentum ,
127+ module .affine ,
128+ module .track_running_stats ,
129+ process_group = process_group ,
130+ )
131+ # set act and drop attr from the original module
132+ module_output .act = module .act
133+ module_output .drop = module .drop
134+ else :
135+ # convert standard BatchNorm layers
136+ module_output = torch .nn .SyncBatchNorm (
137+ module .num_features ,
138+ module .eps ,
139+ module .momentum ,
140+ module .affine ,
141+ module .track_running_stats ,
142+ process_group ,
143+ )
144+ if module .affine :
145+ with torch .no_grad ():
146+ module_output .weight = module .weight
147+ module_output .bias = module .bias
148+ module_output .running_mean = module .running_mean
149+ module_output .running_var = module .running_var
150+ module_output .num_batches_tracked = module .num_batches_tracked
151+ if hasattr (module , "qconfig" ):
152+ module_output .qconfig = module .qconfig
153+ for name , child in module .named_children ():
154+ module_output .add_module (name , convert_sync_batchnorm (child , process_group ))
155+ del module
156+ return module_output
157+
158+
84159def _num_groups (num_channels , num_groups , group_size ):
85160 if group_size :
86161 assert num_channels % group_size == 0
0 commit comments