@@ -68,6 +68,7 @@ def __setstate__(self, state):
6868 for group in self .param_groups :
6969 group .setdefault ('amsgrad' , False )
7070
71+ @torch .no_grad ()
7172 def reset (self ):
7273 for group in self .param_groups :
7374 for p in group ['params' ]:
@@ -77,14 +78,15 @@ def reset(self):
7778 # State initialization
7879 state ['step' ] = 0
7980 # Exponential moving average of gradient values
80- state ['exp_avg' ] = torch .zeros_like (p . data )
81+ state ['exp_avg' ] = torch .zeros_like (p )
8182
8283 # Exponential moving average of squared gradient values
83- state ['exp_avg_var' ] = torch .zeros_like (p . data )
84+ state ['exp_avg_var' ] = torch .zeros_like (p )
8485 if amsgrad :
8586 # Maintains max of all exp. moving avg. of sq. grad. values
86- state ['max_exp_avg_var' ] = torch .zeros_like (p . data )
87+ state ['max_exp_avg_var' ] = torch .zeros_like (p )
8788
89+ @torch .no_grad ()
8890 def step (self , closure = None ):
8991 """Performs a single optimization step.
9092 Arguments:
@@ -93,50 +95,47 @@ def step(self, closure=None):
9395 """
9496 loss = None
9597 if closure is not None :
96- loss = closure ()
98+ with torch .enable_grad ():
99+ loss = closure ()
97100
98101 for group in self .param_groups :
99102 for p in group ['params' ]:
100103 if p .grad is None :
101104 continue
102-
103- # cast data type
104- half_precision = False
105- if p .data .dtype == torch .float16 :
106- half_precision = True
107- p .data = p .data .float ()
108- p .grad = p .grad .float ()
109-
110- grad = p .grad .data
105+ grad = p .grad
106+ if grad .dtype in {torch .float16 , torch .bfloat16 }:
107+ grad = grad .float ()
111108 if grad .is_sparse :
112109 raise RuntimeError (
113110 'AdaBelief does not support sparse gradients, please consider SparseAdam instead' )
114- amsgrad = group ['amsgrad' ]
115111
116- state = self .state [p ]
112+ p_fp32 = p
113+ if p .dtype in {torch .float16 , torch .bfloat16 }:
114+ p_fp32 = p_fp32 .float ()
117115
116+ amsgrad = group ['amsgrad' ]
118117 beta1 , beta2 = group ['betas' ]
119-
118+ state = self . state [ p ]
120119 # State initialization
121120 if len (state ) == 0 :
122121 state ['step' ] = 0
123122 # Exponential moving average of gradient values
124- state ['exp_avg' ] = torch .zeros_like (p . data )
123+ state ['exp_avg' ] = torch .zeros_like (p_fp32 )
125124 # Exponential moving average of squared gradient values
126- state ['exp_avg_var' ] = torch .zeros_like (p . data )
125+ state ['exp_avg_var' ] = torch .zeros_like (p_fp32 )
127126 if amsgrad :
128127 # Maintains max of all exp. moving avg. of sq. grad. values
129- state ['max_exp_avg_var' ] = torch .zeros_like (p . data )
128+ state ['max_exp_avg_var' ] = torch .zeros_like (p_fp32 )
130129
131130 # perform weight decay, check if decoupled weight decay
132131 if group ['decoupled_decay' ]:
133132 if not group ['fixed_decay' ]:
134- p . data .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
133+ p_fp32 .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
135134 else :
136- p . data .mul_ (1.0 - group ['weight_decay' ])
135+ p_fp32 .mul_ (1.0 - group ['weight_decay' ])
137136 else :
138137 if group ['weight_decay' ] != 0 :
139- grad .add_ (p . data , alpha = group ['weight_decay' ])
138+ grad .add_ (p_fp32 , alpha = group ['weight_decay' ])
140139
141140 # get current state variable
142141 exp_avg , exp_avg_var = state ['exp_avg' ], state ['exp_avg_var' ]
@@ -164,7 +163,7 @@ def step(self, closure=None):
164163 if not group ['rectify' ]:
165164 # Default update
166165 step_size = group ['lr' ] / bias_correction1
167- p . data .addcdiv_ (exp_avg , denom , value = - step_size )
166+ p_fp32 .addcdiv_ (exp_avg , denom , value = - step_size )
168167 else :
169168 # Rectified update, forked from RAdam
170169 buffered = group ['buffer' ][int (state ['step' ] % 10 )]
@@ -192,12 +191,11 @@ def step(self, closure=None):
192191
193192 if num_sma >= 5 :
194193 denom = exp_avg_var .sqrt ().add_ (group ['eps' ])
195- p . data .addcdiv_ (exp_avg , denom , value = - step_size * group ['lr' ])
194+ p_fp32 .addcdiv_ (exp_avg , denom , value = - step_size * group ['lr' ])
196195 elif step_size > 0 :
197- p . data .add_ (exp_avg , alpha = - step_size * group ['lr' ])
196+ p_fp32 .add_ (exp_avg , alpha = - step_size * group ['lr' ])
198197
199- if half_precision :
200- p .data = p .data .half ()
201- p .grad = p .grad .half ()
198+ if p .dtype in {torch .float16 , torch .bfloat16 }:
199+ p .copy_ (p_fp32 )
202200
203201 return loss
0 commit comments