Skip to content

Commit 879df47

Browse files
committed
Support BatchNormAct2d for sync-bn use. Fix #1254
1 parent 7cedc8d commit 879df47

File tree

4 files changed

+97
-16
lines changed

4 files changed

+97
-16
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from .factory import create_model, parse_model_name, safe_model_name
6262
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
6363
from .layers import TestTimePoolHead, apply_test_time_pool
64-
from .layers import convert_splitbn_model
64+
from .layers import convert_splitbn_model, convert_sync_batchnorm
6565
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
6666
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
6767
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
2727
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
2828
from .norm import GroupNorm, LayerNorm2d
29-
from .norm_act import BatchNormAct2d, GroupNormAct
29+
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
3030
from .padding import get_padding, get_same_padding, pad_same
3131
from .patch_embed import PatchEmbed
3232
from .pool2d_same import AvgPool2dSame, create_pool2d

timm/models/layers/norm_act.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
""" Normalization + Activation Layers
22
"""
3-
from typing import Union, List
3+
from typing import Union, List, Optional, Any
44

55
import torch
66
from torch import nn as nn
77
from torch.nn import functional as F
8+
try:
9+
from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm
10+
FULL_SYNC_BN = True
11+
except ImportError:
12+
FULL_SYNC_BN = False
813

914
from .trace_utils import _assert
1015
from .create_act import get_act_layer
@@ -18,10 +23,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
1823
instead of composing it as a .bn member.
1924
"""
2025
def __init__(
21-
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
22-
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
23-
super(BatchNormAct2d, self).__init__(
24-
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
26+
self,
27+
num_features,
28+
eps=1e-5,
29+
momentum=0.1,
30+
affine=True,
31+
track_running_stats=True,
32+
apply_act=True,
33+
act_layer=nn.ReLU,
34+
inplace=True,
35+
drop_layer=None,
36+
device=None,
37+
dtype=None
38+
):
39+
try:
40+
factory_kwargs = {'device': device, 'dtype': dtype}
41+
super(BatchNormAct2d, self).__init__(
42+
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
43+
**factory_kwargs
44+
)
45+
except TypeError:
46+
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
47+
super(BatchNormAct2d, self).__init__(
48+
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
2549
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
2650
act_layer = get_act_layer(act_layer) # string -> nn.Module
2751
if act_layer is not None and apply_act:
@@ -81,6 +105,62 @@ def forward(self, x):
81105
return x
82106

83107

108+
class SyncBatchNormAct(nn.SyncBatchNorm):
109+
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
110+
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
111+
# but ONLY when used in conjunction with the timm conversion function below.
112+
# Do not create this module directly or use the PyTorch conversion function.
113+
def forward(self, x: torch.Tensor) -> torch.Tensor:
114+
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
115+
if hasattr(self, "drop"):
116+
x = self.drop(x)
117+
if hasattr(self, "act"):
118+
x = self.act(x)
119+
return x
120+
121+
122+
def convert_sync_batchnorm(module, process_group=None):
123+
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
124+
module_output = module
125+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
126+
if isinstance(module, BatchNormAct2d):
127+
# convert timm norm + act layer
128+
module_output = SyncBatchNormAct(
129+
module.num_features,
130+
module.eps,
131+
module.momentum,
132+
module.affine,
133+
module.track_running_stats,
134+
process_group=process_group,
135+
)
136+
# set act and drop attr from the original module
137+
module_output.act = module.act
138+
module_output.drop = module.drop
139+
else:
140+
# convert standard BatchNorm layers
141+
module_output = torch.nn.SyncBatchNorm(
142+
module.num_features,
143+
module.eps,
144+
module.momentum,
145+
module.affine,
146+
module.track_running_stats,
147+
process_group,
148+
)
149+
if module.affine:
150+
with torch.no_grad():
151+
module_output.weight = module.weight
152+
module_output.bias = module.bias
153+
module_output.running_mean = module.running_mean
154+
module_output.running_var = module.running_var
155+
module_output.num_batches_tracked = module.num_batches_tracked
156+
if hasattr(module, "qconfig"):
157+
module_output.qconfig = module.qconfig
158+
for name, child in module.named_children():
159+
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
160+
del module
161+
return module_output
162+
163+
84164
def _num_groups(num_channels, num_groups, group_size):
85165
if group_size:
86166
assert num_channels % group_size == 0

train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,25 @@
1515
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
1616
"""
1717
import argparse
18-
import time
19-
import yaml
20-
import os
2118
import logging
19+
import os
20+
import time
2221
from collections import OrderedDict
2322
from contextlib import suppress
2423
from datetime import datetime
2524

2625
import torch
2726
import torch.nn as nn
2827
import torchvision.utils
28+
import yaml
2929
from torch.nn.parallel import DistributedDataParallel as NativeDDP
3030

31-
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
32-
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
33-
convert_splitbn_model, model_parameters
3431
from timm import utils
35-
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
32+
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
33+
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
3634
LabelSmoothingCrossEntropy
35+
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
36+
convert_splitbn_model, convert_sync_batchnorm, model_parameters
3737
from timm.optim import create_optimizer_v2, optimizer_kwargs
3838
from timm.scheduler import create_scheduler
3939
from timm.utils import ApexScaler, NativeScaler
@@ -440,10 +440,11 @@ def main():
440440
if args.distributed and args.sync_bn:
441441
assert not args.split_bn
442442
if has_apex and use_amp == 'apex':
443-
# Apex SyncBN preferred unless native amp is activated
443+
# Apex SyncBN used with Apex AMP
444+
# WARNING this won't currently work with models using BatchNormAct2d
444445
model = convert_syncbn_model(model)
445446
else:
446-
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
447+
model = convert_sync_batchnorm(model)
447448
if args.local_rank == 0:
448449
_logger.info(
449450
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '

0 commit comments

Comments
 (0)