Skip to content

Commit beef62e

Browse files
authored
Merge pull request #1317 from rwightman/fixes-syncbn_pretrain_cfg_resolve
Fix SyncBatchNorm for BatchNormAc2d, improve resolve_pretrained_cfg behaviour, other mix fixes.
2 parents 7cedc8d + e6d7df4 commit beef62e

File tree

10 files changed

+122
-32
lines changed

10 files changed

+122
-32
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from .factory import create_model, parse_model_name, safe_model_name
6262
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
6363
from .layers import TestTimePoolHead, apply_test_time_pool
64-
from .layers import convert_splitbn_model
64+
from .layers import convert_splitbn_model, convert_sync_batchnorm
6565
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
6666
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
6767
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

timm/models/helpers.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,26 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
455455
filter_kwargs(kwargs, names=kwargs_filter)
456456

457457

458-
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
458+
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
459459
if pretrained_cfg and isinstance(pretrained_cfg, dict):
460-
# highest priority, pretrained_cfg available and passed explicitly
460+
# highest priority, pretrained_cfg available and passed as arg
461461
return deepcopy(pretrained_cfg)
462-
if kwargs and 'pretrained_cfg' in kwargs:
463-
# next highest, pretrained_cfg in a kwargs dict, pop and return
464-
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
465-
if pretrained_cfg:
466-
return deepcopy(pretrained_cfg)
467-
# lookup pretrained cfg in model registry by variant
462+
# fallback to looking up pretrained cfg in model registry by variant identifier
468463
pretrained_cfg = get_pretrained_cfg(variant)
469-
assert pretrained_cfg
464+
if not pretrained_cfg:
465+
_logger.warning(
466+
f"No pretrained configuration specified for {variant} model. Using a default."
467+
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
468+
pretrained_cfg = dict(
469+
url='',
470+
num_classes=1000,
471+
input_size=(3, 224, 224),
472+
pool_size=None,
473+
crop_pct=.9,
474+
interpolation='bicubic',
475+
first_conv='',
476+
classifier='',
477+
)
470478
return pretrained_cfg
471479

472480

timm/models/inception_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def forward(self, x):
428428

429429

430430
def _create_inception_v3(variant, pretrained=False, **kwargs):
431-
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
431+
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
432432
aux_logits = kwargs.pop('aux_logits', False)
433433
if aux_logits:
434434
assert not kwargs.pop('features_only', False)

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
2727
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
2828
from .norm import GroupNorm, LayerNorm2d
29-
from .norm_act import BatchNormAct2d, GroupNormAct
29+
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
3030
from .padding import get_padding, get_same_padding, pad_same
3131
from .patch_embed import PatchEmbed
3232
from .pool2d_same import AvgPool2dSame, create_pool2d

timm/models/layers/drop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,6 @@ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
164164

165165
def forward(self, x):
166166
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
167+
168+
def extra_repr(self):
169+
return f'drop_prob={round(self.drop_prob,3):0.3f}'

timm/models/layers/evo_norm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,9 @@ def forward(self, x):
256256
class EvoNorm2dS1(nn.Module):
257257
def __init__(
258258
self, num_features, groups=32, group_size=None,
259-
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
259+
apply_act=True, act_layer=None, eps=1e-5, **_):
260260
super().__init__()
261+
act_layer = act_layer or nn.SiLU
261262
self.apply_act = apply_act # apply activation (non-linearity)
262263
if act_layer is not None and apply_act:
263264
self.act = create_act_layer(act_layer)
@@ -290,7 +291,7 @@ def forward(self, x):
290291
class EvoNorm2dS1a(EvoNorm2dS1):
291292
def __init__(
292293
self, num_features, groups=32, group_size=None,
293-
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
294+
apply_act=True, act_layer=None, eps=1e-3, **_):
294295
super().__init__(
295296
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
296297

@@ -305,8 +306,9 @@ def forward(self, x):
305306
class EvoNorm2dS2(nn.Module):
306307
def __init__(
307308
self, num_features, groups=32, group_size=None,
308-
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
309+
apply_act=True, act_layer=None, eps=1e-5, **_):
309310
super().__init__()
311+
act_layer = act_layer or nn.SiLU
310312
self.apply_act = apply_act # apply activation (non-linearity)
311313
if act_layer is not None and apply_act:
312314
self.act = create_act_layer(act_layer)
@@ -338,7 +340,7 @@ def forward(self, x):
338340
class EvoNorm2dS2a(EvoNorm2dS2):
339341
def __init__(
340342
self, num_features, groups=32, group_size=None,
341-
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
343+
apply_act=True, act_layer=None, eps=1e-3, **_):
342344
super().__init__(
343345
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
344346

timm/models/layers/norm_act.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
""" Normalization + Activation Layers
22
"""
3-
from typing import Union, List
3+
from typing import Union, List, Optional, Any
44

55
import torch
66
from torch import nn as nn
@@ -18,10 +18,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
1818
instead of composing it as a .bn member.
1919
"""
2020
def __init__(
21-
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
22-
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
23-
super(BatchNormAct2d, self).__init__(
24-
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
21+
self,
22+
num_features,
23+
eps=1e-5,
24+
momentum=0.1,
25+
affine=True,
26+
track_running_stats=True,
27+
apply_act=True,
28+
act_layer=nn.ReLU,
29+
inplace=True,
30+
drop_layer=None,
31+
device=None,
32+
dtype=None
33+
):
34+
try:
35+
factory_kwargs = {'device': device, 'dtype': dtype}
36+
super(BatchNormAct2d, self).__init__(
37+
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
38+
**factory_kwargs
39+
)
40+
except TypeError:
41+
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
42+
super(BatchNormAct2d, self).__init__(
43+
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
2544
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
2645
act_layer = get_act_layer(act_layer) # string -> nn.Module
2746
if act_layer is not None and apply_act:
@@ -81,6 +100,62 @@ def forward(self, x):
81100
return x
82101

83102

103+
class SyncBatchNormAct(nn.SyncBatchNorm):
104+
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
105+
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
106+
# but ONLY when used in conjunction with the timm conversion function below.
107+
# Do not create this module directly or use the PyTorch conversion function.
108+
def forward(self, x: torch.Tensor) -> torch.Tensor:
109+
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
110+
if hasattr(self, "drop"):
111+
x = self.drop(x)
112+
if hasattr(self, "act"):
113+
x = self.act(x)
114+
return x
115+
116+
117+
def convert_sync_batchnorm(module, process_group=None):
118+
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
119+
module_output = module
120+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
121+
if isinstance(module, BatchNormAct2d):
122+
# convert timm norm + act layer
123+
module_output = SyncBatchNormAct(
124+
module.num_features,
125+
module.eps,
126+
module.momentum,
127+
module.affine,
128+
module.track_running_stats,
129+
process_group=process_group,
130+
)
131+
# set act and drop attr from the original module
132+
module_output.act = module.act
133+
module_output.drop = module.drop
134+
else:
135+
# convert standard BatchNorm layers
136+
module_output = torch.nn.SyncBatchNorm(
137+
module.num_features,
138+
module.eps,
139+
module.momentum,
140+
module.affine,
141+
module.track_running_stats,
142+
process_group,
143+
)
144+
if module.affine:
145+
with torch.no_grad():
146+
module_output.weight = module.weight
147+
module_output.bias = module.bias
148+
module_output.running_mean = module.running_mean
149+
module_output.running_var = module.running_var
150+
module_output.num_batches_tracked = module.num_batches_tracked
151+
if hasattr(module, "qconfig"):
152+
module_output.qconfig = module.qconfig
153+
for name, child in module.named_children():
154+
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
155+
del module
156+
return module_output
157+
158+
84159
def _num_groups(num_channels, num_groups, group_size):
85160
if group_size:
86161
assert num_channels % group_size == 0

timm/models/vision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
633633
if kwargs.get('features_only', None):
634634
raise RuntimeError('features_only not implemented for Vision Transformer models.')
635635

636-
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
636+
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
637637
model = build_model_with_cfg(
638638
VisionTransformer, variant, pretrained,
639639
pretrained_cfg=pretrained_cfg,

timm/models/vision_transformer_relpos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.checkpoint import checkpoint
1717

1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
19-
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
19+
from .helpers import build_model_with_cfg, named_apply
2020
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
2121
from .registry import register_model
2222

train.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,25 @@
1515
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
1616
"""
1717
import argparse
18-
import time
19-
import yaml
20-
import os
2118
import logging
19+
import os
20+
import time
2221
from collections import OrderedDict
2322
from contextlib import suppress
2423
from datetime import datetime
2524

2625
import torch
2726
import torch.nn as nn
2827
import torchvision.utils
28+
import yaml
2929
from torch.nn.parallel import DistributedDataParallel as NativeDDP
3030

31-
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
32-
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
33-
convert_splitbn_model, model_parameters
3431
from timm import utils
35-
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
32+
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
33+
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
3634
LabelSmoothingCrossEntropy
35+
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
36+
convert_splitbn_model, convert_sync_batchnorm, model_parameters
3737
from timm.optim import create_optimizer_v2, optimizer_kwargs
3838
from timm.scheduler import create_scheduler
3939
from timm.utils import ApexScaler, NativeScaler
@@ -438,12 +438,14 @@ def main():
438438

439439
# setup synchronized BatchNorm for distributed training
440440
if args.distributed and args.sync_bn:
441+
args.dist_bn = '' # disable dist_bn when sync BN active
441442
assert not args.split_bn
442443
if has_apex and use_amp == 'apex':
443-
# Apex SyncBN preferred unless native amp is activated
444+
# Apex SyncBN used with Apex AMP
445+
# WARNING this won't currently work with models using BatchNormAct2d
444446
model = convert_syncbn_model(model)
445447
else:
446-
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
448+
model = convert_sync_batchnorm(model)
447449
if args.local_rank == 0:
448450
_logger.info(
449451
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '

0 commit comments

Comments
 (0)