Skip to content

Commit 2c079e5

Browse files
authored
Map more elementwise ops to LLGA (#874)
1 parent c31d354 commit 2c079e5

File tree

4 files changed

+101
-5
lines changed

4 files changed

+101
-5
lines changed

intel_extension_for_pytorch/csrc/jit/codegen/onednn/graph_helper.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,39 @@ Operator LlgaGraphHelper::createOperator(Node* node) const {
229229
return makeEltwiseOp(node, opkind::Sigmoid);
230230
} else if (nodeKind == Symbol::aten("gelu")) {
231231
return makeEltwiseOp(node, opkind::GELU);
232+
} else if (nodeKind == Symbol::aten("round")) {
233+
return makeEltwiseOp(node, opkind::Round);
234+
} else if (nodeKind == Symbol::aten("exp")) {
235+
return makeEltwiseOp(node, opkind::Exp);
232236
} else if (nodeKind == Symbol::aten("sqrt")) {
233237
return makeEltwiseOp(node, opkind::Sqrt);
234238
} else if (nodeKind == Symbol::aten("abs")) {
235239
return makeEltwiseOp(node, opkind::Abs);
236240
} else if (nodeKind == Symbol::aten("square")) {
237241
return makeEltwiseOp(node, opkind::Square);
242+
} else if (nodeKind == Symbol::aten("clamp")) {
243+
// PyTorch API already checks that both min & max are not None.
244+
// But we can check it nevertheless.
245+
auto clamp_min = toIValue(node->input(1));
246+
auto clamp_max = toIValue(node->input(2));
247+
REQ(!(clamp_max->isNone() && clamp_min->isNone()));
248+
auto clamp_min_value = (clamp_min->isNone())
249+
? -std::numeric_limits<float>::infinity()
250+
: Operator::ScalarToFloat(node, 1);
251+
auto clamp_max_value = (clamp_max->isNone())
252+
? std::numeric_limits<float>::infinity()
253+
: Operator::ScalarToFloat(node, 2);
254+
return makeEltwiseOp(node, opkind::Clamp)
255+
.setAttr("min", clamp_min_value)
256+
.setAttr("max", clamp_max_value);
238257
} else if (nodeKind == Symbol::aten("hardtanh")) {
239258
return makeEltwiseOp(node, opkind::HardTanh)
240-
.setAttr("min", Operator::Float, 1)
241-
.setAttr("max", Operator::Float, 2);
259+
.setAttr("min", Operator::ScalarToFloat, 1)
260+
.setAttr("max", Operator::ScalarToFloat, 2);
261+
} else if (nodeKind == Symbol::aten("hardswish")) {
262+
return makeEltwiseOp(node, opkind::HardSwish);
263+
} else if (nodeKind == Symbol::aten("log")) {
264+
return makeEltwiseOp(node, opkind::Log);
242265
} else if (nodeKind == Symbol::aten("leaky_relu")) {
243266
return makeEltwiseOp(node, opkind::LeakyReLU)
244267
.setAttr("alpha", Operator::Float, 1);

intel_extension_for_pytorch/csrc/jit/codegen/onednn/operator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class Operator {
6969
return static_cast<float>(toIValue(node->input(offset))->toDouble());
7070
}
7171

72+
static float ScalarToFloat(const Node* node, size_t offset) {
73+
return toIValue(node->input(offset))->toScalar().to<float>();
74+
}
75+
7276
static std::vector<float> FloatValueToVector(float value) {
7377
return {value};
7478
}

tests/cpu/test_ao_jit_llga_quantization_fuser.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def get_eltwise_fn(name):
4141
elif hasattr(F, name):
4242
return getattr(F, name)
4343
else:
44+
if name == 'hardswish_':
45+
return torch.nn.Hardswish(inplace=True);
4446
raise NameError('Eltwise function %s not found' % name)
4547

4648
class TestOp(JitLlgaTestCase):
@@ -350,7 +352,8 @@ def forward(self, x):
350352
x = self.conv2(x)
351353
return x
352354

353-
for eltwise in ['relu', 'leaky_relu']: # TODO: ['sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']
355+
for eltwise in ['relu', 'leaky_relu', 'sigmoid', 'round', 'abs', 'square',
356+
'abs', 'round', 'exp', 'hardswish', 'tanh', 'hardtanh']:
354357
for inplace in [False, True]:
355358
for memory_format in [torch.contiguous_format, torch.channels_last]:
356359
eltwise_fn_name = eltwise + '_' if inplace else eltwise
@@ -369,6 +372,38 @@ def forward(self, x):
369372
self.assertFused(graph, ['aten::_convolution', 'aten::' + eltwise, 'aten::quantize_per_channel', 'aten::dequantize'])
370373
self.checkPatterns(graph, patterns)
371374

375+
def test_conv2d_clamp(self):
376+
class M(nn.Module):
377+
def __init__(self):
378+
super(M, self).__init__()
379+
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
380+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
381+
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
382+
self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
383+
self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
384+
385+
def forward(self, x):
386+
x = self.conv1(x)
387+
x = torch.clamp(x, min=float('-inf'))
388+
x = self.conv2(x)
389+
x = torch.clamp(x, min=-5)
390+
x = self.conv3(x)
391+
x = torch.clamp(x, min=0, max=float('inf'))
392+
x = self.conv4(x)
393+
x = torch.clamp(x, min=1, max=5)
394+
x = self.conv5(x)
395+
x = torch.clamp(x, max=2)
396+
return x
397+
398+
for inplace in [False, True]:
399+
for memory_format in [torch.contiguous_format, torch.channels_last]:
400+
x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
401+
m = M()
402+
for qconfig in static_qconfig:
403+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
404+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
405+
self.assertFused(graph, ['aten::_convolution', 'aten::' + "clamp", 'aten::quantize_per_channel', 'aten::dequantize'])
406+
372407
def test_ensure_tensor_is_rewrapped(self):
373408
class M(nn.Module):
374409
def __init__(self, eltwise_fn):

tests/cpu/test_jit_llga_fuser.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def get_eltwise_fn(name):
2222
elif hasattr(F, name):
2323
return getattr(F, name)
2424
else:
25+
if name == 'hardswish_':
26+
return torch.nn.Hardswish(inplace=True);
2527
raise NameError('Eltwise function %s not found' % name)
2628

2729

@@ -414,8 +416,8 @@ def forward(self, x):
414416
x = self.eltwise(x)
415417
return x
416418

417-
# for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']:
418-
for eltwise in ['relu']:
419+
for eltwise in ['relu', 'leaky_relu', 'sigmoid', 'round', 'abs', 'square',
420+
'abs', 'round', 'exp', 'hardswish', 'tanh', 'hardtanh']:
419421
for inplace in [True, False]:
420422
eltwise_fn_name = eltwise + '_' if inplace else eltwise
421423
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
@@ -429,6 +431,38 @@ def forward(self, x):
429431
# test if relu is fused into the fusion group
430432
self.assertFused(graph, ['aten::' + eltwise])
431433

434+
@llga_fp32_bf16_test_env
435+
def test_conv2d_clamp(self):
436+
class M(nn.Module):
437+
def __init__(self):
438+
super(M, self).__init__()
439+
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
440+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
441+
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
442+
self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
443+
self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
444+
445+
def forward(self, x):
446+
x = self.conv1(x)
447+
x = torch.clamp(x, min=float('-inf'))
448+
x = self.conv2(x)
449+
x = torch.clamp(x, min=-5)
450+
x = self.conv3(x)
451+
x = torch.clamp(x, min=0, max=float('inf'))
452+
x = self.conv4(x)
453+
x = torch.clamp(x, min=1, max=5)
454+
x = self.conv5(x)
455+
x = torch.clamp(x, max=2)
456+
return x
457+
458+
for inplace in [False, True]:
459+
for memory_format in [torch.contiguous_format, torch.channels_last]:
460+
x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
461+
m = M()
462+
graph, _ = self.checkTrace(m, [x])
463+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
464+
self.assertFused(graph, ['aten::_convolution', "aten::clamp"])
465+
432466
@llga_fp32_bf16_test_env
433467
def test_ensure_tensor_is_rewrapped(self):
434468
class M(nn.Module):

0 commit comments

Comments
 (0)