22
33Hacked together by / Copyright 2020 Ross Wightman
44"""
5- import logging
6- from collections import OrderedDict
75from copy import deepcopy
86
97import torch
8+ import torch .nn as nn
109
11- _logger = logging .getLogger (__name__ )
1210
13-
14- class ModelEma :
11+ class ModelEma (nn .Module ):
1512 """ Model Exponential Moving Average
1613 Keep a moving average of everything in the model state_dict (parameters and buffers).
1714
@@ -32,46 +29,20 @@ class ModelEma:
3229 GPU assignment and distributed training wrappers.
3330 I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
3431 """
35- def __init__ (self , model , decay = 0.9999 , device = '' , resume = '' ):
32+ def __init__ (self , model , decay = 0.9999 , device = None ):
33+ super (ModelEma , self ).__init__ ()
3634 # make a copy of the model for accumulating moving average of weights
37- self .ema = deepcopy (model )
38- self .ema .eval ()
35+ self .module = deepcopy (model )
36+ self .module .eval ()
3937 self .decay = decay
4038 self .device = device # perform ema on different device from model if set
41- if device :
42- self .ema .to (device = device )
43- self .ema_has_module = hasattr (self .ema , 'module' )
44- if resume :
45- self ._load_checkpoint (resume )
46- for p in self .ema .parameters ():
47- p .requires_grad_ (False )
48-
49- def _load_checkpoint (self , checkpoint_path ):
50- checkpoint = torch .load (checkpoint_path , map_location = 'cpu' )
51- assert isinstance (checkpoint , dict )
52- if 'state_dict_ema' in checkpoint :
53- new_state_dict = OrderedDict ()
54- for k , v in checkpoint ['state_dict_ema' ].items ():
55- # ema model may have been wrapped by DataParallel, and need module prefix
56- if self .ema_has_module :
57- name = 'module.' + k if not k .startswith ('module' ) else k
58- else :
59- name = k
60- new_state_dict [name ] = v
61- self .ema .load_state_dict (new_state_dict )
62- _logger .info ("Loaded state_dict_ema" )
63- else :
64- _logger .warning ("Failed to find state_dict_ema, starting from loaded model weights" )
39+ if device is not None :
40+ self .module .to (device = device )
6541
6642 def update (self , model ):
67- # correct a mismatch in state dict keys
68- needs_module = hasattr (model , 'module' ) and not self .ema_has_module
6943 with torch .no_grad ():
70- msd = model .state_dict ()
71- for k , ema_v in self .ema .state_dict ().items ():
72- if needs_module :
73- k = 'module.' + k
74- model_v = msd [k ].detach ()
44+ for ema_v , model_v in zip (self .module .state_dict ().values (), model .state_dict ().values ()):
45+ assert ema_v .shape == model_v .shape
7546 if self .device :
7647 model_v = model_v .to (device = self .device )
7748 ema_v .copy_ (ema_v * self .decay + (1. - self .decay ) * model_v )
0 commit comments