22
33Hacked together by / Copyright 2020 Ross Wightman
44"""
5+ import logging
6+ from collections import OrderedDict
57from copy import deepcopy
68
79import torch
810import torch .nn as nn
911
12+ _logger = logging .getLogger (__name__ )
13+
14+
15+ class ModelEma :
16+ """ Model Exponential Moving Average (DEPRECATED)
17+
18+ Keep a moving average of everything in the model state_dict (parameters and buffers).
19+ This version is deprecated, it does not work with scripted models. Will be removed eventually.
20+
21+ This is intended to allow functionality like
22+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
23+
24+ A smoothed version of the weights is necessary for some training schemes to perform well.
25+ E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
26+ RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
27+ smoothing of weights to match results. Pay attention to the decay constant you are using
28+ relative to your update count per epoch.
29+
30+ To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
31+ disable validation of the EMA weights. Validation will have to be done manually in a separate
32+ process, or after the training stops converging.
33+
34+ This class is sensitive where it is initialized in the sequence of model init,
35+ GPU assignment and distributed training wrappers.
36+ """
37+ def __init__ (self , model , decay = 0.9999 , device = '' , resume = '' ):
38+ # make a copy of the model for accumulating moving average of weights
39+ self .ema = deepcopy (model )
40+ self .ema .eval ()
41+ self .decay = decay
42+ self .device = device # perform ema on different device from model if set
43+ if device :
44+ self .ema .to (device = device )
45+ self .ema_has_module = hasattr (self .ema , 'module' )
46+ if resume :
47+ self ._load_checkpoint (resume )
48+ for p in self .ema .parameters ():
49+ p .requires_grad_ (False )
50+
51+ def _load_checkpoint (self , checkpoint_path ):
52+ checkpoint = torch .load (checkpoint_path , map_location = 'cpu' )
53+ assert isinstance (checkpoint , dict )
54+ if 'state_dict_ema' in checkpoint :
55+ new_state_dict = OrderedDict ()
56+ for k , v in checkpoint ['state_dict_ema' ].items ():
57+ # ema model may have been wrapped by DataParallel, and need module prefix
58+ if self .ema_has_module :
59+ name = 'module.' + k if not k .startswith ('module' ) else k
60+ else :
61+ name = k
62+ new_state_dict [name ] = v
63+ self .ema .load_state_dict (new_state_dict )
64+ _logger .info ("Loaded state_dict_ema" )
65+ else :
66+ _logger .warning ("Failed to find state_dict_ema, starting from loaded model weights" )
67+
68+ def update (self , model ):
69+ # correct a mismatch in state dict keys
70+ needs_module = hasattr (model , 'module' ) and not self .ema_has_module
71+ with torch .no_grad ():
72+ msd = model .state_dict ()
73+ for k , ema_v in self .ema .state_dict ().items ():
74+ if needs_module :
75+ k = 'module.' + k
76+ model_v = msd [k ].detach ()
77+ if self .device :
78+ model_v = model_v .to (device = self .device )
79+ ema_v .copy_ (ema_v * self .decay + (1. - self .decay ) * model_v )
80+
81+
82+ class ModelEmaV2 (nn .Module ):
83+ """ Model Exponential Moving Average V2
1084
11- class ModelEma (nn .Module ):
12- """ Model Exponential Moving Average
1385 Keep a moving average of everything in the model state_dict (parameters and buffers).
86+ V2 of this module is simpler, it does not match params/buffers based on name but simply
87+ iterates in order. It works with torchscript (JIT of full model).
1488
1589 This is intended to allow functionality like
1690 https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
@@ -27,22 +101,20 @@ class ModelEma(nn.Module):
27101
28102 This class is sensitive where it is initialized in the sequence of model init,
29103 GPU assignment and distributed training wrappers.
30- I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
31104 """
32105 def __init__ (self , model , decay = 0.9999 , device = None ):
33- super (ModelEma , self ).__init__ ()
106+ super (ModelEmaV2 , self ).__init__ ()
34107 # make a copy of the model for accumulating moving average of weights
35108 self .module = deepcopy (model )
36109 self .module .eval ()
37110 self .decay = decay
38111 self .device = device # perform ema on different device from model if set
39- if device is not None :
112+ if self . device is not None :
40113 self .module .to (device = device )
41114
42115 def update (self , model ):
43116 with torch .no_grad ():
44117 for ema_v , model_v in zip (self .module .state_dict ().values (), model .state_dict ().values ()):
45- assert ema_v .shape == model_v .shape
46- if self .device :
118+ if self .device is not None :
47119 model_v = model_v .to (device = self .device )
48120 ema_v .copy_ (ema_v * self .decay + (1. - self .decay ) * model_v )
0 commit comments