@@ -19,16 +19,20 @@ class RMSpropTF(Optimizer):
1919 parameter groups
2020 lr (float, optional): learning rate (default: 1e-2)
2121 momentum (float, optional): momentum factor (default: 0)
22- alpha (float, optional): smoothing constant (default: 0.99 )
22+ alpha (float, optional): smoothing (decay) constant (default: 0.9 )
2323 eps (float, optional): term added to the denominator to improve
24- numerical stability (default: 1e-8 )
24+ numerical stability (default: 1e-10 )
2525 centered (bool, optional) : if ``True``, compute the centered RMSProp,
2626 the gradient is normalized by an estimation of its variance
2727 weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
28+ decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
29+ lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
30+ update as per defaults in Tensorflow
2831
2932 """
3033
31- def __init__ (self , params , lr = 1e-2 , alpha = 0.99 , eps = 1e-8 , weight_decay = 0 , momentum = 0 , centered = False ):
34+ def __init__ (self , params , lr = 1e-2 , alpha = 0.9 , eps = 1e-10 , weight_decay = 0 , momentum = 0. , centered = False ,
35+ decoupled_decay = False , lr_in_momentum = True ):
3236 if not 0.0 <= lr :
3337 raise ValueError ("Invalid learning rate: {}" .format (lr ))
3438 if not 0.0 <= eps :
@@ -40,7 +44,8 @@ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, moment
4044 if not 0.0 <= alpha :
4145 raise ValueError ("Invalid alpha value: {}" .format (alpha ))
4246
43- defaults = dict (lr = lr , momentum = momentum , alpha = alpha , eps = eps , centered = centered , weight_decay = weight_decay )
47+ defaults = dict (lr = lr , momentum = momentum , alpha = alpha , eps = eps , centered = centered , weight_decay = weight_decay ,
48+ decoupled_decay = decoupled_decay , lr_in_momentum = lr_in_momentum )
4449 super (RMSpropTF , self ).__init__ (params , defaults )
4550
4651 def __setstate__ (self , state ):
@@ -72,33 +77,45 @@ def step(self, closure=None):
7277 # State initialization
7378 if len (state ) == 0 :
7479 state ['step' ] = 0
75- state ['square_avg' ] = torch .zeros_like (p .data )
80+ state ['square_avg' ] = torch .ones_like (p .data ) # PyTorch inits to zero
7681 if group ['momentum' ] > 0 :
7782 state ['momentum_buffer' ] = torch .zeros_like (p .data )
7883 if group ['centered' ]:
7984 state ['grad_avg' ] = torch .zeros_like (p .data )
8085
8186 square_avg = state ['square_avg' ]
82- alpha = group ['alpha' ]
87+ one_minus_alpha = 1. - group ['alpha' ]
8388
8489 state ['step' ] += 1
8590
8691 if group ['weight_decay' ] != 0 :
87- grad = grad .add (group ['weight_decay' ], p .data )
92+ if group ['decoupled_decay' ]:
93+ p .data .add_ (- group ['weight_decay' ], p .data )
94+ else :
95+ grad = grad .add (group ['weight_decay' ], p .data )
8896
89- square_avg .mul_ (alpha ).addcmul_ (1 - alpha , grad , grad )
97+ # Tensorflow order of ops for updating squared avg
98+ square_avg .add_ (one_minus_alpha , grad .pow (2 ) - square_avg )
99+ # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
90100
91101 if group ['centered' ]:
92102 grad_avg = state ['grad_avg' ]
93- grad_avg .mul_ (alpha ).add_ (1 - alpha , grad )
94- avg = square_avg .addcmul (- 1 , grad_avg , grad_avg ).add (group ['eps' ]).sqrt_ ()
103+ grad_avg .add_ (one_minus_alpha , grad - grad_avg )
104+ # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
105+ avg = square_avg .addcmul (- 1 , grad_avg , grad_avg ).add (group ['eps' ]).sqrt_ () # eps moved in sqrt
95106 else :
96- avg = square_avg .add (group ['eps' ]).sqrt_ ()
107+ avg = square_avg .add (group ['eps' ]).sqrt_ () # eps moved in sqrt
97108
98109 if group ['momentum' ] > 0 :
99110 buf = state ['momentum_buffer' ]
100- buf .mul_ (group ['momentum' ]).addcdiv_ (grad , avg )
101- p .data .add_ (- group ['lr' ], buf )
111+ # Tensorflow accumulates the LR scaling in the momentum buffer
112+ if group ['lr_in_momentum' ]:
113+ buf .mul_ (group ['momentum' ]).addcdiv_ (group ['lr' ], grad , avg )
114+ p .data .add_ (- buf )
115+ else :
116+ # PyTorch scales the param update by LR
117+ buf .mul_ (group ['momentum' ]).addcdiv_ (grad , avg )
118+ p .data .add_ (- group ['lr' ], buf )
102119 else :
103120 p .data .addcdiv_ (- group ['lr' ], grad , avg )
104121
0 commit comments