Skip to content

Commit a426511

Browse files
committed
More optimizer cleanup. Change all to no longer use .data. Improve (b)float16 use with adabelief. Add XLA compatible Lars.
1 parent 9541f49 commit a426511

File tree

15 files changed

+332
-141
lines changed

15 files changed

+332
-141
lines changed

tests/test_optim.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,33 @@ def test_lamb(optimizer):
490490
_test_model(optimizer, dict(lr=1e-3))
491491

492492

493+
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
494+
def test_lars(optimizer):
495+
_test_basic_cases(
496+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
497+
)
498+
_test_basic_cases(
499+
lambda weight, bias: create_optimizer_v2(
500+
_build_params_dict(weight, bias, lr=1e-3),
501+
optimizer,
502+
lr=1e-1)
503+
)
504+
_test_basic_cases(
505+
lambda weight, bias: create_optimizer_v2(
506+
_build_params_dict_single(weight, bias, lr=1e-3),
507+
optimizer,
508+
lr=1e-3)
509+
)
510+
_test_basic_cases(
511+
lambda weight, bias: create_optimizer_v2(
512+
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
513+
)
514+
_test_rosenbrock(
515+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
516+
)
517+
_test_model(optimizer, dict(lr=1e-3))
518+
519+
493520
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
494521
def test_madgrad(optimizer):
495522
_test_basic_cases(

timm/optim/adabelief.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __setstate__(self, state):
6868
for group in self.param_groups:
6969
group.setdefault('amsgrad', False)
7070

71+
@torch.no_grad()
7172
def reset(self):
7273
for group in self.param_groups:
7374
for p in group['params']:
@@ -77,14 +78,15 @@ def reset(self):
7778
# State initialization
7879
state['step'] = 0
7980
# Exponential moving average of gradient values
80-
state['exp_avg'] = torch.zeros_like(p.data)
81+
state['exp_avg'] = torch.zeros_like(p)
8182

8283
# Exponential moving average of squared gradient values
83-
state['exp_avg_var'] = torch.zeros_like(p.data)
84+
state['exp_avg_var'] = torch.zeros_like(p)
8485
if amsgrad:
8586
# Maintains max of all exp. moving avg. of sq. grad. values
86-
state['max_exp_avg_var'] = torch.zeros_like(p.data)
87+
state['max_exp_avg_var'] = torch.zeros_like(p)
8788

89+
@torch.no_grad()
8890
def step(self, closure=None):
8991
"""Performs a single optimization step.
9092
Arguments:
@@ -93,50 +95,47 @@ def step(self, closure=None):
9395
"""
9496
loss = None
9597
if closure is not None:
96-
loss = closure()
98+
with torch.enable_grad():
99+
loss = closure()
97100

98101
for group in self.param_groups:
99102
for p in group['params']:
100103
if p.grad is None:
101104
continue
102-
103-
# cast data type
104-
half_precision = False
105-
if p.data.dtype == torch.float16:
106-
half_precision = True
107-
p.data = p.data.float()
108-
p.grad = p.grad.float()
109-
110-
grad = p.grad.data
105+
grad = p.grad
106+
if grad.dtype in {torch.float16, torch.bfloat16}:
107+
grad = grad.float()
111108
if grad.is_sparse:
112109
raise RuntimeError(
113110
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
114-
amsgrad = group['amsgrad']
115111

116-
state = self.state[p]
112+
p_fp32 = p
113+
if p.dtype in {torch.float16, torch.bfloat16}:
114+
p_fp32 = p_fp32.float()
117115

116+
amsgrad = group['amsgrad']
118117
beta1, beta2 = group['betas']
119-
118+
state = self.state[p]
120119
# State initialization
121120
if len(state) == 0:
122121
state['step'] = 0
123122
# Exponential moving average of gradient values
124-
state['exp_avg'] = torch.zeros_like(p.data)
123+
state['exp_avg'] = torch.zeros_like(p_fp32)
125124
# Exponential moving average of squared gradient values
126-
state['exp_avg_var'] = torch.zeros_like(p.data)
125+
state['exp_avg_var'] = torch.zeros_like(p_fp32)
127126
if amsgrad:
128127
# Maintains max of all exp. moving avg. of sq. grad. values
129-
state['max_exp_avg_var'] = torch.zeros_like(p.data)
128+
state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
130129

131130
# perform weight decay, check if decoupled weight decay
132131
if group['decoupled_decay']:
133132
if not group['fixed_decay']:
134-
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
133+
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
135134
else:
136-
p.data.mul_(1.0 - group['weight_decay'])
135+
p_fp32.mul_(1.0 - group['weight_decay'])
137136
else:
138137
if group['weight_decay'] != 0:
139-
grad.add_(p.data, alpha=group['weight_decay'])
138+
grad.add_(p_fp32, alpha=group['weight_decay'])
140139

141140
# get current state variable
142141
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
@@ -164,7 +163,7 @@ def step(self, closure=None):
164163
if not group['rectify']:
165164
# Default update
166165
step_size = group['lr'] / bias_correction1
167-
p.data.addcdiv_(exp_avg, denom, value=-step_size)
166+
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
168167
else:
169168
# Rectified update, forked from RAdam
170169
buffered = group['buffer'][int(state['step'] % 10)]
@@ -192,12 +191,11 @@ def step(self, closure=None):
192191

193192
if num_sma >= 5:
194193
denom = exp_avg_var.sqrt().add_(group['eps'])
195-
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
194+
p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
196195
elif step_size > 0:
197-
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
196+
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
198197

199-
if half_precision:
200-
p.data = p.data.half()
201-
p.grad = p.grad.half()
198+
if p.dtype in {torch.float16, torch.bfloat16}:
199+
p.copy_(p_fp32)
202200

203201
return loss

timm/optim/adafactor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,30 @@ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
7676
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
7777
return torch.mul(r_factor, c_factor)
7878

79+
@torch.no_grad()
7980
def step(self, closure=None):
8081
"""Performs a single optimization step.
8182
Arguments:
8283
closure (callable, optional): A closure that reevaluates the model and returns the loss.
8384
"""
8485
loss = None
8586
if closure is not None:
86-
loss = closure()
87+
with torch.enable_grad():
88+
loss = closure()
8789

8890
for group in self.param_groups:
8991
for p in group['params']:
9092
if p.grad is None:
9193
continue
92-
grad = p.grad.data
94+
grad = p.grad
9395
if grad.dtype in {torch.float16, torch.bfloat16}:
9496
grad = grad.float()
9597
if grad.is_sparse:
9698
raise RuntimeError('Adafactor does not support sparse gradients.')
9799

98100
state = self.state[p]
99-
grad_shape = grad.shape
100101

101-
factored, use_first_moment = self._get_options(group, grad_shape)
102+
factored, use_first_moment = self._get_options(group, grad.shape)
102103
# State Initialization
103104
if len(state) == 0:
104105
state['step'] = 0
@@ -107,8 +108,8 @@ def step(self, closure=None):
107108
# Exponential moving average of gradient values
108109
state['exp_avg'] = torch.zeros_like(grad)
109110
if factored:
110-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
111-
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
111+
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
112+
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
112113
else:
113114
state['exp_avg_sq'] = torch.zeros_like(grad)
114115

@@ -122,12 +123,12 @@ def step(self, closure=None):
122123
else:
123124
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
124125

125-
p_data_fp32 = p.data
126-
if p.data.dtype in {torch.float16, torch.bfloat16}:
127-
p_data_fp32 = p_data_fp32.float()
126+
p_fp32 = p
127+
if p.dtype in {torch.float16, torch.bfloat16}:
128+
p_fp32 = p_fp32.float()
128129

129130
state['step'] += 1
130-
state['RMS'] = self._rms(p_data_fp32)
131+
state['RMS'] = self._rms(p_fp32)
131132
lr_t = self._get_lr(group, state)
132133

133134
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
@@ -157,11 +158,10 @@ def step(self, closure=None):
157158
update = exp_avg
158159

159160
if group['weight_decay'] != 0:
160-
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
161+
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
161162

162-
p_data_fp32.add_(-update)
163-
164-
if p.data.dtype in {torch.float16, torch.bfloat16}:
165-
p.data.copy_(p_data_fp32)
163+
p_fp32.add_(-update)
164+
if p.dtype in {torch.float16, torch.bfloat16}:
165+
p.copy_(p_fp32)
166166

167167
return loss

timm/optim/adamp.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
2626
wd = 1.
2727
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
2828
for view_func in [_channel_view, _layer_view]:
29-
param_view = view_func(p.data)
29+
param_view = view_func(p)
3030
grad_view = view_func(grad)
3131
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
3232

33+
# FIXME this is a problem for PyTorch XLA
3334
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
34-
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
35+
p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
3536
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
3637
wd = wd_ratio
3738
return perturb, wd
@@ -47,17 +48,19 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
4748
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
4849
super(AdamP, self).__init__(params, defaults)
4950

51+
@torch.no_grad()
5052
def step(self, closure=None):
5153
loss = None
5254
if closure is not None:
53-
loss = closure()
55+
with torch.enable_grad():
56+
loss = closure()
5457

5558
for group in self.param_groups:
5659
for p in group['params']:
5760
if p.grad is None:
5861
continue
5962

60-
grad = p.grad.data
63+
grad = p.grad
6164
beta1, beta2 = group['betas']
6265
nesterov = group['nesterov']
6366

@@ -66,8 +69,8 @@ def step(self, closure=None):
6669
# State initialization
6770
if len(state) == 0:
6871
state['step'] = 0
69-
state['exp_avg'] = torch.zeros_like(p.data)
70-
state['exp_avg_sq'] = torch.zeros_like(p.data)
72+
state['exp_avg'] = torch.zeros_like(p)
73+
state['exp_avg_sq'] = torch.zeros_like(p)
7174

7275
# Adam
7376
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@@ -94,9 +97,9 @@ def step(self, closure=None):
9497

9598
# Weight decay
9699
if group['weight_decay'] > 0:
97-
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
100+
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
98101

99102
# Step
100-
p.data.add_(perturb, alpha=-step_size)
103+
p.add_(perturb, alpha=-step_size)
101104

102105
return loss

timm/optim/adamw.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __setstate__(self, state):
5555
for group in self.param_groups:
5656
group.setdefault('amsgrad', False)
5757

58+
@torch.no_grad()
5859
def step(self, closure=None):
5960
"""Performs a single optimization step.
6061
@@ -64,7 +65,8 @@ def step(self, closure=None):
6465
"""
6566
loss = None
6667
if closure is not None:
67-
loss = closure()
68+
with torch.enable_grad():
69+
loss = closure()
6870

6971
for group in self.param_groups:
7072
for p in group['params']:
@@ -75,7 +77,7 @@ def step(self, closure=None):
7577
p.data.mul_(1 - group['lr'] * group['weight_decay'])
7678

7779
# Perform optimization step
78-
grad = p.grad.data
80+
grad = p.grad
7981
if grad.is_sparse:
8082
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
8183
amsgrad = group['amsgrad']
@@ -86,12 +88,12 @@ def step(self, closure=None):
8688
if len(state) == 0:
8789
state['step'] = 0
8890
# Exponential moving average of gradient values
89-
state['exp_avg'] = torch.zeros_like(p.data)
91+
state['exp_avg'] = torch.zeros_like(p)
9092
# Exponential moving average of squared gradient values
91-
state['exp_avg_sq'] = torch.zeros_like(p.data)
93+
state['exp_avg_sq'] = torch.zeros_like(p)
9294
if amsgrad:
9395
# Maintains max of all exp. moving avg. of sq. grad. values
94-
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
96+
state['max_exp_avg_sq'] = torch.zeros_like(p)
9597

9698
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
9799
if amsgrad:
@@ -115,6 +117,6 @@ def step(self, closure=None):
115117

116118
step_size = group['lr'] / bias_correction1
117119

118-
p.data.addcdiv_(exp_avg, denom, value=-step_size)
120+
p.addcdiv_(exp_avg, denom, value=-step_size)
119121

120122
return loss

0 commit comments

Comments
 (0)