Skip to content

Commit 9214ca0

Browse files
committed
Simplifying EMA...
1 parent 80cd31f commit 9214ca0

File tree

3 files changed

+12
-44
lines changed

3 files changed

+12
-44
lines changed

timm/utils/model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66

77

88
def unwrap_model(model):
9-
if isinstance(model, ModelEma):
10-
return unwrap_model(model.ema)
11-
else:
12-
return model.module if hasattr(model, 'module') else model
9+
return model.module if hasattr(model, 'module') else model
1310

1411

1512
def get_state_dict(model, unwrap_fn=unwrap_model):

timm/utils/model_ema.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5-
import logging
6-
from collections import OrderedDict
75
from copy import deepcopy
86

97
import torch
8+
import torch.nn as nn
109

11-
_logger = logging.getLogger(__name__)
1210

13-
14-
class ModelEma:
11+
class ModelEma(nn.Module):
1512
""" Model Exponential Moving Average
1613
Keep a moving average of everything in the model state_dict (parameters and buffers).
1714
@@ -32,46 +29,20 @@ class ModelEma:
3229
GPU assignment and distributed training wrappers.
3330
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
3431
"""
35-
def __init__(self, model, decay=0.9999, device='', resume=''):
32+
def __init__(self, model, decay=0.9999, device=None):
33+
super(ModelEma, self).__init__()
3634
# make a copy of the model for accumulating moving average of weights
37-
self.ema = deepcopy(model)
38-
self.ema.eval()
35+
self.module = deepcopy(model)
36+
self.module.eval()
3937
self.decay = decay
4038
self.device = device # perform ema on different device from model if set
41-
if device:
42-
self.ema.to(device=device)
43-
self.ema_has_module = hasattr(self.ema, 'module')
44-
if resume:
45-
self._load_checkpoint(resume)
46-
for p in self.ema.parameters():
47-
p.requires_grad_(False)
48-
49-
def _load_checkpoint(self, checkpoint_path):
50-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
51-
assert isinstance(checkpoint, dict)
52-
if 'state_dict_ema' in checkpoint:
53-
new_state_dict = OrderedDict()
54-
for k, v in checkpoint['state_dict_ema'].items():
55-
# ema model may have been wrapped by DataParallel, and need module prefix
56-
if self.ema_has_module:
57-
name = 'module.' + k if not k.startswith('module') else k
58-
else:
59-
name = k
60-
new_state_dict[name] = v
61-
self.ema.load_state_dict(new_state_dict)
62-
_logger.info("Loaded state_dict_ema")
63-
else:
64-
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
39+
if device is not None:
40+
self.module.to(device=device)
6541

6642
def update(self, model):
67-
# correct a mismatch in state dict keys
68-
needs_module = hasattr(model, 'module') and not self.ema_has_module
6943
with torch.no_grad():
70-
msd = model.state_dict()
71-
for k, ema_v in self.ema.state_dict().items():
72-
if needs_module:
73-
k = 'module.' + k
74-
model_v = msd[k].detach()
44+
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
45+
assert ema_v.shape == model_v.shape
7546
if self.device:
7647
model_v = model_v.to(device=self.device)
7748
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def main():
568568
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
569569
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
570570
ema_eval_metrics = validate(
571-
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
571+
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
572572
eval_metrics = ema_eval_metrics
573573

574574
if lr_scheduler is not None:

0 commit comments

Comments
 (0)