@@ -1834,6 +1834,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18341834 else
18351835 setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
18361836 }
1837+
1838+ if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
1839+ setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
18371840 }
18381841
18391842 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7742,7 +7745,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77427745 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
77437746
77447747 assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
7745- assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
7748+ assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
7749+ "Unexpected FMUL VT");
77467750
77477751 auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
77487752 return [&, IID](EVT VT, auto... Ops) {
@@ -7751,37 +7755,56 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77517755 };
77527756 };
77537757
7754- auto ReinterpretCast = [&](SDValue Value, EVT VT) {
7755- if (VT == Value.getValueType())
7758+ auto Reinterpret = [&](SDValue Value, EVT VT) {
7759+ EVT SrcVT = Value.getValueType();
7760+ if (VT == SrcVT)
77567761 return Value;
7762+ if (SrcVT.isFixedLengthVector())
7763+ return convertToScalableVector(DAG, VT, Value);
7764+ if (VT.isFixedLengthVector())
7765+ return convertFromScalableVector(DAG, VT, Value);
77577766 return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
77587767 };
77597768
7760- // Create helpers for building intrinsic calls.
7761- auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
7762- auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
7769+ bool UseSVEBFMLAL = VT.isScalableVector();
77637770 auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
77647771 auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
77657772
7766- // All intrinsics expect to operate on full bf16 vector types.
7767- SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
7768- SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
7773+ // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
7774+ // This does not match BFCVTN[2], so we use SVE to convert back to bf16.
7775+ auto BFMLALB =
7776+ MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb
7777+ : Intrinsic::aarch64_neon_bfmlalb);
7778+ auto BFMLALT =
7779+ MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt
7780+ : Intrinsic::aarch64_neon_bfmlalt);
77697781
7770- SDValue Zero =
7771- DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32 , Op->getFlags());
7772- SDValue Pg = DAG.getConstant(1 , DL, MVT::nxv4i1 );
7782+ EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32;
7783+ SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT , Op->getFlags());
7784+ SDValue Pg = getPredicateForVector(DAG , DL, AccVT );
77737785
7774- // Lower bf16 FMUL as a pair (VT == nxv8bf16 ) of BFMLAL top/bottom
7786+ // Lower bf16 FMUL as a pair (VT == [nx]v8bf16 ) of BFMLAL top/bottom
77757787 // instructions. These result in two f32 vectors, which can be converted back
77767788 // to bf16 with FCVT and FCVTNT.
7777- SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
7789+ SDValue LHS = Op.getOperand(0);
7790+ SDValue RHS = Op.getOperand(1);
7791+
7792+ // All SVE intrinsics expect to operate on full bf16 vector types.
7793+ if (UseSVEBFMLAL) {
7794+ LHS = Reinterpret(LHS, MVT::nxv8bf16);
7795+ RHS = Reinterpret(RHS, MVT::nxv8bf16);
7796+ }
7797+
7798+ SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
77787799 SDValue BottomBF16 =
77797800 FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
77807801 // Note: nxv4bf16 only uses even lanes.
77817802 if (VT == MVT::nxv4bf16)
7782- return ReinterpretCast(BottomBF16, VT);
7783- SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
7784- return FCVTNT(VT, BottomBF16, Pg, TopF32);
7803+ return Reinterpret(BottomBF16, VT);
7804+
7805+ SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
7806+ SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
7807+ return Reinterpret(TopBF16, VT);
77857808}
77867809
77877810SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
0 commit comments