|
4 | 4 |
|
5 | 5 | #include "../cpu/isa/cpu_feature.hpp" |
6 | 6 |
|
| 7 | +#include <dnnl.hpp> |
| 8 | + |
7 | 9 | #include <algorithm> |
8 | 10 | #include <cstdlib> |
9 | 11 | #include <cstring> |
10 | 12 |
|
11 | 13 | namespace torch_ipex { |
12 | 14 | namespace cpu { |
13 | 15 |
|
| 16 | +using namespace dnnl; |
| 17 | + |
14 | 18 | const char* CPUCapabilityToString(CPUCapability isa) { |
15 | 19 | switch (isa) { |
16 | 20 | case CPUCapability::DEFAULT: |
@@ -79,6 +83,62 @@ CPUCapability _get_highest_binary_support_isa_level() { |
79 | 83 | return CPUCapability::DEFAULT; |
80 | 84 | } |
81 | 85 |
|
| 86 | +static bool _load_sync_onednn_isa_setting() { |
| 87 | + // _IPEX_NOT_SYNC_ONEDNN_ISA is debug only env. |
| 88 | + auto envar = std::getenv("_IPEX_NOT_SYNC_ONEDNN_ISA"); |
| 89 | + if (envar) { |
| 90 | + if (strcmp(envar, "1") == 0) { |
| 91 | + return true; |
| 92 | + } |
| 93 | + } |
| 94 | + return false; |
| 95 | +} |
| 96 | + |
| 97 | +bool check_not_sync_onednn_isa_level() { |
| 98 | + static bool b_not_sync = _load_sync_onednn_isa_setting(); |
| 99 | + |
| 100 | + return b_not_sync; |
| 101 | +} |
| 102 | + |
| 103 | +status set_current_cpu_isa_level_to_onednn(cpu_isa isa) { |
| 104 | + if (check_not_sync_onednn_isa_level()) { |
| 105 | + return status::unimplemented; |
| 106 | + } |
| 107 | + |
| 108 | + return set_max_cpu_isa(isa); |
| 109 | +} |
| 110 | + |
| 111 | +cpu_isa ipex_isa_to_onednn_isa(CPUCapability ipex_isa) { |
| 112 | + switch (ipex_isa) { |
| 113 | + case CPUCapability::DEFAULT: |
| 114 | + return cpu_isa::all; |
| 115 | + case CPUCapability::AVX2: |
| 116 | + return cpu_isa::avx2; |
| 117 | + case CPUCapability::AVX2_VNNI: |
| 118 | + return cpu_isa::avx2_vnni; |
| 119 | + case CPUCapability::AVX512: |
| 120 | + return cpu_isa::avx512_core; |
| 121 | + case CPUCapability::AVX512_VNNI: |
| 122 | + return cpu_isa::avx512_core_vnni; |
| 123 | + case CPUCapability::AVX512_BF16: |
| 124 | + return cpu_isa::avx512_core_bf16; |
| 125 | + case CPUCapability::AMX: |
| 126 | + return cpu_isa::avx512_core_amx; |
| 127 | + case CPUCapability::NUM_OPTIONS: |
| 128 | + TORCH_WARN("DispatchStub: OutOfBoundaryISALevel for IPEX"); |
| 129 | + return cpu_isa::all; |
| 130 | + |
| 131 | + default: |
| 132 | + auto ipex_isa_str = CPUCapabilityToString(ipex_isa); |
| 133 | + auto msg = c10::str( |
| 134 | + "DispatchStub: No corresponding onednn isa for ", |
| 135 | + ipex_isa_str, |
| 136 | + "Please consider check whether this ISA is supported by oneDNN"); |
| 137 | + TORCH_WARN(msg); |
| 138 | + return cpu_isa::all; |
| 139 | + } |
| 140 | +} |
| 141 | + |
82 | 142 | static CPUCapability compute_cpu_capability() { |
83 | 143 | CPUCapability highest_cpu_supported_isa_level = |
84 | 144 | _get_highest_cpu_support_isa_level(); |
@@ -119,6 +179,9 @@ static CPUCapability compute_cpu_capability() { |
119 | 179 | CPUCapability max_support_isa_level = std::min( |
120 | 180 | highest_cpu_supported_isa_level, highest_binary_supported_isa_level); |
121 | 181 | if (b_manual_setup) { |
| 182 | + cpu_isa manual_onednn_isa = ipex_isa_to_onednn_isa(manual_setup_isa_level); |
| 183 | + set_current_cpu_isa_level_to_onednn(manual_onednn_isa); |
| 184 | + |
122 | 185 | if (manual_setup_isa_level <= max_support_isa_level) { |
123 | 186 | return manual_setup_isa_level; |
124 | 187 | } |
|
0 commit comments