Skip to content

Commit 1717c44

Browse files
authored
Fix for fx (#894) (#901)
* override is_leaf_module for fx tracer * add ut for ipex.optimize with fx and fix ut for traced_model_serialization * minor changs * at an empty line at the end of the file
1 parent 1af1449 commit 1717c44

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

intel_extension_for_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
pass # skip if torchvision is not available
66

77
from .version import __version__
8-
from .utils import _cpu_isa
8+
from .utils import _cpu_isa, _custom_fx_tracer
99
_cpu_isa.check_minimal_isa_support()
1010

1111
torch_version = ''
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
import torch.fx as fx
3+
import types
4+
5+
def override_is_leaf_module():
6+
fx_tracer = fx.Tracer
7+
orig_is_leaf_module_fn = fx_tracer.is_leaf_module
8+
def ipex_is_leaf_module_fn(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
9+
is_ipex = m.__module__.startswith('intel_extension_for_pytorch.nn')
10+
return is_ipex or orig_is_leaf_module_fn(self, m, module_qualified_name)
11+
setattr(fx_tracer, 'is_leaf_module', types.MethodType(ipex_is_leaf_module_fn, fx_tracer))
12+
13+
override_is_leaf_module()

tests/cpu/test_ipex_optimize.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.fx.experimental.optimization as optimization
23
import intel_extension_for_pytorch as ipex
34
import intel_extension_for_pytorch._C as core
45
from intel_extension_for_pytorch.nn.utils._weight_prepack import _IPEXLinear as _IPEXLinear, _IPEXConv2d as _IPEXConv2d
@@ -45,6 +46,7 @@ class ConvTranspose2d(torch.nn.Module):
4546
def __init__(self, ):
4647
super(ConvTranspose2d, self).__init__()
4748
self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3 ,3))
49+
self.input1 = torch.randn(5, 5, 3, 3)
4850

4951
def forward(self, x):
5052
x = self.conv_transpose2d(x)
@@ -267,17 +269,36 @@ def test_record_shape(self):
267269
self.assertEqual(opt_M.l2.batch_size_collapsed, 3)
268270

269271
def test_traced_model_serialization(self):
270-
for module in [ConvBatchNorm, OneLayerMLP]:
272+
for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:
271273
for dtype in [torch.float, torch.bfloat16]:
272274
M = module().eval()
273-
opt_M = ipex.optimize(M, dtype=dtype)
275+
input = M.input1.to(dtype)
276+
opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True)
274277
with torch.no_grad():
275-
traced_M = torch.jit.trace(M, M.input1).eval()
278+
traced_M = torch.jit.trace(opt_M, input).eval()
276279
traced_M.save('traced_m.pt')
277280
loaded_M = torch.jit.load('traced_m.pt')
278-
self.assertEqual(traced_M(M.input1), loaded_M(M.input1))
281+
self.assertEqual(traced_M(input), loaded_M(input))
279282
os.remove('traced_m.pt')
280283

284+
def test_optimized_model_with_fx(self):
285+
for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:
286+
for dtype in [torch.float, torch.bfloat16]:
287+
M = module().eval()
288+
input = M.input1.to(dtype)
289+
opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True)
290+
ref_out = opt_M(input)
291+
fx_M = optimization.fuse(opt_M)
292+
fx_out = fx_M(input)
293+
self.assertEqual(ref_out, fx_out)
294+
with torch.no_grad():
295+
traced_M = torch.jit.trace(fx_M, input).eval()
296+
traced_M = torch.jit.freeze(traced_M)
297+
# do graph opt
298+
traced_M(input)
299+
# get optimized results
300+
out = traced_M(input)
301+
self.assertEqual(ref_out, out)
281302

282303
if __name__ == '__main__':
283304
test = unittest.main()

0 commit comments

Comments
 (0)