Skip to content

Commit 60d0796

Browse files
XiaobingSuperJiong Gong
andauthored
quantization: introduce default qconfig for static and dynamic quantization (#886)
* quantization:introduce default qconfig for static and dynamic quantization * update README * change reduce_range to False * Update intel_extension_for_pytorch/ao/quantization/README.md Co-authored-by: Jiong Gong <jiong.gong@intel.com> Co-authored-by: Jiong Gong <jiong.gong@intel.com>
1 parent 9dcdb56 commit 60d0796

File tree

5 files changed

+65
-29
lines changed

5 files changed

+65
-29
lines changed

intel_extension_for_pytorch/ao/quantization/README.md

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,31 @@ import intel_extension_for_pytorch as ipex
1111
from intel_extension_for_pytorch.quantization import prepare, convert
1212
```
1313

14-
### Define QConfig
14+
### Define qconfig
15+
16+
Using the default qconfig(recommended):
17+
18+
```python
19+
qconfig = ipex.quantization.default_static_qconfig
20+
# equal to
21+
# QConfig(activation=HistogramObserver.with_args(reduce_range=False),
22+
# weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
23+
```
1524

16-
Define a **QConfig** which set the activation and weight's observer methond:
25+
or define your own qconfig as:
1726

1827
```python
1928
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
2029
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
21-
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
30+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
2231
```
2332

2433
Note: we fully use of PyTorch [observer methonds](https://pytorch.org/docs/stable/quantization-support.html#torch-quantization-observer), so you can use a different PyTorch obsever methond to define the [QConfig](https://pytorch.org/docs/1.11/generated/torch.quantization.qconfig.QConfig.html). For weight observer, we only support **torch.qint8** dtype now.
2534

26-
**Suggestion**: For activation observer, if your set **qscheme** with **torch.per_tensor_affine**, the dtype prefer to **torch.quint8**, if you set the **qscheme** with **torch.per_tensor_symmetric**, the dtype prefer to **torch.qint8**. For weight observer, setting **qscheme** to **torch.per_channel_symmetric** can get a better accuracy,
35+
**Suggestion**:
36+
37+
1. For activation observer, if your set **qscheme** with **torch.per_tensor_affine**, **torch.quint8** is preferred, if you set the **qscheme** with **torch.per_tensor_symmetric**, **torch.qint8** is preferred. For weight observer, setting **qscheme** to **torch.per_channel_symmetric** can get a better accuracy.
38+
2. If your CPU device doesn't support VNNI, seeting the observer's **reduce_range** to **True** can get a better accuracy, such as skylake.
2739

2840
### Prepare Model and Do Calibration
2941

@@ -74,14 +86,29 @@ from intel_extension_for_pytorch.quantization import prepare, convert
7486

7587
### Define QConfig
7688

89+
Using the default qconfig(recommended):
90+
91+
```python
92+
dynamic_qconfig = ipex.quantization.default_dynamic_qconfig
93+
# equal to
94+
# QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
95+
# weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
96+
```
97+
98+
or define your own qconfig as:
99+
77100
```python
78101
from torch.ao.quantization import MinMaxObserver, PlaceholderObserver, QConfig
79-
dynamic_qconfig = QConfig(
80-
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
81-
weight = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
102+
dynamic_qconfig = QConfig(activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
103+
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
82104
```
83105

84-
Note: For weight observer, it only support dtype **torch.qint8**, and the qscheme only can be **torch.per_tensor_symmetric** or **torch.per_channel_symmetric**.
106+
Note: For weight observer, it only supports dtype **torch.qint8**, and the qscheme only can be **torch.per_tensor_symmetric** or **torch.per_channel_symmetric**. For activation observer, it only supports dtype **torch.float**, and the compute_dtype can be **torch.quint8** or **torch.qint8**.
107+
108+
**Suggestion**:
109+
110+
1. For weight observer, setting **qscheme** to **torch.per_channel_symmetric** can get a better accuracy.
111+
2. If your CPU device doesn't support VNNI, seeting the observer's **reduce_range** to **True** can get a better accuracy, such as skylake.
85112

86113
### Prepare Model
87114

@@ -106,13 +133,13 @@ convert_model = convert(prepared_model)
106133
# ...
107134
# for inference
108135
y = convert_model(x)
109-
110136
```
111137

112138
Note: we only support the following ops to do dynamic quantization:
139+
113140
- torch.nn.Linear
114-
- torch.nn.LSTM
115-
- torch.nn.GRU
141+
- torch.nn.LSTM
142+
- torch.nn.GRU
116143
- torch.nn.LSTMCell
117144
- torch.nn.RNNCell
118145
- torch.nn.GRUCell
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from ._quantize import prepare, convert
2+
from ._qconfig import default_static_qconfig, default_dynamic_qconfig
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
from torch.ao.quantization import PlaceholderObserver, PerChannelMinMaxObserver, HistogramObserver, QConfig
3+
4+
5+
_default_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
6+
7+
default_static_qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
8+
weight=_default_weight_observer)
9+
"""
10+
Default qconfig configuration for static quantization.
11+
"""
12+
13+
default_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
14+
weight=_default_weight_observer)
15+
"""
16+
Default qconfig configuration for dynamic quantization.
17+
"""
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ..ao.quantization import prepare, convert
1+
from ..ao.quantization import prepare, convert, default_static_qconfig, default_dynamic_qconfig

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,19 @@
2121
default_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
2222

2323
static_qconfig = [
24-
QConfig(
25-
activation = MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
24+
QConfig(activation = MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
2625
weight = default_weight_observer),
27-
QConfig(
28-
activation = MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
26+
QConfig(activation = MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
2927
weight = default_weight_observer),
30-
QConfig(
31-
activation = HistogramObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8, reduce_range=True),
28+
QConfig(activation = HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8, reduce_range=True),
3229
weight = default_weight_observer),
33-
QConfig(
34-
activation = HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8, reduce_range=True),
35-
weight = default_weight_observer),
36-
]
30+
ipex.quantization.default_static_qconfig]
3731

3832
dynamic_qconfig = [
39-
QConfig(
40-
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
41-
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)),
42-
QConfig(
43-
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
44-
weight = default_weight_observer),
45-
]
33+
QConfig(activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
34+
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)),
35+
ipex.quantization.default_dynamic_qconfig]
36+
4637

4738
class TestIpexOps(JitLlgaTestCase):
4839
def test_adaptive_avg_pool2d(self):

0 commit comments

Comments
 (0)