11""" Normalization + Activation Layers
22"""
3- from typing import Union , List
3+ from typing import Union , List , Optional , Any
44
55import torch
66from torch import nn as nn
77from torch .nn import functional as F
8+ try :
9+ from torch .nn .modules ._functions import SyncBatchNorm as sync_batch_norm
10+ FULL_SYNC_BN = True
11+ except ImportError :
12+ FULL_SYNC_BN = False
813
914from .trace_utils import _assert
1015from .create_act import get_act_layer
@@ -18,10 +23,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
1823 instead of composing it as a .bn member.
1924 """
2025 def __init__ (
21- self , num_features , eps = 1e-5 , momentum = 0.1 , affine = True , track_running_stats = True ,
22- apply_act = True , act_layer = nn .ReLU , inplace = True , drop_layer = None ):
23- super (BatchNormAct2d , self ).__init__ (
24- num_features , eps = eps , momentum = momentum , affine = affine , track_running_stats = track_running_stats )
26+ self ,
27+ num_features ,
28+ eps = 1e-5 ,
29+ momentum = 0.1 ,
30+ affine = True ,
31+ track_running_stats = True ,
32+ apply_act = True ,
33+ act_layer = nn .ReLU ,
34+ inplace = True ,
35+ drop_layer = None ,
36+ device = None ,
37+ dtype = None
38+ ):
39+ try :
40+ factory_kwargs = {'device' : device , 'dtype' : dtype }
41+ super (BatchNormAct2d , self ).__init__ (
42+ num_features , eps = eps , momentum = momentum , affine = affine , track_running_stats = track_running_stats ,
43+ ** factory_kwargs
44+ )
45+ except TypeError :
46+ # NOTE for backwards compat with old PyTorch w/o factory device/dtype support
47+ super (BatchNormAct2d , self ).__init__ (
48+ num_features , eps = eps , momentum = momentum , affine = affine , track_running_stats = track_running_stats )
2549 self .drop = drop_layer () if drop_layer is not None else nn .Identity ()
2650 act_layer = get_act_layer (act_layer ) # string -> nn.Module
2751 if act_layer is not None and apply_act :
@@ -81,6 +105,62 @@ def forward(self, x):
81105 return x
82106
83107
108+ class SyncBatchNormAct (nn .SyncBatchNorm ):
109+ # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
110+ # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
111+ # but ONLY when used in conjunction with the timm conversion function below.
112+ # Do not create this module directly or use the PyTorch conversion function.
113+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
114+ x = super ().forward (x ) # SyncBN doesn't work with torchscript anyways, so this is fine
115+ if hasattr (self , "drop" ):
116+ x = self .drop (x )
117+ if hasattr (self , "act" ):
118+ x = self .act (x )
119+ return x
120+
121+
122+ def convert_sync_batchnorm (module , process_group = None ):
123+ # convert both BatchNorm and BatchNormAct layers to Synchronized variants
124+ module_output = module
125+ if isinstance (module , torch .nn .modules .batchnorm ._BatchNorm ):
126+ if isinstance (module , BatchNormAct2d ):
127+ # convert timm norm + act layer
128+ module_output = SyncBatchNormAct (
129+ module .num_features ,
130+ module .eps ,
131+ module .momentum ,
132+ module .affine ,
133+ module .track_running_stats ,
134+ process_group = process_group ,
135+ )
136+ # set act and drop attr from the original module
137+ module_output .act = module .act
138+ module_output .drop = module .drop
139+ else :
140+ # convert standard BatchNorm layers
141+ module_output = torch .nn .SyncBatchNorm (
142+ module .num_features ,
143+ module .eps ,
144+ module .momentum ,
145+ module .affine ,
146+ module .track_running_stats ,
147+ process_group ,
148+ )
149+ if module .affine :
150+ with torch .no_grad ():
151+ module_output .weight = module .weight
152+ module_output .bias = module .bias
153+ module_output .running_mean = module .running_mean
154+ module_output .running_var = module .running_var
155+ module_output .num_batches_tracked = module .num_batches_tracked
156+ if hasattr (module , "qconfig" ):
157+ module_output .qconfig = module .qconfig
158+ for name , child in module .named_children ():
159+ module_output .add_module (name , convert_sync_batchnorm (child , process_group ))
160+ del module
161+ return module_output
162+
163+
84164def _num_groups (num_channels , num_groups , group_size ):
85165 if group_size :
86166 assert num_channels % group_size == 0
0 commit comments