|
1 | 1 | import torch |
| 2 | +import torch.fx.experimental.optimization as optimization |
2 | 3 | import intel_extension_for_pytorch as ipex |
3 | 4 | import intel_extension_for_pytorch._C as core |
4 | 5 | from intel_extension_for_pytorch.nn.utils._weight_prepack import _IPEXLinear as _IPEXLinear, _IPEXConv2d as _IPEXConv2d |
@@ -45,6 +46,7 @@ class ConvTranspose2d(torch.nn.Module): |
45 | 46 | def __init__(self, ): |
46 | 47 | super(ConvTranspose2d, self).__init__() |
47 | 48 | self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3 ,3)) |
| 49 | + self.input1 = torch.randn(5, 5, 3, 3) |
48 | 50 |
|
49 | 51 | def forward(self, x): |
50 | 52 | x = self.conv_transpose2d(x) |
@@ -267,17 +269,36 @@ def test_record_shape(self): |
267 | 269 | self.assertEqual(opt_M.l2.batch_size_collapsed, 3) |
268 | 270 |
|
269 | 271 | def test_traced_model_serialization(self): |
270 | | - for module in [ConvBatchNorm, OneLayerMLP]: |
| 272 | + for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]: |
271 | 273 | for dtype in [torch.float, torch.bfloat16]: |
272 | 274 | M = module().eval() |
273 | | - opt_M = ipex.optimize(M, dtype=dtype) |
| 275 | + input = M.input1.to(dtype) |
| 276 | + opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True) |
274 | 277 | with torch.no_grad(): |
275 | | - traced_M = torch.jit.trace(M, M.input1).eval() |
| 278 | + traced_M = torch.jit.trace(opt_M, input).eval() |
276 | 279 | traced_M.save('traced_m.pt') |
277 | 280 | loaded_M = torch.jit.load('traced_m.pt') |
278 | | - self.assertEqual(traced_M(M.input1), loaded_M(M.input1)) |
| 281 | + self.assertEqual(traced_M(input), loaded_M(input)) |
279 | 282 | os.remove('traced_m.pt') |
280 | 283 |
|
| 284 | + def test_optimized_model_with_fx(self): |
| 285 | + for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]: |
| 286 | + for dtype in [torch.float, torch.bfloat16]: |
| 287 | + M = module().eval() |
| 288 | + input = M.input1.to(dtype) |
| 289 | + opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True) |
| 290 | + ref_out = opt_M(input) |
| 291 | + fx_M = optimization.fuse(opt_M) |
| 292 | + fx_out = fx_M(input) |
| 293 | + self.assertEqual(ref_out, fx_out) |
| 294 | + with torch.no_grad(): |
| 295 | + traced_M = torch.jit.trace(fx_M, input).eval() |
| 296 | + traced_M = torch.jit.freeze(traced_M) |
| 297 | + # do graph opt |
| 298 | + traced_M(input) |
| 299 | + # get optimized results |
| 300 | + out = traced_M(input) |
| 301 | + self.assertEqual(ref_out, out) |
281 | 302 |
|
282 | 303 | if __name__ == '__main__': |
283 | 304 | test = unittest.main() |
0 commit comments