Skip to content

Commit d9a2680

Browse files
authored
Add _native_multi_head_attention to low precision cast policy of AutocastCPU (#860)
* add _native_multi_head_attention to low precision cast policy of AutocastCPU * fix fcode format
1 parent 15678cc commit d9a2680

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

intel_extension_for_pytorch/csrc/autocast/autocast_mode.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ struct CPU_WrapFunction_<
126126
}
127127
};
128128

129+
#define TUPLE_TWO_TENSORS std::tuple<Tensor, Tensor>
129130
#define ADD_NS(RAW_OP) at::RAW_OP
130131

131132
#define MAKE_REGISTER_FUNC(FUNC, NAME, SIG, CAST_POLICY) \
@@ -153,9 +154,29 @@ MAKE_REGISTER_FUNC(
153154
bool),
154155
user_defined_dtype)
155156

157+
MAKE_REGISTER_FUNC(
158+
ADD_NS(_native_multi_head_attention),
159+
"_native_multi_head_attention",
160+
TUPLE_TWO_TENSORS(
161+
const Tensor&,
162+
const Tensor&,
163+
const Tensor&,
164+
int64_t,
165+
int64_t,
166+
const Tensor&,
167+
const Tensor&,
168+
const Tensor&,
169+
const Tensor&,
170+
const c10::optional<Tensor>&,
171+
bool,
172+
bool),
173+
user_defined_dtype)
174+
156175
// fp32 cast policy a.k.a BlackList
157176
MAKE_REGISTER_FUNC(ADD_NS(mish), "mish", Tensor(const Tensor&), fp32)
158177

178+
#undef TUPLE_TWO_TENSORS
179+
159180
IPEX_TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
160181
m.impl(
161182
TORCH_SELECTIVE_NAME("aten::_convolution"),

tests/cpu/autocast_test_lists.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ def __init__(self, dev):
8080
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
8181
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
8282
]
83+
self.torch_bf16_multi_output = [
84+
("_native_multi_head_attention", (torch.randn((1, 1, 768), device=dev, dtype=torch.float32),
85+
torch.randn((1, 1, 768), device=dev, dtype=torch.float32),
86+
torch.randn((1, 1, 768), device=dev, dtype=torch.float32),
87+
768, 12, torch.randn((2304, 768), device=dev, dtype=torch.float32),
88+
torch.randn((2304), device=dev, dtype=torch.float32),
89+
torch.randn((768, 768), device=dev, dtype=torch.float32),
90+
torch.randn((768), device=dev, dtype=torch.float32),
91+
None, False, True)),
92+
]
8393
self.torch_fp32 = [
8494
("conv_transpose1d", conv_args_bf16[0]),
8595
("conv_transpose2d", conv_args_bf16[1]),

tests/cpu/test_autocast.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,11 @@ def test_autocast_blacklist_non_float_output(self):
659659
for op, args in self.autocast_lists.blacklist_non_float_output_pass_test:
660660
self._run_autocast_pass_test(op, args, torch.float32)
661661

662+
def test_autocast_torch_bf16_multi_output(self):
663+
for op_with_args in self.autocast_lists.torch_bf16_multi_output:
664+
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
665+
self._run_autocast_pass_test(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
666+
662667
def test_autocast_torch_fp32_multi_output(self):
663668
for op_with_args in self.autocast_lists.torch_fp32_multi_output:
664669
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)

0 commit comments

Comments
 (0)