You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
model = torch.jit.trace(model, x, check_trace=False).eval()
31
-
# Fold the BatchNormalization and propagate constant
32
-
torch.jit.freeze(model)
33
-
# Print the graph
34
-
print(model.graph_for(x))
35
-
```
36
19
Compared to the original code, the model launcher needs to add a few lines of code and the extension will automatically accelerate the model. Regarding the RN50, the extension will automatically fuse the Conv + ReLU and Conv + Sum + ReLU as ConvReLU and ConvSumReLU. If you check the output of `graph_for`, you will observe the fused operators.
@@ -175,17 +128,8 @@ Here listed all the currently supported int8 patterns in Intel® Extension for P
175
128
176
129
### Folding
177
130
Stock PyTorch provids constant propagation and BatchNormalization folding. These optimizations are automatically applied to the jit model by invoking `torch.jit.freeze`. Take the Resnet50 as an example:
178
-
```
179
-
import torch
180
-
import torchvision.models as models
181
-
model = models.__dict__["resnet50 "](pretrained=True)
182
-
model.eval()
183
-
x = torch.randn(args.batch_size, 3, 224, 224)
184
-
with torch.no_grad():
185
-
model = torch.jit.trace(model, x, check_trace=False).eval()
186
-
# Fold the BatchNormalization and propagate constant
If the model owner does not invoke the `torch.jit.freeze`, the `BatchNormalization` still exists on the graph. Otheriwse, the `BatchNormalization` will be folded on the graph to save the compuation and then improve the performance. Refer to the [Constant Folding Wikipedia page](https://en.wikipedia.org/wiki/Constant_folding) for more details.
0 commit comments