Skip to content

Commit ac469b5

Browse files
committed
Optimizer improvements, additions, cleanup
* Add MADGRAD code * Fix Lamb (non-fused variant) to work w/ PyTorch XLA * Tweak optimizer factory args (lr/learning_rate and opt/optimizer_name), may break compat * Use newer fn signatures for all add,addcdiv, addcmul in optimizers * Use upcoming PyTorch native Nadam if it's available * Cleanup lookahead opt * Add optimizer tests * Remove novograd.py impl as it was messy, keep nvnovograd * Make AdamP/SGDP work in channels_last layout * Add rectified adablief mode (radabelief) * Support a few more PyTorch optim, adamax, adagrad
1 parent 3cdaf5e commit ac469b5

File tree

16 files changed

+438
-463
lines changed

16 files changed

+438
-463
lines changed

benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
255255

256256
self.optimizer = create_optimizer_v2(
257257
self.model,
258-
optimizer_name=kwargs.pop('opt', 'sgd'),
259-
learning_rate=kwargs.pop('lr', 1e-4))
258+
opt=kwargs.pop('opt', 'sgd'),
259+
lr=kwargs.pop('lr', 1e-4))
260260

261261
def _gen_target(self, batch_size):
262262
return torch.empty(

timm/optim/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .adahessian import Adahessian
55
from .lookahead import Lookahead
66
from .nadam import Nadam
7-
from .novograd import NovoGrad
87
from .nvnovograd import NvNovoGrad
98
from .radam import RAdam
109
from .rmsprop_tf import RMSpropTF

timm/optim/adabelief.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AdaBelief(Optimizer):
1818
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
1919
algorithm from the paper `On the Convergence of Adam and Beyond`_
2020
(default: False)
21-
weight_decouple (boolean, optional): ( default: True) If set as True, then
21+
decoupled_decay (boolean, optional): ( default: True) If set as True, then
2222
the optimizer uses decoupled weight decay as in AdamW
2323
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
2424
is set as True.
@@ -39,9 +39,9 @@ class AdaBelief(Optimizer):
3939
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
4040
"""
4141

42-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
43-
weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
44-
degenerated_to_sgd=True):
42+
def __init__(
43+
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
44+
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
4545

4646
if not 0.0 <= lr:
4747
raise ValueError("Invalid learning rate: {}".format(lr))
@@ -52,21 +52,17 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
5252
if not 0.0 <= betas[1] < 1.0:
5353
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
5454

55-
self.degenerated_to_sgd = degenerated_to_sgd
5655
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
5756
for param in params:
5857
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
5958
param['buffer'] = [[None, None, None] for _ in range(10)]
6059

61-
defaults = dict(lr=lr, betas=betas, eps=eps,
62-
weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)])
60+
defaults = dict(
61+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
62+
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
63+
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
6364
super(AdaBelief, self).__init__(params, defaults)
6465

65-
self.degenerated_to_sgd = degenerated_to_sgd
66-
self.weight_decouple = weight_decouple
67-
self.rectify = rectify
68-
self.fixed_decay = fixed_decay
69-
7066
def __setstate__(self, state):
7167
super(AdaBelief, self).__setstate__(state)
7268
for group in self.param_groups:
@@ -133,8 +129,8 @@ def step(self, closure=None):
133129
state['max_exp_avg_var'] = torch.zeros_like(p.data)
134130

135131
# perform weight decay, check if decoupled weight decay
136-
if self.weight_decouple:
137-
if not self.fixed_decay:
132+
if group['decoupled_decay']:
133+
if not group['fixed_decay']:
138134
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
139135
else:
140136
p.data.mul_(1.0 - group['weight_decay'])
@@ -152,7 +148,7 @@ def step(self, closure=None):
152148
# Update first and second moment running average
153149
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
154150
grad_residual = grad - exp_avg
155-
exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2)
151+
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
156152

157153
if amsgrad:
158154
max_exp_avg_var = state['max_exp_avg_var']
@@ -165,34 +161,36 @@ def step(self, closure=None):
165161
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
166162

167163
# update
168-
if not self.rectify:
164+
if not group['rectify']:
169165
# Default update
170166
step_size = group['lr'] / bias_correction1
171-
p.data.addcdiv_( exp_avg, denom, value=-step_size)
172-
173-
else: # Rectified update, forked from RAdam
167+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
168+
else:
169+
# Rectified update, forked from RAdam
174170
buffered = group['buffer'][int(state['step'] % 10)]
175171
if state['step'] == buffered[0]:
176-
N_sma, step_size = buffered[1], buffered[2]
172+
num_sma, step_size = buffered[1], buffered[2]
177173
else:
178174
buffered[0] = state['step']
179175
beta2_t = beta2 ** state['step']
180-
N_sma_max = 2 / (1 - beta2) - 1
181-
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
182-
buffered[1] = N_sma
176+
num_sma_max = 2 / (1 - beta2) - 1
177+
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
178+
buffered[1] = num_sma
183179

184180
# more conservative since it's an approximated value
185-
if N_sma >= 5:
181+
if num_sma >= 5:
186182
step_size = math.sqrt(
187-
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
188-
N_sma_max - 2)) / (1 - beta1 ** state['step'])
189-
elif self.degenerated_to_sgd:
183+
(1 - beta2_t) *
184+
(num_sma - 4) / (num_sma_max - 4) *
185+
(num_sma - 2) / num_sma *
186+
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
187+
elif group['degenerated_to_sgd']:
190188
step_size = 1.0 / (1 - beta1 ** state['step'])
191189
else:
192190
step_size = -1
193191
buffered[2] = step_size
194192

195-
if N_sma >= 5:
193+
if num_sma >= 5:
196194
denom = exp_avg_var.sqrt().add_(group['eps'])
197195
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
198196
elif step_size > 0:

timm/optim/adafactor.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,13 @@ class Adafactor(torch.optim.Optimizer):
3434
beta1 (float): coefficient used for computing running averages of gradient (default: None)
3535
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
3636
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37-
relative_step (bool): if True, time-dependent learning rate is computed
38-
instead of external learning rate (default: True)
3937
warmup_init (bool): time-dependent learning rate computation depends on
4038
whether warm-up initialization is being used (default: False)
4139
"""
4240

4341
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
4442
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
45-
relative_step = lr is None
43+
relative_step = not lr
4644
if warmup_init and not relative_step:
4745
raise ValueError('warmup_init requires relative_step=True')
4846

@@ -138,37 +136,32 @@ def step(self, closure=None):
138136
exp_avg_sq_row = state['exp_avg_sq_row']
139137
exp_avg_sq_col = state['exp_avg_sq_col']
140138

141-
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
142-
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
143-
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
144-
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
139+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
140+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
145141

146142
# Approximation of exponential moving average of square of gradient
147143
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
148144
update.mul_(grad)
149145
else:
150146
exp_avg_sq = state['exp_avg_sq']
151147

152-
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
153-
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
148+
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
154149
update = exp_avg_sq.rsqrt().mul_(grad)
155150

156151
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
157152
update.mul_(lr_t)
158153

159154
if use_first_moment:
160155
exp_avg = state['exp_avg']
161-
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
162-
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
156+
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
163157
update = exp_avg
164158

165159
if group['weight_decay'] != 0:
166-
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
167-
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
160+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
168161

169162
p_data_fp32.add_(-update)
170163

171164
if p.data.dtype in {torch.float16, torch.bfloat16}:
172165
p.data.copy_(p_data_fp32)
173166

174-
return loss
167+
return loss

timm/optim/adamp.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,43 @@
99
"""
1010

1111
import torch
12-
import torch.nn as nn
13-
from torch.optim.optimizer import Optimizer, required
12+
import torch.nn.functional as F
13+
from torch.optim.optimizer import Optimizer
1414
import math
1515

16-
class AdamP(Optimizer):
17-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
18-
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
19-
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
20-
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
21-
super(AdamP, self).__init__(params, defaults)
22-
23-
def _channel_view(self, x):
24-
return x.view(x.size(0), -1)
2516

26-
def _layer_view(self, x):
27-
return x.view(1, -1)
17+
def _channel_view(x) -> torch.Tensor:
18+
return x.reshape(x.size(0), -1)
2819

29-
def _cosine_similarity(self, x, y, eps, view_func):
30-
x = view_func(x)
31-
y = view_func(y)
3220

33-
x_norm = x.norm(dim=1).add_(eps)
34-
y_norm = y.norm(dim=1).add_(eps)
35-
dot = (x * y).sum(dim=1)
21+
def _layer_view(x) -> torch.Tensor:
22+
return x.reshape(1, -1)
3623

37-
return dot.abs() / x_norm / y_norm
3824

39-
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40-
wd = 1
41-
expand_size = [-1] + [1] * (len(p.shape) - 1)
42-
for view_func in [self._channel_view, self._layer_view]:
25+
def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
26+
wd = 1.
27+
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
28+
for view_func in [_channel_view, _layer_view]:
29+
param_view = view_func(p.data)
30+
grad_view = view_func(grad)
31+
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
4332

44-
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
33+
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
34+
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
35+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
36+
wd = wd_ratio
37+
return perturb, wd
4538

46-
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47-
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48-
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49-
wd = wd_ratio
39+
return perturb, wd
5040

51-
return perturb, wd
5241

53-
return perturb, wd
42+
class AdamP(Optimizer):
43+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
44+
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
45+
defaults = dict(
46+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
47+
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
48+
super(AdamP, self).__init__(params, defaults)
5449

5550
def step(self, closure=None):
5651
loss = None
@@ -81,8 +76,8 @@ def step(self, closure=None):
8176
bias_correction1 = 1 - beta1 ** state['step']
8277
bias_correction2 = 1 - beta2 ** state['step']
8378

84-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
85-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
79+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
80+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
8681

8782
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
8883
step_size = group['lr'] / bias_correction1
@@ -93,15 +88,15 @@ def step(self, closure=None):
9388
perturb = exp_avg / denom
9489

9590
# Projection
96-
wd_ratio = 1
91+
wd_ratio = 1.
9792
if len(p.shape) > 1:
98-
perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
93+
perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
9994

10095
# Weight decay
10196
if group['weight_decay'] > 0:
102-
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
97+
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
10398

10499
# Step
105-
p.data.add_(-step_size, perturb)
100+
p.data.add_(perturb, alpha=-step_size)
106101

107102
return loss

timm/optim/adamw.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
""" AdamW Optimizer
22
Impl copied from PyTorch master
3+
4+
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
5+
someday
36
"""
47
import math
58
import torch
@@ -100,8 +103,8 @@ def step(self, closure=None):
100103
bias_correction2 = 1 - beta2 ** state['step']
101104

102105
# Decay the first and second moment running average coefficient
103-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
104-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
106+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
107+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
105108
if amsgrad:
106109
# Maintains the maximum of all 2nd moment running avg. till now
107110
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
@@ -112,6 +115,6 @@ def step(self, closure=None):
112115

113116
step_size = group['lr'] / bias_correction1
114117

115-
p.data.addcdiv_(-step_size, exp_avg, denom)
118+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
116119

117120
return loss

0 commit comments

Comments
 (0)