55import torch .nn as nn
66from torch .testing ._internal .common_utils import TestCase
77import intel_extension_for_pytorch # noqa
8+ import torchvision .models as models
89import pytest
910import os
1011
1112cpu_device = torch .device ("cpu" )
1213xpu_device = torch .device ("xpu" )
1314
1415batch_size = 128
15- class_num = 1000
16- input_channel = 512
17- hidden_channel = 2048
18- num_iter = 10
16+ input_channel = 3
17+ train_num_iter = 5
18+ eval_num_iter = 3
1919lr = 0.01
2020checkpoint_path_str = './_checkpoint.test.case.test_xpu_checkpoint_save_load_integrity_and_accuracy.pth.tar'
2121
22- class TrainingModel (nn .Module ):
23- def __init__ (self ):
24- super (TrainingModel , self ).__init__ ()
25- self .m = nn .Sequential (
26- nn .Conv2d (input_channel , hidden_channel , kernel_size = (1 , 1 ), stride = (1 , 1 ), bias = False ),
27- nn .BatchNorm2d (hidden_channel , eps = 1e-05 , momentum = 0.1 ),
28- nn .ReLU (inplace = True ),
29- nn .AvgPool2d (kernel_size = 7 , stride = 1 , padding = 0 ),
30- )
31- self .fc = nn .Linear (in_features = hidden_channel , out_features = class_num , bias = True )
32-
33- def forward (self , x , indentity_for_mul , indentity_for_add ):
34- x = self .m (x )
35- x = x * indentity_for_mul
36- x = x .view (x .size (0 ), - 1 )
37- x = self .fc (x )
38- x = x + indentity_for_add
39- return x
40-
4122class TestTorchMethod (TestCase ):
4223 @pytest .mark .skipif (not torch .xpu .utils .has_fp64_dtype (), reason = "fp64 not support by this device" )
4324 def test_save_load (self ):
@@ -65,31 +46,17 @@ def test_serialization_multi_map_location(self):
6546 self .assertEqual (b .device .__str__ (), 'xpu:1' )
6647
6748 @pytest .mark .skipif (not torch .xpu .utils .has_fp64_dtype (), reason = "fp64 not support by this device" )
68- def test_xpu_checkpoint_save_load_integrity_and_accuracy (self , dtype = torch .bfloat16 ):
69- # create model
49+ def test_xpu_checkpoint_save_load_integrity_and_accuracy (self ):
7050 device = 'xpu'
71- model_xpu = TrainingModel ()
72- model_xpu = model_xpu .to (device = device ).train ()
73- optimizer_xpu = torch .optim .SGD (model_xpu .parameters (), lr = lr )
74- criterion = nn .CrossEntropyLoss ()
75-
76- if os .path .exists (checkpoint_path_str ):
77- os .remove (checkpoint_path_str )
78-
79- # process torch.xpu.optimize
80- model_xpu , optimizer_xpu = torch .xpu .optimize (model = model_xpu , dtype = dtype , optimizer = optimizer_xpu )
81-
82- def training_step (model_xpu , optimizer_xpu , criterion ):
83- input = torch .randn (batch_size , input_channel , 7 , 7 )
84- target = torch .empty (batch_size , dtype = torch .long ).random_ (class_num )
51+ def training_step (model_xpu , optimizer_xpu , criterion , dtype ):
52+ input = torch .randn (batch_size , input_channel , 224 , 224 )
53+ target = torch .empty (batch_size , dtype = torch .long ).random_ (1000 )
8554 input_xpu = input .clone ().to (device = device ).requires_grad_ ()
8655 target_xpu = target .to (device )
87- indentity_for_mul = torch .randn (batch_size , hidden_channel , 1 , 1 ).to (device = device )
88- indentity_for_add = torch .randn (batch_size , class_num ).to (device = device )
8956
9057 # forward
9158 with torch .xpu .amp .autocast (enabled = True , dtype = dtype ):
92- output_xpu = model_xpu (input_xpu , indentity_for_mul , indentity_for_add )
59+ output_xpu = model_xpu (input_xpu )
9360 loss_xpu = criterion (output_xpu , target_xpu )
9461
9562 # optimizer
@@ -103,35 +70,70 @@ def training_step(model_xpu, optimizer_xpu, criterion):
10370 loss_xpu = loss_xpu .cpu ()
10471 output_xpu = output_xpu .cpu ()
10572
106- def save_checkpoint (state , filename = checkpoint_path_str ):
107- torch .save (state , filename )
73+ def eval_step (model_xpu , dtype ):
74+ input = torch .randn (batch_size , input_channel , 224 , 224 )
75+ target = torch .empty (batch_size , dtype = torch .long ).random_ (1000 )
76+ input_xpu = input .clone ().to (device = device ).requires_grad_ ()
77+ target_xpu = target .to (device )
10878
109- for _ in range (num_iter ):
110- training_step (model_xpu , optimizer_xpu , criterion )
79+ # forward
80+ with torch .xpu .amp .autocast (enabled = True , dtype = dtype ):
81+ output_xpu = model_xpu (input_xpu )
82+ loss_xpu = criterion (output_xpu , target_xpu )
83+
84+ loss_xpu = loss_xpu .cpu ()
85+ output_xpu = output_xpu .cpu ()
11186
112- save_checkpoint ({'model_state_dict' : model_xpu .state_dict (), 'optimizer_state_dict' : optimizer_xpu .state_dict ()})
113- if os .path .isfile (checkpoint_path_str ):
114- # load checkpoint
115- checkpoint = torch .load (checkpoint_path_str , map_location = 'xpu' )
116- print ('load checkpoint' )
87+ def save_checkpoint (state , filename = checkpoint_path_str ):
88+ torch .save (state , filename )
11789
90+ for dtype in [torch .float32 , torch .bfloat16 ]:
91+ print ('dtype = ' , dtype )
11892 # create model
119- new_model = TrainingModel ()
120- new_model = new_model .to (device = device ).train ()
121- print ('create model' )
122-
123- # create optimizer
124- new_optimizer = torch .optim .SGD (new_model .parameters (), lr = lr )
125- print ('create model' )
126-
127- # load state dict
128- new_model .load_state_dict (checkpoint ['model_state_dict' ])
129- new_optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
130- print ('load state dict' )
131-
132- # check
133- print ('checking...' )
134- self .assertEqual (model_xpu .state_dict (), new_model .state_dict (), atol = 1e-6 , rtol = 1e-6 )
135- self .assertEqual (optimizer_xpu .state_dict (), new_optimizer .state_dict (), atol = 1e-6 , rtol = 1e-6 )
136- else :
137- assert False , "save checkpoint failed for xpu model" # noqa B011
93+ model_xpu = models .__dict__ ['resnet18' ](pretrained = True ).to (device = device ).train ()
94+ optimizer_xpu = torch .optim .SGD (model_xpu .parameters (), lr = lr )
95+ criterion = nn .CrossEntropyLoss ()
96+
97+ if os .path .exists (checkpoint_path_str ):
98+ os .remove (checkpoint_path_str )
99+
100+ # process torch.xpu.optimize
101+ model_xpu , optimizer_xpu = torch .xpu .optimize (model = model_xpu , dtype = dtype , optimizer = optimizer_xpu )
102+
103+ # mimic model train, then eval
104+ for _ in range (train_num_iter ):
105+ training_step (model_xpu , optimizer_xpu , criterion , dtype )
106+ model_xpu .eval ()
107+ for _ in range (eval_num_iter ):
108+ eval_step (model_xpu , dtype )
109+ torch .xpu .synchronize ()
110+
111+ save_checkpoint ({'model_state_dict' : model_xpu .state_dict (), 'optimizer_state_dict' : optimizer_xpu .state_dict ()})
112+ if os .path .isfile (checkpoint_path_str ):
113+ # load checkpoint
114+ checkpoint = torch .load (checkpoint_path_str , map_location = device )
115+ print ('load checkpoint' )
116+
117+ # create model
118+ new_model = models .__dict__ ['resnet18' ](pretrained = False ).to (device = device ).train ()
119+ print ('create model' )
120+
121+ # create optimizer
122+ new_optimizer = torch .optim .SGD (new_model .parameters (), lr = lr )
123+ print ('create model' )
124+
125+ # optimize
126+ new_model , new_optimizer = torch .xpu .optimize (model = new_model , dtype = dtype , optimizer = new_optimizer )
127+
128+ # load state dict
129+ new_model .load_state_dict (checkpoint ['model_state_dict' ])
130+ new_optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
131+ print ('load state dict' )
132+
133+ # check
134+ print ('checking...' )
135+ self .assertEqual (model_xpu .state_dict (), new_model .state_dict (), atol = 1e-6 , rtol = 1e-6 )
136+ self .assertEqual (optimizer_xpu .state_dict (), new_optimizer .state_dict (), atol = 1e-6 , rtol = 1e-6 )
137+ os .remove (checkpoint_path_str )
138+ else :
139+ assert False , "save checkpoint failed for xpu model" # noqa B011
0 commit comments