1010
1111import torch
1212from torch .testing ._internal .common_utils import TestCase
13- from torch .autograd import Variable
13+ from torch .nn import Parameter
1414from timm .scheduler import PlateauLRScheduler
1515
1616from timm .optim import create_optimizer_v2
2121
2222
2323def _test_basic_cases_template (weight , bias , input , constructor , scheduler_constructors ):
24- weight = Variable (weight , requires_grad = True )
25- bias = Variable (bias , requires_grad = True )
26- input = Variable (input )
24+ weight = Parameter (weight )
25+ bias = Parameter (bias )
26+ input = Parameter (input )
2727 optimizer = constructor (weight , bias )
2828 schedulers = []
2929 for scheduler_constructor in scheduler_constructors :
@@ -55,9 +55,9 @@ def fn():
5555
5656
5757def _test_state_dict (weight , bias , input , constructor ):
58- weight = Variable (weight , requires_grad = True )
59- bias = Variable (bias , requires_grad = True )
60- input = Variable (input )
58+ weight = Parameter (weight )
59+ bias = Parameter (bias )
60+ input = Parameter (input )
6161
6262 def fn_base (optimizer , weight , bias ):
6363 optimizer .zero_grad ()
@@ -73,8 +73,9 @@ def fn_base(optimizer, weight, bias):
7373 for _i in range (20 ):
7474 optimizer .step (fn )
7575 # Clone the weights and construct new optimizer for them
76- weight_c = Variable (weight .data .clone (), requires_grad = True )
77- bias_c = Variable (bias .data .clone (), requires_grad = True )
76+ with torch .no_grad ():
77+ weight_c = Parameter (weight .clone ().detach ())
78+ bias_c = Parameter (bias .clone ().detach ())
7879 optimizer_c = constructor (weight_c , bias_c )
7980 fn_c = functools .partial (fn_base , optimizer_c , weight_c , bias_c )
8081 # Load state dict
@@ -86,12 +87,8 @@ def fn_base(optimizer, weight, bias):
8687 for _i in range (20 ):
8788 optimizer .step (fn )
8889 optimizer_c .step (fn_c )
89- #assert torch.equal(weight, weight_c)
90- #assert torch.equal(bias, bias_c)
9190 torch_tc .assertEqual (weight , weight_c )
9291 torch_tc .assertEqual (bias , bias_c )
93- # Make sure state dict wasn't modified
94- torch_tc .assertEqual (state_dict , state_dict_c )
9592 # Make sure state dict is deterministic with equal but not identical parameters
9693 torch_tc .assertEqual (optimizer .state_dict (), optimizer_c .state_dict ())
9794 # Make sure repeated parameters have identical representation in state dict
@@ -103,9 +100,10 @@ def fn_base(optimizer, weight, bias):
103100 if not torch .cuda .is_available ():
104101 return
105102
106- input_cuda = Variable (input .data .float ().cuda ())
107- weight_cuda = Variable (weight .data .float ().cuda (), requires_grad = True )
108- bias_cuda = Variable (bias .data .float ().cuda (), requires_grad = True )
103+ with torch .no_grad ():
104+ input_cuda = Parameter (input .clone ().detach ().float ().cuda ())
105+ weight_cuda = Parameter (weight .clone ().detach ().cuda ())
106+ bias_cuda = Parameter (bias .clone ().detach ().cuda ())
109107 optimizer_cuda = constructor (weight_cuda , bias_cuda )
110108 fn_cuda = functools .partial (fn_base , optimizer_cuda , weight_cuda , bias_cuda )
111109
@@ -216,21 +214,21 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
216214 scheduler_constructors = []
217215 params_t = torch .tensor ([1.5 , 1.5 ])
218216
219- params = Variable (params_t , requires_grad = True )
217+ params = Parameter (params_t )
220218 optimizer = constructor ([params ])
221219 schedulers = []
222220 for scheduler_constructor in scheduler_constructors :
223221 schedulers .append (scheduler_constructor (optimizer ))
224222
225223 solution = torch .tensor ([1 , 1 ])
226- initial_dist = params .data .dist (solution )
224+ initial_dist = params .clone (). detach () .dist (solution )
227225
228226 def eval (params , w ):
229227 # Depending on w, provide only the x or y gradient
230228 optimizer .zero_grad ()
231229 loss = rosenbrock (params )
232230 loss .backward ()
233- grad = drosenbrock (params .data )
231+ grad = drosenbrock (params .clone (). detach () )
234232 # NB: We torture test the optimizer by returning an
235233 # uncoalesced sparse tensor
236234 if w :
@@ -256,7 +254,7 @@ def eval(params, w):
256254 else :
257255 scheduler .step ()
258256
259- torch_tc .assertLessEqual (params .data .dist (solution ), initial_dist )
257+ torch_tc .assertLessEqual (params .clone (). detach () .dist (solution ), initial_dist )
260258
261259
262260def _build_params_dict (weight , bias , ** kwargs ):
0 commit comments