Skip to content

Commit 764e4fb

Browse files
support scaled dot product attention
1 parent 8714a49 commit 764e4fb

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ NNOpsDynamicShapeTest::test_categorical_crossentropy
264264
NNOpsDynamicShapeTest::test_multi_hot_dtype_
265265
NNOpsCorrectnessTest::test_conv_transpose_
266266
NNOpsCorrectnessTest::test_ctc_decode
267-
NNOpsCorrectnessTest::test_dot_product_attention_
268267
NNOpsCorrectnessTest::test_multi_hot_
269268
NNOpsCorrectnessTest::test_binary_crossentropy
270269
NNOpsCorrectnessTest::test_categorical_crossentropy
@@ -282,7 +281,6 @@ NNOpsCorrectnessTest::test_rms_normalization_10.0
282281
NNOpsDtypeTest::test_ctc_decode
283282
NNOpsDtypeTest::test_glu_
284283
NNOpsDtypeTest::test_polar_
285-
NNOpsDtypeTest::test_dot_product_attention_
286284
NNOpsDynamicShapeTest::test_glu
287285
NNOpsBehaviorTest::test_invalid_strategy_ctc_decode
288286
NNOpsBehaviorTest::test_logit_recovery_binary_crossentropy

keras/src/backend/openvino/nn.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,47 @@ def dot_product_attention(
693693
flash_attention=None,
694694
attn_logits_soft_cap=None,
695695
):
696-
raise NotImplementedError(
697-
"`dot_product_attention` is not supported with openvino backend"
696+
if bias is not None:
697+
raise NotImplementedError(
698+
"`dot_product_attention` with `bias` is not supported "
699+
"with openvino backend"
700+
)
701+
if flash_attention is not None:
702+
raise NotImplementedError(
703+
"`dot_product_attention` with `flash_attention` is not supported "
704+
"with openvino backend"
705+
)
706+
if attn_logits_soft_cap is not None:
707+
raise NotImplementedError(
708+
"`dot_product_attention` with `attn_logits_soft_cap` is not "
709+
"supported with openvino backend"
710+
)
711+
query = get_ov_output(query)
712+
key = get_ov_output(key)
713+
value = get_ov_output(value)
714+
if query.get_element_type() != key.get_element_type():
715+
ov_type = OPENVINO_DTYPES[backend.floatx()]
716+
query = ov_opset.convert(query, ov_type)
717+
key = ov_opset.convert(key, ov_type)
718+
if value.get_element_type() != query.get_element_type():
719+
ov_type = OPENVINO_DTYPES[backend.floatx()]
720+
value = ov_opset.convert(value, ov_type)
721+
axes_const = ov_opset.constant([0, 2, 1, 3], Type.i32).output(0)
722+
723+
query = ov_opset.transpose(query, axes_const)
724+
key = ov_opset.transpose(key, axes_const)
725+
value = ov_opset.transpose(value, axes_const)
726+
mask = get_ov_output(mask) if mask is not None else None
727+
scale = (
728+
get_ov_output(scale, query.get_element_type())
729+
if scale is not None
730+
else None
731+
)
732+
dpa = ov_opset.scaled_dot_product_attention(
733+
query, key, value, attention_mask=mask, scale=scale, causal=is_causal
698734
)
735+
dpa = ov_opset.transpose(dpa, axes_const)
736+
return OpenVINOKerasTensor(dpa.output(0))
699737

700738

701739
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):

keras/src/ops/nn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2448,7 +2448,7 @@ def test_dot_product_attention(
24482448
mask = mask[None, None, ...]
24492449
mask = np.tile(mask, (2, 4, 1, 1))
24502450
if bias is not None:
2451-
if backend.backend() == "torch":
2451+
if backend.backend() in ("torch", "openvino"):
24522452
self.skipTest(
24532453
"torch does not support `bias` with `dot_product_attention`"
24542454
)

0 commit comments

Comments
 (0)