@@ -693,9 +693,47 @@ def dot_product_attention(
693693 flash_attention = None ,
694694 attn_logits_soft_cap = None ,
695695):
696- raise NotImplementedError (
697- "`dot_product_attention` is not supported with openvino backend"
696+ if bias is not None :
697+ raise NotImplementedError (
698+ "`dot_product_attention` with `bias` is not supported "
699+ "with openvino backend"
700+ )
701+ if flash_attention is not None :
702+ raise NotImplementedError (
703+ "`dot_product_attention` with `flash_attention` is not supported "
704+ "with openvino backend"
705+ )
706+ if attn_logits_soft_cap is not None :
707+ raise NotImplementedError (
708+ "`dot_product_attention` with `attn_logits_soft_cap` is not "
709+ "supported with openvino backend"
710+ )
711+ query = get_ov_output (query )
712+ key = get_ov_output (key )
713+ value = get_ov_output (value )
714+ if query .get_element_type () != key .get_element_type ():
715+ ov_type = OPENVINO_DTYPES [backend .floatx ()]
716+ query = ov_opset .convert (query , ov_type )
717+ key = ov_opset .convert (key , ov_type )
718+ if value .get_element_type () != query .get_element_type ():
719+ ov_type = OPENVINO_DTYPES [backend .floatx ()]
720+ value = ov_opset .convert (value , ov_type )
721+ axes_const = ov_opset .constant ([0 , 2 , 1 , 3 ], Type .i32 ).output (0 )
722+
723+ query = ov_opset .transpose (query , axes_const )
724+ key = ov_opset .transpose (key , axes_const )
725+ value = ov_opset .transpose (value , axes_const )
726+ mask = get_ov_output (mask ) if mask is not None else None
727+ scale = (
728+ get_ov_output (scale , query .get_element_type ())
729+ if scale is not None
730+ else None
731+ )
732+ dpa = ov_opset .scaled_dot_product_attention (
733+ query , key , value , attention_mask = mask , scale = scale , causal = is_causal
698734 )
735+ dpa = ov_opset .transpose (dpa , axes_const )
736+ return OpenVINOKerasTensor (dpa .output (0 ))
699737
700738
701739def unfold (input , kernel_size , dilation = 1 , padding = 0 , stride = 1 ):
0 commit comments