Skip to content

Commit 955bcff

Browse files
quantization: disable deconvolution quantization, replace dropout with identity to enable more fusion (#872)
* quantization: disable deconvolution quantization * quantization: replace dropout with identity to enable more fusion
1 parent 2c079e5 commit 955bcff

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
F.conv3d,
2626
torch.conv2d,
2727
torch.conv3d,
28-
F.conv_transpose2d,
29-
F.conv_transpose3d,
30-
torch.conv_transpose2d,
31-
torch.conv_transpose3d,
28+
#F.conv_transpose2d, #TODO
29+
#F.conv_transpose3d, #TODO
30+
#torch.conv_transpose2d, #TODO
31+
#torch.conv_transpose3d, #TODO
3232
torch.relu,
3333
F.relu,
3434
#torch.sigmoid, # TODO
@@ -50,8 +50,8 @@
5050
module_types_supported_by_quantization = set([
5151
torch.nn.Conv2d,
5252
torch.nn.Conv3d,
53-
torch.nn.ConvTranspose2d,
54-
torch.nn.ConvTranspose3d,
53+
#torch.nn.ConvTranspose2d,
54+
#torch.nn.ConvTranspose3d,
5555
torch.nn.Linear,
5656
torch.nn.MaxPool2d,
5757
torch.nn.MaxPool3d,
@@ -90,10 +90,10 @@
9090
str(F.conv3d),
9191
str(torch.conv2d),
9292
str(torch.conv3d),
93-
str(F.conv_transpose2d),
94-
str(F.conv_transpose3d),
95-
str(torch.conv_transpose2d),
96-
str(torch.conv_transpose3d),
93+
#str(F.conv_transpose2d),
94+
#str(F.conv_transpose3d),
95+
#str(torch.conv_transpose2d),
96+
#str(torch.conv_transpose3d),
9797
str(F.linear),
9898
str(torch._C._nn.linear),
9999
]
@@ -102,8 +102,8 @@
102102
#str(torch.nn.Conv1d) # it will be enabled at next step.
103103
str(torch.nn.Conv2d),
104104
str(torch.nn.Conv3d),
105-
str(torch.nn.ConvTranspose2d),
106-
str(torch.nn.ConvTranspose3d),
105+
#str(torch.nn.ConvTranspose2d),
106+
#str(torch.nn.ConvTranspose3d),
107107
str(torch.nn.Linear),
108108
]
109109

intel_extension_for_pytorch/ao/quantization/_quantize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def prepare(
3838
except:
3939
assert False, "The model's copy is failed, please try set inplace to True to do the prepare"
4040
warnings.warn("Conv BatchNorm folding failed during the prepare process.")
41+
# replace dropout with identity to enable more fusion pattern.
42+
nn.utils._model_convert.replace_dropout_with_identity(prepare_model)
4143
# Special case for common case of passing a single Tensor
4244
if isinstance(example_inputs, (torch.Tensor, dict)):
4345
example_inputs = (example_inputs,)

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def forward(self, x, y):
755755
["aten::dequantize", "aten::linear"]
756756
]
757757
for qconfig in static_qconfig:
758-
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, remove_dropout=True, qconfig=qconfig)
758+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
759759
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
760760
self.assertFused(graph, ['aten::linear', 'aten::add', 'aten::quantize_per_channel', 'aten::dequantize'])
761761
self.checkPatterns(graph, patterns)
@@ -806,7 +806,7 @@ def forward(self, x, y):
806806
["aten::dequantize", "aten::to", "aten::linear", "aten::to", "aten::quantize_per_tensor"],
807807
["aten::dequantize", "aten::to", "aten::linear", "aten::add"]
808808
]
809-
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, remove_dropout=True, int8_bf16=True)
809+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
810810
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
811811
self.assertFused(graph, ['aten::linear', 'aten::add', 'aten::dequantize'])
812812
self.checkPatterns(graph, patterns)

tests/cpu/test_ao_jit_llga_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ def assertFused(self, graph, fused_patterns):
102102
for pat in fused_patterns:
103103
self.assertGraphContainsExactly(graph, pat, 0)
104104

105-
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, remove_dropout=False, x_var=None,
105+
def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, x_var=None,
106106
qconfig=default_static_qconfig, int8_bf16=False):
107-
graph, traced_model, fp32_model = self.prepareModel(model, x, remove_dropout, qconfig, int8_bf16)
107+
graph, traced_model, fp32_model = self.prepareModel(model, x, qconfig, int8_bf16)
108108
with torch.no_grad():
109109
y = fp32_model(*x)
110110
y = y.to(torch.bfloat16) if int8_bf16 else y
@@ -120,14 +120,11 @@ def checkQuantizeTrace(self, model, x, atol=1e-3, rtol=1e-2, remove_dropout=Fals
120120

121121
return graph
122122

123-
def prepareModel(self, model, x, remove_dropout=False, qconfig=default_static_qconfig,
124-
int8_bf16=False, prepare_inplace=True, convert_inplace=True,):
123+
def prepareModel(self, model, x, qconfig=default_static_qconfig, int8_bf16=False,
124+
prepare_inplace=True, convert_inplace=True,):
125125
model.eval()
126126
fp32_model = copy.deepcopy(model)
127127
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
128-
# fold conv bn
129-
if remove_dropout:
130-
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
131128
model = ipex.quantization.prepare(model, qconfig, x, inplace=prepare_inplace)
132129
# do calibration
133130
y = model(*x)

0 commit comments

Comments
 (0)