Skip to content

Commit 27bbc70

Browse files
committed
Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training.
1 parent 9214ca0 commit 27bbc70

File tree

4 files changed

+153
-76
lines changed

4 files changed

+153
-76
lines changed

timm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
from .metrics import AverageMeter, accuracy
77
from .misc import natural_key, add_bool_arg
88
from .model import unwrap_model, get_state_dict
9-
from .model_ema import ModelEma
9+
from .model_ema import ModelEma, ModelEmaV2
1010
from .summary import update_summary, get_outdir

timm/utils/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77

88
def unwrap_model(model):
9-
return model.module if hasattr(model, 'module') else model
9+
if isinstance(model, ModelEma):
10+
return unwrap_model(model.ema)
11+
else:
12+
return model.module if hasattr(model, 'module') else model
1013

1114

1215
def get_state_dict(model, unwrap_fn=unwrap_model):

timm/utils/model_ema.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,89 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5+
import logging
6+
from collections import OrderedDict
57
from copy import deepcopy
68

79
import torch
810
import torch.nn as nn
911

12+
_logger = logging.getLogger(__name__)
13+
14+
15+
class ModelEma:
16+
""" Model Exponential Moving Average (DEPRECATED)
17+
18+
Keep a moving average of everything in the model state_dict (parameters and buffers).
19+
This version is deprecated, it does not work with scripted models. Will be removed eventually.
20+
21+
This is intended to allow functionality like
22+
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
23+
24+
A smoothed version of the weights is necessary for some training schemes to perform well.
25+
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
26+
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
27+
smoothing of weights to match results. Pay attention to the decay constant you are using
28+
relative to your update count per epoch.
29+
30+
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
31+
disable validation of the EMA weights. Validation will have to be done manually in a separate
32+
process, or after the training stops converging.
33+
34+
This class is sensitive where it is initialized in the sequence of model init,
35+
GPU assignment and distributed training wrappers.
36+
"""
37+
def __init__(self, model, decay=0.9999, device='', resume=''):
38+
# make a copy of the model for accumulating moving average of weights
39+
self.ema = deepcopy(model)
40+
self.ema.eval()
41+
self.decay = decay
42+
self.device = device # perform ema on different device from model if set
43+
if device:
44+
self.ema.to(device=device)
45+
self.ema_has_module = hasattr(self.ema, 'module')
46+
if resume:
47+
self._load_checkpoint(resume)
48+
for p in self.ema.parameters():
49+
p.requires_grad_(False)
50+
51+
def _load_checkpoint(self, checkpoint_path):
52+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
53+
assert isinstance(checkpoint, dict)
54+
if 'state_dict_ema' in checkpoint:
55+
new_state_dict = OrderedDict()
56+
for k, v in checkpoint['state_dict_ema'].items():
57+
# ema model may have been wrapped by DataParallel, and need module prefix
58+
if self.ema_has_module:
59+
name = 'module.' + k if not k.startswith('module') else k
60+
else:
61+
name = k
62+
new_state_dict[name] = v
63+
self.ema.load_state_dict(new_state_dict)
64+
_logger.info("Loaded state_dict_ema")
65+
else:
66+
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
67+
68+
def update(self, model):
69+
# correct a mismatch in state dict keys
70+
needs_module = hasattr(model, 'module') and not self.ema_has_module
71+
with torch.no_grad():
72+
msd = model.state_dict()
73+
for k, ema_v in self.ema.state_dict().items():
74+
if needs_module:
75+
k = 'module.' + k
76+
model_v = msd[k].detach()
77+
if self.device:
78+
model_v = model_v.to(device=self.device)
79+
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
80+
81+
82+
class ModelEmaV2(nn.Module):
83+
""" Model Exponential Moving Average V2
1084
11-
class ModelEma(nn.Module):
12-
""" Model Exponential Moving Average
1385
Keep a moving average of everything in the model state_dict (parameters and buffers).
86+
V2 of this module is simpler, it does not match params/buffers based on name but simply
87+
iterates in order. It works with torchscript (JIT of full model).
1488
1589
This is intended to allow functionality like
1690
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
@@ -27,22 +101,20 @@ class ModelEma(nn.Module):
27101
28102
This class is sensitive where it is initialized in the sequence of model init,
29103
GPU assignment and distributed training wrappers.
30-
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
31104
"""
32105
def __init__(self, model, decay=0.9999, device=None):
33-
super(ModelEma, self).__init__()
106+
super(ModelEmaV2, self).__init__()
34107
# make a copy of the model for accumulating moving average of weights
35108
self.module = deepcopy(model)
36109
self.module.eval()
37110
self.decay = decay
38111
self.device = device # perform ema on different device from model if set
39-
if device is not None:
112+
if self.device is not None:
40113
self.module.to(device=device)
41114

42115
def update(self, model):
43116
with torch.no_grad():
44117
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
45-
assert ema_v.shape == model_v.shape
46-
if self.device:
118+
if self.device is not None:
47119
model_v = model_v.to(device=self.device)
48120
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

0 commit comments

Comments
 (0)