Skip to content

Commit 4ca52d7

Browse files
committed
Add separate set and update method to ModelEmaV2
1 parent 2ed8f24 commit 4ca52d7

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

timm/utils/model_ema.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,15 @@ def __init__(self, model, decay=0.9999, device=None):
112112
if self.device is not None:
113113
self.module.to(device=device)
114114

115-
def update(self, model):
115+
def _update(self, model, update_fn):
116116
with torch.no_grad():
117117
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
118118
if self.device is not None:
119119
model_v = model_v.to(device=self.device)
120-
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
120+
ema_v.copy_(update_fn(ema_v, model_v))
121+
122+
def update(self, model):
123+
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
124+
125+
def set(self, model):
126+
self._update(model, update_fn=lambda e, m: m)

0 commit comments

Comments
 (0)