Skip to content

Commit 93fd96d

Browse files
Update keras/src/backend/openvino/nn.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 764e4fb commit 93fd96d

File tree

1 file changed

+3
-4
lines changed
  • keras/src/backend/openvino

1 file changed

+3
-4
lines changed

keras/src/backend/openvino/nn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,10 @@ def dot_product_attention(
713713
value = get_ov_output(value)
714714
if query.get_element_type() != key.get_element_type():
715715
ov_type = OPENVINO_DTYPES[backend.floatx()]
716-
query = ov_opset.convert(query, ov_type)
717-
key = ov_opset.convert(key, ov_type)
716+
query = ov_opset.convert(query, ov_type).output(0)
717+
key = ov_opset.convert(key, ov_type).output(0)
718718
if value.get_element_type() != query.get_element_type():
719-
ov_type = OPENVINO_DTYPES[backend.floatx()]
720-
value = ov_opset.convert(value, ov_type)
719+
value = ov_opset.convert(value, query.get_element_type()).output(0)
721720
axes_const = ov_opset.constant([0, 2, 1, 3], Type.i32).output(0)
722721

723722
query = ov_opset.transpose(query, axes_const)

0 commit comments

Comments
 (0)