Skip to content

Commit 9addc48

Browse files
zhuhaozhexuhancn
andauthored
Init amx for amx is enabled and override onednn isa flag (#1184) (#1227)
* init amx for amx is enabled and override onednn isa flag * optimize code. * avoid double amx initial. * update code. * add ut for ipex control onednn isa * add more ut * use ideep.hpp * fix headers * rename function onednn_isa_to_string * skip isa override ut if not sync ipex isa to onednn * optimze dnnl dependency & fix build issue. * disable sync_isa_level_to_dnnl * add env switch to enable dnnl isa level sync * fix dyndisp UT issue. * sync onednn isa level as default. * fix no return issue. * rename debug only env var. * generalize manual setup onednn isa code. * throw error when enable amx on non-linux platform and throw warninig for un-expected isa level Co-authored-by: Han, Xu <xu.han@intel.com> Co-authored-by: Han, Xu <xu.han@intel.com>
1 parent b8a49bf commit 9addc48

File tree

8 files changed

+187
-2
lines changed

8 files changed

+187
-2
lines changed

csrc/cpu/aten/utils/isa_help.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#include "isa_help.h"
22

3+
#include <dnnl.hpp>
4+
35
namespace torch_ipex {
46
namespace cpu {
57

8+
using namespace dnnl;
9+
610
DEFINE_DISPATCH(get_current_isa_level_kernel_stub);
711

812
// get_current_isa_level_kernel_impl
@@ -23,5 +27,31 @@ std::string get_highest_binary_support_isa_level() {
2327
return CPUCapabilityToString(level);
2428
}
2529

30+
const char* OneDNNIsaLevelToString(cpu_isa isa) {
31+
// convert dnnl::cpu_isa to string
32+
switch (isa) {
33+
case cpu_isa::avx2:
34+
return "AVX2";
35+
case cpu_isa::avx2_vnni:
36+
return "AVX2_VNNI";
37+
case cpu_isa::avx512_core:
38+
return "AVX512";
39+
case cpu_isa::avx512_core_vnni:
40+
return "AVX512_VNNI";
41+
case cpu_isa::avx512_core_bf16:
42+
return "AVX512_BF16";
43+
case cpu_isa::avx512_core_amx:
44+
return "AMX";
45+
46+
default:
47+
return "WrongLevel";
48+
}
49+
}
50+
51+
std::string get_current_onednn_isa_level() {
52+
cpu_isa onednn_isa_level = get_effective_cpu_isa();
53+
return OneDNNIsaLevelToString(onednn_isa_level);
54+
}
55+
2656
} // namespace cpu
2757
} // namespace torch_ipex

csrc/cpu/aten/utils/isa_help.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
namespace torch_ipex {
99
namespace cpu {
1010

11+
std::string get_current_onednn_isa_level();
12+
1113
std::string get_current_isa_level();
1214
std::string get_highest_cpu_support_isa_level();
1315
std::string get_highest_binary_support_isa_level();

csrc/cpu/dyndisp/DispatchStub.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44

55
#include "../cpu/isa/cpu_feature.hpp"
66

7+
#include <dnnl.hpp>
8+
79
#include <algorithm>
810
#include <cstdlib>
911
#include <cstring>
1012

1113
namespace torch_ipex {
1214
namespace cpu {
1315

16+
using namespace dnnl;
17+
1418
const char* CPUCapabilityToString(CPUCapability isa) {
1519
switch (isa) {
1620
case CPUCapability::DEFAULT:
@@ -79,6 +83,62 @@ CPUCapability _get_highest_binary_support_isa_level() {
7983
return CPUCapability::DEFAULT;
8084
}
8185

86+
static bool _load_sync_onednn_isa_setting() {
87+
// _IPEX_NOT_SYNC_ONEDNN_ISA is debug only env.
88+
auto envar = std::getenv("_IPEX_NOT_SYNC_ONEDNN_ISA");
89+
if (envar) {
90+
if (strcmp(envar, "1") == 0) {
91+
return true;
92+
}
93+
}
94+
return false;
95+
}
96+
97+
bool check_not_sync_onednn_isa_level() {
98+
static bool b_not_sync = _load_sync_onednn_isa_setting();
99+
100+
return b_not_sync;
101+
}
102+
103+
status set_current_cpu_isa_level_to_onednn(cpu_isa isa) {
104+
if (check_not_sync_onednn_isa_level()) {
105+
return status::unimplemented;
106+
}
107+
108+
return set_max_cpu_isa(isa);
109+
}
110+
111+
cpu_isa ipex_isa_to_onednn_isa(CPUCapability ipex_isa) {
112+
switch (ipex_isa) {
113+
case CPUCapability::DEFAULT:
114+
return cpu_isa::all;
115+
case CPUCapability::AVX2:
116+
return cpu_isa::avx2;
117+
case CPUCapability::AVX2_VNNI:
118+
return cpu_isa::avx2_vnni;
119+
case CPUCapability::AVX512:
120+
return cpu_isa::avx512_core;
121+
case CPUCapability::AVX512_VNNI:
122+
return cpu_isa::avx512_core_vnni;
123+
case CPUCapability::AVX512_BF16:
124+
return cpu_isa::avx512_core_bf16;
125+
case CPUCapability::AMX:
126+
return cpu_isa::avx512_core_amx;
127+
case CPUCapability::NUM_OPTIONS:
128+
TORCH_WARN("DispatchStub: OutOfBoundaryISALevel for IPEX");
129+
return cpu_isa::all;
130+
131+
default:
132+
auto ipex_isa_str = CPUCapabilityToString(ipex_isa);
133+
auto msg = c10::str(
134+
"DispatchStub: No corresponding onednn isa for ",
135+
ipex_isa_str,
136+
"Please consider check whether this ISA is supported by oneDNN");
137+
TORCH_WARN(msg);
138+
return cpu_isa::all;
139+
}
140+
}
141+
82142
static CPUCapability compute_cpu_capability() {
83143
CPUCapability highest_cpu_supported_isa_level =
84144
_get_highest_cpu_support_isa_level();
@@ -119,6 +179,9 @@ static CPUCapability compute_cpu_capability() {
119179
CPUCapability max_support_isa_level = std::min(
120180
highest_cpu_supported_isa_level, highest_binary_supported_isa_level);
121181
if (b_manual_setup) {
182+
cpu_isa manual_onednn_isa = ipex_isa_to_onednn_isa(manual_setup_isa_level);
183+
set_current_cpu_isa_level_to_onednn(manual_onednn_isa);
184+
122185
if (manual_setup_isa_level <= max_support_isa_level) {
123186
return manual_setup_isa_level;
124187
}

csrc/cpu/dyndisp/DispatchStub.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ const char* CPUCapabilityToString(CPUCapability isa);
6363
CPUCapability _get_highest_cpu_support_isa_level();
6464
CPUCapability _get_highest_binary_support_isa_level();
6565

66+
bool check_not_sync_onednn_isa_level();
67+
6668
CPUCapability get_cpu_capability();
6769

6870
template <typename FnPtr, typename T>

csrc/cpu/isa/cpu_feature.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#include <stdio.h>
33
#include "embedded_function.h"
44

5+
#ifdef __linux__
6+
#include <sys/syscall.h>
7+
#include <unistd.h>
8+
#endif
9+
10+
#include <ATen/ATen.h>
11+
512
namespace torch_ipex {
613
namespace cpu {
714
CPUFeature::CPUFeature() {
@@ -277,6 +284,44 @@ bool CPUFeature::os_amx() {
277284
return false;
278285
}
279286

287+
#ifdef __linux__
288+
289+
#define XFEATURE_XTILECFG 17
290+
#define XFEATURE_XTILEDATA 18
291+
#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
292+
#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
293+
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
294+
#define ARCH_GET_XCOMP_PERM 0x1022
295+
#define ARCH_REQ_XCOMP_PERM 0x1023
296+
297+
bool CPUFeature::init_amx() {
298+
unsigned long bitmask = 0;
299+
long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
300+
if (0 != status)
301+
return false;
302+
if (bitmask & XFEATURE_MASK_XTILEDATA)
303+
return true;
304+
305+
status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
306+
if (0 != status)
307+
return false; // XFEATURE_XTILEDATA setup is failed, TMUL usage is not
308+
// allowed
309+
status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
310+
311+
// XFEATURE_XTILEDATA setup is failed, can't use TMUL
312+
if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA))
313+
return false;
314+
315+
// XFEATURE_XTILEDATA set successfully, TMUL usage is allowed
316+
return true;
317+
}
318+
#else
319+
bool CPUFeature::init_amx() {
320+
AT_ERROR("DispatchStub: only support init amx on Linux now");
321+
return false;
322+
}
323+
#endif
324+
280325
bool CPUFeature::isa_level_avx2() {
281326
static bool b_is_support = os_avx2() && cpuid_avx2() && cpuid_fma();
282327
return b_is_support;
@@ -304,9 +349,19 @@ bool CPUFeature::isa_level_avx512_bf16() {
304349
return b_is_support;
305350
}
306351

352+
bool CPUFeature::_do_check_and_init_amx() {
353+
bool b_is_support = isa_level_avx512_bf16() && os_amx() && cpuid_amx_bf16() &&
354+
cpuid_amx_int8() && cpuid_amx_tile();
355+
if (b_is_support) {
356+
b_is_support = init_amx();
357+
}
358+
return b_is_support;
359+
}
360+
307361
bool CPUFeature::isa_level_amx() {
308-
static bool b_is_support = isa_level_avx512_bf16() && os_amx() &&
309-
cpuid_amx_bf16() && cpuid_amx_int8() && cpuid_amx_tile();
362+
// check and init in a funtion, avoid to double init.
363+
static bool b_is_support = _do_check_and_init_amx();
364+
310365
return b_is_support;
311366
}
312367

csrc/cpu/isa/cpu_feature.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class CPUFeature {
113113
MICRO_CLASS_MEMBER_DECL(amx_bf16);
114114
MICRO_CLASS_MEMBER_DECL(amx_tile);
115115
MICRO_CLASS_MEMBER_DECL(amx_int8);
116+
bool init_amx();
117+
bool _do_check_and_init_amx();
116118

117119
public:
118120
MICRO_CLASS_CHECK_FUNC(amx_bf16);

intel_extension_for_pytorch/csrc/python/cpu/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ void InitIpexModuleBindings(py::module m) {
6161
return get_current_isa_level();
6262
});
6363

64+
m.def("_get_current_onednn_isa_level", []() {
65+
using namespace torch_ipex::cpu;
66+
return get_current_onednn_isa_level();
67+
});
68+
69+
m.def("_check_not_sync_onednn_isa_level", []() {
70+
using namespace torch_ipex::cpu;
71+
return check_not_sync_onednn_isa_level();
72+
});
73+
6474
m.def("_get_highest_cpu_support_isa_level", []() {
6575
using namespace torch_ipex::cpu;
6676
return get_highest_cpu_support_isa_level();

tests/cpu/test_dyndisp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import os
3+
import subprocess
34

45
import intel_extension_for_pytorch._C as core
56

@@ -36,6 +37,9 @@ def get_highest_binary_support_isa_level():
3637
def get_highest_cpu_support_isa_level():
3738
return core._get_highest_cpu_support_isa_level().lower()
3839

40+
def check_not_sync_onednn_isa_level():
41+
return core._check_not_sync_onednn_isa_level()
42+
3943
class TestDynDisp(unittest.TestCase):
4044

4145
def test_manual_select_kernel(self):
@@ -69,5 +73,22 @@ def test_dyndisp_in_supported_set(self):
6973
self.assertTrue(expected_isa)
7074
return
7175

76+
@unittest.skipIf(check_not_sync_onednn_isa_level(), 'skip this if not sync onednn isa level')
77+
def test_ipex_set_onednn_isa_level(self):
78+
command = 'ATEN_CPU_CAPABILITY=avx2 python -c "import torch; import intel_extension_for_pytorch._C as core; print(core._get_current_onednn_isa_level())" '
79+
with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
80+
out = p.stdout.readlines()
81+
onednn_isa_level = str(out[-1], 'utf-8').strip()
82+
self.assertTrue(onednn_isa_level == 'AVX2')
83+
84+
@unittest.skipIf(check_not_sync_onednn_isa_level(), 'skip this if not sync onednn isa level')
85+
def test_onednn_do_not_set_isa_level(self):
86+
command = 'ONEDNN_MAX_CPU_ISA=avx2 python -c "import torch; import intel_extension_for_pytorch._C as core; print(core._get_current_isa_level().lower())" '
87+
cur_ipex_isa = get_currnet_isa_level()
88+
with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
89+
out = p.stdout.readlines()
90+
cur_ipex_isa_1 = str(out[-1], 'utf-8').strip()
91+
self.assertTrue(cur_ipex_isa == cur_ipex_isa_1)
92+
7293
if __name__ == '__main__':
7394
unittest.main()

0 commit comments

Comments
 (0)