Skip to content

Commit 20d66be

Browse files
committed
Move RMSpropTF another step closer to Tensorflow impl
* init square_avg with one instead of zero as per TF * match TF order of ops for square_avg accumulation * move LR scaling to momentum buffer accumulator as per TF * add decoupled weight decay flag (not in TF)
1 parent 89147a9 commit 20d66be

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

optim/rmsprop_tf.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@ class RMSpropTF(Optimizer):
1919
parameter groups
2020
lr (float, optional): learning rate (default: 1e-2)
2121
momentum (float, optional): momentum factor (default: 0)
22-
alpha (float, optional): smoothing constant (default: 0.99)
22+
alpha (float, optional): smoothing (decay) constant (default: 0.9)
2323
eps (float, optional): term added to the denominator to improve
24-
numerical stability (default: 1e-8)
24+
numerical stability (default: 1e-10)
2525
centered (bool, optional) : if ``True``, compute the centered RMSProp,
2626
the gradient is normalized by an estimation of its variance
2727
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
28+
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
29+
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
30+
update as per defaults in Tensorflow
2831
2932
"""
3033

31-
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
34+
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
35+
decoupled_decay=False, lr_in_momentum=True):
3236
if not 0.0 <= lr:
3337
raise ValueError("Invalid learning rate: {}".format(lr))
3438
if not 0.0 <= eps:
@@ -40,7 +44,8 @@ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, moment
4044
if not 0.0 <= alpha:
4145
raise ValueError("Invalid alpha value: {}".format(alpha))
4246

43-
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
47+
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
48+
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
4449
super(RMSpropTF, self).__init__(params, defaults)
4550

4651
def __setstate__(self, state):
@@ -72,33 +77,45 @@ def step(self, closure=None):
7277
# State initialization
7378
if len(state) == 0:
7479
state['step'] = 0
75-
state['square_avg'] = torch.zeros_like(p.data)
80+
state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
7681
if group['momentum'] > 0:
7782
state['momentum_buffer'] = torch.zeros_like(p.data)
7883
if group['centered']:
7984
state['grad_avg'] = torch.zeros_like(p.data)
8085

8186
square_avg = state['square_avg']
82-
alpha = group['alpha']
87+
one_minus_alpha = 1. - group['alpha']
8388

8489
state['step'] += 1
8590

8691
if group['weight_decay'] != 0:
87-
grad = grad.add(group['weight_decay'], p.data)
92+
if group['decoupled_decay']:
93+
p.data.add_(-group['weight_decay'], p.data)
94+
else:
95+
grad = grad.add(group['weight_decay'], p.data)
8896

89-
square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
97+
# Tensorflow order of ops for updating squared avg
98+
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
99+
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
90100

91101
if group['centered']:
92102
grad_avg = state['grad_avg']
93-
grad_avg.mul_(alpha).add_(1 - alpha, grad)
94-
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_()
103+
grad_avg.add_(one_minus_alpha, grad - grad_avg)
104+
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
105+
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
95106
else:
96-
avg = square_avg.add(group['eps']).sqrt_()
107+
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
97108

98109
if group['momentum'] > 0:
99110
buf = state['momentum_buffer']
100-
buf.mul_(group['momentum']).addcdiv_(grad, avg)
101-
p.data.add_(-group['lr'], buf)
111+
# Tensorflow accumulates the LR scaling in the momentum buffer
112+
if group['lr_in_momentum']:
113+
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
114+
p.data.add_(-buf)
115+
else:
116+
# PyTorch scales the param update by LR
117+
buf.mul_(group['momentum']).addcdiv_(grad, avg)
118+
p.data.add_(-group['lr'], buf)
102119
else:
103120
p.data.addcdiv_(-group['lr'], grad, avg)
104121

0 commit comments

Comments
 (0)