Skip to content

Commit 8b02d62

Browse files
authored
update docs for graph_opt and add examples (#1584)
1 parent 02449cc commit 8b02d62

File tree

5 files changed

+98
-68
lines changed

5 files changed

+98
-68
lines changed

docs/tutorials/features/graph_optimization.md

Lines changed: 9 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,63 +12,16 @@ ipex.enable_onednn_fusion(False)
1212
```
1313

1414
### FP32 and BF16 models
15-
```
16-
import torch
17-
import torchvision.models as models
18-
19-
# Import the Intel Extension for PyTorch
20-
import intel_extension_for_pytorch as ipex
21-
22-
model = models.__dict__["resnet50 "](pretrained=True)
23-
model.eval()
2415

25-
# Apply some fusions at the front end
26-
model = ipex.optimize(model, dtype=torch.float32)
16+
[//]: # (marker_feature_graph_optimization_fp32_bf16)
17+
[//]: # (marker_feature_graph_optimization_fp32_bf16)
2718

28-
x = torch.randn(args.batch_size, 3, 224, 224)
29-
with torch.no_grad():
30-
model = torch.jit.trace(model, x, check_trace=False).eval()
31-
# Fold the BatchNormalization and propagate constant
32-
torch.jit.freeze(model)
33-
# Print the graph
34-
print(model.graph_for(x))
35-
```
3619
Compared to the original code, the model launcher needs to add a few lines of code and the extension will automatically accelerate the model. Regarding the RN50, the extension will automatically fuse the Conv + ReLU and Conv + Sum + ReLU as ConvReLU and ConvSumReLU. If you check the output of `graph_for`, you will observe the fused operators.
3720

3821
### INT8 models
39-
```
40-
import torch
41-
import intel_extension_for_pytorch as ipex
42-
43-
44-
# First-time quantization flow
45-
# define the model
46-
def MyModel(torch.nn.Module):
47-
...
48-
49-
# construct the model
50-
model = MyModel(...)
51-
qconfig = ipex.quantization.default_static_qconfig
52-
model.eval()
53-
example_inputs = ..
54-
prepared_model = prepare(user_model, qconfig, example_inputs=example_inputs, inplace=False)
55-
with torch.no_grad():
56-
for images in calibration_data_loader():
57-
prepared_model(images)
58-
59-
convert_model = convert(prepared_model)
60-
with torch.no_grad():
61-
traced_model = torch.jit.trace(convert_model, example_input)
62-
traced_model = torch.jit.freeze(traced_model)
63-
64-
traced_model.save("quantized_model.pt")
65-
# Deployment
66-
import intel_extension_for_pytorch as ipex
67-
quantized_model = torch.jit.load("quantized_model.pt")
68-
quantized_model = torch.jit.freeze(quantized_model.eval())
69-
with torch.no_grad():
70-
output = quantized_model(images)
71-
```
22+
23+
[//]: # (marker_feature_graph_optimization_int8)
24+
[//]: # (marker_feature_graph_optimization_int8)
7225

7326
## Methodology
7427
### Fusion
@@ -175,17 +128,8 @@ Here listed all the currently supported int8 patterns in Intel® Extension for P
175128

176129
### Folding
177130
Stock PyTorch provids constant propagation and BatchNormalization folding. These optimizations are automatically applied to the jit model by invoking `torch.jit.freeze`. Take the Resnet50 as an example:
178-
```
179-
import torch
180-
import torchvision.models as models
181-
model = models.__dict__["resnet50 "](pretrained=True)
182-
model.eval()
183-
x = torch.randn(args.batch_size, 3, 224, 224)
184-
with torch.no_grad():
185-
model = torch.jit.trace(model, x, check_trace=False).eval()
186-
# Fold the BatchNormalization and propagate constant
187-
torch.jit.freeze(model)
188-
# Print the graph
189-
print(model.graph_for(x))
190-
```
131+
132+
[//]: # (marker_feature_graph_optimization_folding)
133+
[//]: # (marker_feature_graph_optimization_folding)
134+
191135
If the model owner does not invoke the `torch.jit.freeze`, the `BatchNormalization` still exists on the graph. Otheriwse, the `BatchNormalization` will be folded on the graph to save the compuation and then improve the performance. Refer to the [Constant Folding Wikipedia page](https://en.wikipedia.org/wiki/Constant_folding) for more details.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
import torchvision.models as models
3+
4+
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
5+
model.eval()
6+
x = torch.randn(4, 3, 224, 224)
7+
8+
with torch.no_grad():
9+
model = torch.jit.trace(model, x, check_trace=False).eval()
10+
# Fold the BatchNormalization and propagate constant
11+
torch.jit.freeze(model)
12+
# Print the graph
13+
print(model.graph_for(x))
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import torchvision.models as models
3+
4+
# Import the Intel Extension for PyTorch
5+
import intel_extension_for_pytorch as ipex
6+
7+
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
8+
model.eval()
9+
10+
# Apply some fusions at the front end
11+
model = ipex.optimize(model, dtype=torch.float32)
12+
13+
x = torch.randn(4, 3, 224, 224)
14+
with torch.no_grad():
15+
model = torch.jit.trace(model, x, check_trace=False).eval()
16+
# Fold the BatchNormalization and propagate constant
17+
torch.jit.freeze(model)
18+
# Print the graph
19+
print(model.graph_for(x))
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torchvision.models as models
3+
import intel_extension_for_pytorch as ipex
4+
from intel_extension_for_pytorch.quantization import prepare, convert
5+
6+
# construct the model
7+
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
8+
qconfig = ipex.quantization.default_static_qconfig
9+
model.eval()
10+
example_inputs = torch.rand(1, 3, 224, 224)
11+
prepared_model = prepare(model, qconfig, example_inputs=example_inputs, inplace=False)
12+
13+
##### Example Dataloader #####
14+
import torchvision
15+
DOWNLOAD = True
16+
DATA = 'datasets/cifar10/'
17+
18+
transform = torchvision.transforms.Compose([
19+
torchvision.transforms.Resize((224, 224)),
20+
torchvision.transforms.ToTensor(),
21+
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
22+
])
23+
train_dataset = torchvision.datasets.CIFAR10(
24+
root=DATA,
25+
train=True,
26+
transform=transform,
27+
download=DOWNLOAD,
28+
)
29+
calibration_data_loader = torch.utils.data.DataLoader(
30+
dataset=train_dataset,
31+
batch_size=128
32+
)
33+
34+
with torch.no_grad():
35+
for batch_idx, (d, target) in enumerate(calibration_data_loader):
36+
print(f'calibrated on batch {batch_idx} out of {len(calibration_data_loader)}')
37+
prepared_model(d)
38+
##############################
39+
40+
convert_model = convert(prepared_model)
41+
with torch.no_grad():
42+
traced_model = torch.jit.trace(convert_model, example_inputs)
43+
traced_model = torch.jit.freeze(traced_model)
44+
45+
traced_model.save("quantized_model.pt")
46+
47+
# Deployment
48+
quantized_model = torch.jit.load("quantized_model.pt")
49+
quantized_model = torch.jit.freeze(quantized_model.eval())
50+
images = torch.rand(1, 3, 244, 244)
51+
with torch.no_grad():
52+
output = quantized_model(images)
53+
print('fin')

examples/cpu/inference/python/int8_calibration_static.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@
4040
batch_size=128
4141
)
4242

43-
for batch_idx, (d, target) in enumerate(calibration_data_loader):
44-
print(f'calibrated on batch {batch_idx} out of {len(calibration_data_loader)}')
45-
prepared_model(d)
43+
with torch.no_grad():
44+
for batch_idx, (d, target) in enumerate(calibration_data_loader):
45+
print(f'calibrated on batch {batch_idx} out of {len(calibration_data_loader)}')
46+
prepared_model(d)
4647
##############################
4748

4849
converted_model = convert(prepared_model)

0 commit comments

Comments
 (0)