11import math
22import torch
33from torch .optim .optimizer import Optimizer
4- from tabulate import tabulate
5- from colorama import Fore , Back , Style
64
7- version_higher = ( torch .__version__ >= "1.5.0" )
85
96class AdaBelief (Optimizer ):
107 r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
8+
119 Arguments:
1210 params (iterable): iterable of parameters to optimize or dicts defining
1311 parameter groups
@@ -33,39 +31,17 @@ class AdaBelief(Optimizer):
3331 update similar to RAdam
3432 degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
3533 when variance of gradient is high
36- print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
37- default hyper-parameters
3834 reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
35+
36+ For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
37+ For example train/args for EfficientNet see these gists
38+ - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
39+ - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
3940 """
4041
4142 def __init__ (self , params , lr = 1e-3 , betas = (0.9 , 0.999 ), eps = 1e-16 ,
4243 weight_decay = 0 , amsgrad = False , weight_decouple = True , fixed_decay = False , rectify = True ,
43- degenerated_to_sgd = True , print_change_log = True ):
44-
45- # ------------------------------------------------------------------------------
46- # Print modifications to default arguments
47- if print_change_log :
48- print (Fore .RED + 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.' )
49- print (Fore .RED + 'Modifications to default arguments:' )
50- default_table = tabulate ([
51- ['adabelief-pytorch=0.0.5' ,'1e-8' ,'False' ,'False' ],
52- ['>=0.1.0 (Current 0.2.0)' ,'1e-16' ,'True' ,'True' ]],
53- headers = ['eps' ,'weight_decouple' ,'rectify' ])
54- print (Fore .RED + default_table )
55-
56- recommend_table = tabulate ([
57- ['Recommended eps = 1e-8' , 'Recommended eps = 1e-16' ],
58- ],
59- headers = ['SGD better than Adam (e.g. CNN for Image Classification)' ,'Adam better than SGD (e.g. Transformer, GAN)' ])
60- print (Fore .BLUE + recommend_table )
61-
62- print (Fore .BLUE + 'For a complete table of recommended hyperparameters, see' )
63- print (Fore .BLUE + 'https://github.com/juntang-zhuang/Adabelief-Optimizer' )
64-
65- print (Fore .GREEN + 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.' )
66-
67- print (Style .RESET_ALL )
68- # ------------------------------------------------------------------------------
44+ degenerated_to_sgd = True ):
6945
7046 if not 0.0 <= lr :
7147 raise ValueError ("Invalid learning rate: {}" .format (lr ))
@@ -90,14 +66,6 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
9066 self .weight_decouple = weight_decouple
9167 self .rectify = rectify
9268 self .fixed_decay = fixed_decay
93- if self .weight_decouple :
94- print ('Weight decoupling enabled in AdaBelief' )
95- if self .fixed_decay :
96- print ('Weight decay fixed' )
97- if self .rectify :
98- print ('Rectification enabled in AdaBelief' )
99- if amsgrad :
100- print ('AMSGrad enabled in AdaBelief' )
10169
10270 def __setstate__ (self , state ):
10371 super (AdaBelief , self ).__setstate__ (state )
@@ -113,17 +81,13 @@ def reset(self):
11381 # State initialization
11482 state ['step' ] = 0
11583 # Exponential moving average of gradient values
116- state ['exp_avg' ] = torch .zeros_like (p .data ,memory_format = torch .preserve_format ) \
117- if version_higher else torch .zeros_like (p .data )
84+ state ['exp_avg' ] = torch .zeros_like (p .data )
11885
11986 # Exponential moving average of squared gradient values
120- state ['exp_avg_var' ] = torch .zeros_like (p .data ,memory_format = torch .preserve_format ) \
121- if version_higher else torch .zeros_like (p .data )
122-
87+ state ['exp_avg_var' ] = torch .zeros_like (p .data )
12388 if amsgrad :
12489 # Maintains max of all exp. moving avg. of sq. grad. values
125- state ['max_exp_avg_var' ] = torch .zeros_like (p .data ,memory_format = torch .preserve_format ) \
126- if version_higher else torch .zeros_like (p .data )
90+ state ['max_exp_avg_var' ] = torch .zeros_like (p .data )
12791
12892 def step (self , closure = None ):
12993 """Performs a single optimization step.
@@ -161,15 +125,12 @@ def step(self, closure=None):
161125 if len (state ) == 0 :
162126 state ['step' ] = 0
163127 # Exponential moving average of gradient values
164- state ['exp_avg' ] = torch .zeros_like (p .data ,memory_format = torch .preserve_format ) \
165- if version_higher else torch .zeros_like (p .data )
128+ state ['exp_avg' ] = torch .zeros_like (p .data )
166129 # Exponential moving average of squared gradient values
167- state ['exp_avg_var' ] = torch .zeros_like (p .data ,memory_format = torch .preserve_format ) \
168- if version_higher else torch .zeros_like (p .data )
130+ state ['exp_avg_var' ] = torch .zeros_like (p .data )
169131 if amsgrad :
170132 # Maintains max of all exp. moving avg. of sq. grad. values
171- state ['max_exp_avg_var' ] = torch .zeros_like (p .data ,memory_format = torch .preserve_format ) \
172- if version_higher else torch .zeros_like (p .data )
133+ state ['max_exp_avg_var' ] = torch .zeros_like (p .data )
173134
174135 # perform weight decay, check if decoupled weight decay
175136 if self .weight_decouple :
0 commit comments