Skip to content

Commit 55676ed

Browse files
Fix REQ checks & scale/zps value retrieval when FX is used with JIT (#1420) (#1432)
* Fix REQ checks & scale/zps value retrieval when FX is used with JIT * Revise comments * Refactor code * Fix lint --------- Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
1 parent 0aa4aa1 commit 55676ed

File tree

5 files changed

+77
-8
lines changed

5 files changed

+77
-8
lines changed

csrc/jit/codegen/onednn/graph_helper.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,12 @@ Operator LlgaGraphHelper::createOperator(Node* node) const {
396396
// ---/-----/-----\-----\---
397397
// dequant q_scale q_zp dtype
398398
// REQ(node->output(0)->uses().size() <= 2);
399-
auto scale = toIValue(node->input(1));
400-
REQ(scale.has_value() && scale->isDouble());
399+
auto scale = node->input(1);
400+
REQ(utils::isScaleSupported(scale));
401+
402+
auto zero_point = node->input(2);
403+
REQ(utils::isZeroPointSupported(zero_point));
401404

402-
auto zero_point = toIValue(node->input(2));
403-
REQ(zero_point.has_value() && zero_point->isInt());
404405
return Operator(node, opkind::Quantize)
405406
.setInput(0)
406407
.setOutput(0)

csrc/jit/codegen/onednn/operator.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,23 @@ class Operator {
6464
}
6565

6666
static int64_t Int(const torch::jit::Node* node, size_t offset) {
67-
return torch::jit::toIValue(node->input(offset))->toInt();
67+
if (node->input(offset)->type()->isSubtypeOf(
68+
torch::jit::TensorType::get())) {
69+
// Composing FX with JIT tracing may cause scale/zps to be 0-dim tensors
70+
return toIValue(node->input(offset)).value().toTensor().item().toInt();
71+
} else {
72+
return static_cast<int64_t>(toIValue(node->input(offset))->toInt());
73+
}
6874
}
6975

7076
static float Float(const torch::jit::Node* node, size_t offset) {
71-
return static_cast<float>(
72-
torch::jit::toIValue(node->input(offset))->toDouble());
77+
if (node->input(offset)->type()->isSubtypeOf(
78+
torch::jit::TensorType::get())) {
79+
// Composing FX with JIT tracing may cause scale/zps to be 0-dim tensors
80+
return toIValue(node->input(offset)).value().toTensor().item().toFloat();
81+
} else {
82+
return static_cast<float>(toIValue(node->input(offset))->toDouble());
83+
}
7384
}
7485

7586
static float ScalarToFloat(const torch::jit::Node* node, size_t offset) {

csrc/jit/codegen/onednn/utils.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,25 @@ double getScale(Node* input_node) {
7777
return scale;
7878
}
7979

80+
bool isZeroPointSupported(Value* zps) {
81+
auto zps_value = toIValue(zps);
82+
return (
83+
zps_value.has_value() &&
84+
(zps_value->isInt() ||
85+
(zps_value->isTensor() &&
86+
(zps_value.value().toTensor().scalar_type() == at::ScalarType::Long))));
87+
}
88+
89+
bool isScaleSupported(Value* scale) {
90+
auto scale_value = toIValue(scale);
91+
return (
92+
scale_value.has_value() &&
93+
(scale_value->isDouble() ||
94+
(scale_value->isTensor() &&
95+
(scale_value.value().toTensor().scalar_type() ==
96+
at::ScalarType::Float))));
97+
}
98+
8099
} // namespace utils
81100
} // namespace onednn
82101
} // namespace fuser

csrc/jit/codegen/onednn/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ double getScale(torch::jit::Node* input_node);
2020

2121
std::vector<int64_t> getZPSVector(torch::jit::Node* input_node);
2222

23+
bool isZeroPointSupported(torch::jit::Value* zps);
24+
25+
bool isScaleSupported(torch::jit::Value* scale);
26+
2327
} // namespace utils
2428
} // namespace onednn
2529
} // namespace fuser

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,41 @@ def forward(self, x):
17451745
m = convert_to_reference_fx(m)
17461746
graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
17471747
# dequant -> linear should be mapped to LLGA
1748-
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1748+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1749+
1750+
@unittest.skipIf(True, "Poor accuracy")
1751+
@skipIfNoTorchVision
1752+
def test_fx_ao_qat_model(self):
1753+
class M(nn.Module):
1754+
def __init__(self):
1755+
super(M, self).__init__()
1756+
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1757+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1758+
self.eltwise = torch.nn.ReLU()
1759+
1760+
def forward(self, x):
1761+
x = self.conv1(x)
1762+
x = self.eltwise(x)
1763+
x = self.conv2(x)
1764+
return x
1765+
data = torch.randn(1, 32, 224, 224).to(memory_format=torch.channels_last)
1766+
m = M()
1767+
m.eval()
1768+
#
1769+
# quantization aware training for static quantization
1770+
#
1771+
qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')}
1772+
m.train()
1773+
model_prepared = prepare_qat_fx(m, qconfig_dict, example_inputs=data)
1774+
model_quantized = convert_to_reference_fx(model_prepared)
1775+
model_quantized=model_quantized.eval()
1776+
model = model_quantized.to(memory_format=torch.channels_last)
1777+
graph = self.checkQuantizeTrace(model, [data], atol=2e-1)
1778+
self.checkPatterns(graph, [['aten::dequantize', 'aten::quantize_per_channel', 'aten::_convolution',
1779+
'aten::relu', 'aten::quantize_per_tensor'],
1780+
['aten::dequantize', 'aten::quantize_per_channel', 'aten::_convolution',
1781+
'aten::quantize_per_tensor']])
1782+
17491783

17501784
def test_ffn_residual(self):
17511785
class FFN_Residual(nn.Module):

0 commit comments

Comments
 (0)