Skip to content

Commit cd3dc49

Browse files
committed
Fix adabelief imports, remove prints, preserve memory format is the default arg for zeros_like
1 parent 21812d3 commit cd3dc49

File tree

2 files changed

+14
-53
lines changed

2 files changed

+14
-53
lines changed

timm/optim/adabelief.py

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import math
22
import torch
33
from torch.optim.optimizer import Optimizer
4-
from tabulate import tabulate
5-
from colorama import Fore, Back, Style
64

7-
version_higher = ( torch.__version__ >= "1.5.0" )
85

96
class AdaBelief(Optimizer):
107
r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
8+
119
Arguments:
1210
params (iterable): iterable of parameters to optimize or dicts defining
1311
parameter groups
@@ -33,39 +31,17 @@ class AdaBelief(Optimizer):
3331
update similar to RAdam
3432
degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
3533
when variance of gradient is high
36-
print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
37-
default hyper-parameters
3834
reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
35+
36+
For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
37+
For example train/args for EfficientNet see these gists
38+
- link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
39+
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
3940
"""
4041

4142
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
4243
weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
43-
degenerated_to_sgd=True, print_change_log = True):
44-
45-
# ------------------------------------------------------------------------------
46-
# Print modifications to default arguments
47-
if print_change_log:
48-
print(Fore.RED + 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.')
49-
print(Fore.RED + 'Modifications to default arguments:')
50-
default_table = tabulate([
51-
['adabelief-pytorch=0.0.5','1e-8','False','False'],
52-
['>=0.1.0 (Current 0.2.0)','1e-16','True','True']],
53-
headers=['eps','weight_decouple','rectify'])
54-
print(Fore.RED + default_table)
55-
56-
recommend_table = tabulate([
57-
['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
58-
],
59-
headers=['SGD better than Adam (e.g. CNN for Image Classification)','Adam better than SGD (e.g. Transformer, GAN)'])
60-
print(Fore.BLUE + recommend_table)
61-
62-
print(Fore.BLUE +'For a complete table of recommended hyperparameters, see')
63-
print(Fore.BLUE + 'https://github.com/juntang-zhuang/Adabelief-Optimizer')
64-
65-
print(Fore.GREEN + 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.')
66-
67-
print(Style.RESET_ALL)
68-
# ------------------------------------------------------------------------------
44+
degenerated_to_sgd=True):
6945

7046
if not 0.0 <= lr:
7147
raise ValueError("Invalid learning rate: {}".format(lr))
@@ -90,14 +66,6 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
9066
self.weight_decouple = weight_decouple
9167
self.rectify = rectify
9268
self.fixed_decay = fixed_decay
93-
if self.weight_decouple:
94-
print('Weight decoupling enabled in AdaBelief')
95-
if self.fixed_decay:
96-
print('Weight decay fixed')
97-
if self.rectify:
98-
print('Rectification enabled in AdaBelief')
99-
if amsgrad:
100-
print('AMSGrad enabled in AdaBelief')
10169

10270
def __setstate__(self, state):
10371
super(AdaBelief, self).__setstate__(state)
@@ -113,17 +81,13 @@ def reset(self):
11381
# State initialization
11482
state['step'] = 0
11583
# Exponential moving average of gradient values
116-
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
117-
if version_higher else torch.zeros_like(p.data)
84+
state['exp_avg'] = torch.zeros_like(p.data)
11885

11986
# Exponential moving average of squared gradient values
120-
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
121-
if version_higher else torch.zeros_like(p.data)
122-
87+
state['exp_avg_var'] = torch.zeros_like(p.data)
12388
if amsgrad:
12489
# Maintains max of all exp. moving avg. of sq. grad. values
125-
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
126-
if version_higher else torch.zeros_like(p.data)
90+
state['max_exp_avg_var'] = torch.zeros_like(p.data)
12791

12892
def step(self, closure=None):
12993
"""Performs a single optimization step.
@@ -161,15 +125,12 @@ def step(self, closure=None):
161125
if len(state) == 0:
162126
state['step'] = 0
163127
# Exponential moving average of gradient values
164-
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
165-
if version_higher else torch.zeros_like(p.data)
128+
state['exp_avg'] = torch.zeros_like(p.data)
166129
# Exponential moving average of squared gradient values
167-
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
168-
if version_higher else torch.zeros_like(p.data)
130+
state['exp_avg_var'] = torch.zeros_like(p.data)
169131
if amsgrad:
170132
# Maintains max of all exp. moving avg. of sq. grad. values
171-
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
172-
if version_higher else torch.zeros_like(p.data)
133+
state['max_exp_avg_var'] = torch.zeros_like(p.data)
173134

174135
# perform weight decay, check if decoupled weight decay
175136
if self.weight_decouple:

timm/optim/optim_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def create_optimizer_v2(
121121
elif opt_lower == 'adam':
122122
optimizer = optim.Adam(parameters, **opt_args)
123123
elif opt_lower == 'adabelief':
124-
optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args)
124+
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
125125
elif opt_lower == 'adamw':
126126
optimizer = optim.AdamW(parameters, **opt_args)
127127
elif opt_lower == 'nadam':

0 commit comments

Comments
 (0)