@@ -1000,48 +1000,135 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
10001000 return true ;
10011001}
10021002
1003- // / Try to convert FindLastIV to FindFirstIV reduction when using a strict
1004- // / predicate. Returns the new FindFirstIVPhiR on success, nullptr on failure.
1005- static VPReductionPHIRecipe *
1006- tryConvertToFindFirstIV (VPlan &Plan, VPReductionPHIRecipe *FindLastIVPhiR,
1007- VPValue *IVOp, ScalarEvolution &SE, const Loop *L) {
1008- Type *Ty = VPTypeAnalysis (Plan).inferScalarType (FindLastIVPhiR);
1009- unsigned NumBits = Ty->getIntegerBitWidth ();
1010-
1011- // Determine the reduction kind and sentinel based on the IV range.
1012- RecurKind NewKind;
1013- VPValue *NewSentinel;
1014- auto *AR = cast<SCEVAddRecExpr>(vputils::getSCEVExprForVPValue (IVOp, SE, L));
1015- if (RecurrenceDescriptor::isValidIVRangeForFindIV (
1016- AR, /* IsSigned=*/ true , /* IsFindFirstIV=*/ true , SE)) {
1017- NewKind = RecurKind::FindFirstIVSMin;
1018- NewSentinel = Plan.getConstantInt (APInt::getSignedMaxValue (NumBits));
1019- } else if (RecurrenceDescriptor::isValidIVRangeForFindIV (
1020- AR, /* IsSigned=*/ false , /* IsFindFirstIV=*/ true , SE)) {
1021- NewKind = RecurKind::FindFirstIVUMin;
1022- NewSentinel = Plan.getConstantInt (APInt::getMaxValue (NumBits));
1023- } else {
1024- return nullptr ;
1003+ // / For argmin/argmax reductions with strict predicates, convert the existing
1004+ // / FindLastIV reduction to a new UMin reduction of a wide canonical IV. If the
1005+ // / original IV was not canonical, a new canonical wide IV is added, and the
1006+ // / final result is scaled back to the original IV.
1007+ static bool handleStrictArgMinArgMax (VPlan &Plan,
1008+ VPReductionPHIRecipe *MinMaxPhiR,
1009+ VPReductionPHIRecipe *FindIVPhiR,
1010+ VPWidenIntOrFpInductionRecipe *WideIV,
1011+ VPInstruction *MinMaxResult) {
1012+ Type *Ty = Plan.getVectorLoopRegion ()->getCanonicalIVType ();
1013+ if (Ty != VPTypeAnalysis (Plan).inferScalarType (FindIVPhiR))
1014+ return false ;
1015+
1016+ // If the original wide IV is not canonical, create a new one. The wide IV is
1017+ // guaranteed to not wrap for all lanes that are active in the vector loop.
1018+ if (!WideIV->isCanonical ()) {
1019+ VPValue *Zero = Plan.getOrAddLiveIn (ConstantInt::get (Ty, 0 ));
1020+ VPValue *One = Plan.getOrAddLiveIn (ConstantInt::get (Ty, 1 ));
1021+ auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe (
1022+ nullptr , Zero, One, WideIV->getVFValue (),
1023+ WideIV->getInductionDescriptor (), VPIRFlags (), WideIV->getDebugLoc ());
1024+ WidenCanIV->insertBefore (WideIV);
1025+
1026+ // Update the select to use the wide canonical IV.
1027+ auto *SelectRecipe = cast<VPSingleDefRecipe>(
1028+ FindIVPhiR->getBackedgeValue ()->getDefiningRecipe ());
1029+ if (SelectRecipe->getOperand (1 ) == WideIV)
1030+ SelectRecipe->setOperand (1 , WidenCanIV);
1031+ else if (SelectRecipe->getOperand (2 ) == WideIV)
1032+ SelectRecipe->setOperand (2 , WidenCanIV);
10251033 }
10261034
1027- // Create the new FindFirstIV reduction recipe.
1028- assert (!FindLastIVPhiR->isInLoop () && !FindLastIVPhiR->isOrdered ());
1029- ReductionStyle Style = RdxUnordered{FindLastIVPhiR->getVFScaleFactor ()};
1030- auto *FindFirstIVPhiR =
1031- new VPReductionPHIRecipe (nullptr , NewKind, *NewSentinel, Style,
1032- FindLastIVPhiR->hasUsesOutsideReductionChain ());
1033- FindFirstIVPhiR->addOperand (FindLastIVPhiR->getBackedgeValue ());
1035+ // Create the new UMin reduction recipe to track the minimum index.
1036+ assert (!FindIVPhiR->isInLoop () && !FindIVPhiR->isOrdered () &&
1037+ " inloop and ordered reductions not supported" );
1038+ VPValue *MaxInt =
1039+ Plan.getConstantInt (APInt::getMaxValue (Ty->getIntegerBitWidth ()));
1040+ ReductionStyle Style = RdxUnordered{FindIVPhiR->getVFScaleFactor ()};
1041+ auto *MinIdxPhiR = new VPReductionPHIRecipe (
1042+ dyn_cast_or_null<PHINode>(FindIVPhiR->getUnderlyingValue ()),
1043+ RecurKind::UMin, *MaxInt, Style,
1044+ FindIVPhiR->hasUsesOutsideReductionChain ());
1045+ MinIdxPhiR->addOperand (FindIVPhiR->getBackedgeValue ());
1046+ MinIdxPhiR->insertBefore (FindIVPhiR);
10341047
1035- FindFirstIVPhiR->insertBefore (FindLastIVPhiR);
10361048 VPInstruction *FindLastIVResult =
1037- findUserOf<VPInstruction::ComputeFindIVResult>(FindLastIVPhiR);
1038- FindLastIVPhiR->replaceAllUsesWith (FindFirstIVPhiR);
1039- FindLastIVResult->setOperand (2 , NewSentinel);
1040- return FindFirstIVPhiR;
1049+ findUserOf<VPInstruction::ComputeFindIVResult>(FindIVPhiR);
1050+ MinMaxResult->moveBefore (*FindLastIVResult->getParent (),
1051+ FindLastIVResult->getIterator ());
1052+
1053+ // The reduction using MinMaxPhiR needs adjusting to compute the correct
1054+ // result:
1055+ // 1. We need to find the first canonical IV for which the condition based
1056+ // on the min/max recurrence is true,
1057+ // 2. Compare the partial min/max reduction result to its final value and,
1058+ // 3. Select the lanes of the partial UMin reduction of the canonical wide
1059+ // IV which correspond to the lanes matching the min/max reduction result.
1060+ // 4. Scale the final select canonical IV back to the original IV using
1061+ // VPDerivedIVRecipe.
1062+ // 5. If the minimum value matches the start value, the condition in the
1063+ // loop was never true, return the start value in that case.
1064+ //
1065+ // The original reductions need adjusting:
1066+ // For example, this transforms
1067+ // vp<%min.result> = compute-reduction-result ir<%min.val>,
1068+ // ir<%min.val.next>
1069+ // vp<%find.iv.result = compute-find-iv-result ir<%min.idx>, ir<0>,
1070+ // SENTINEL, vp<%min.idx.next>
1071+ //
1072+ // into:
1073+ // vp<%min.result> = compute-reduction-result ir<%min.val>, ir<%min.val.next>
1074+ // vp<%final.min.cmp> = icmp eq ir<%min.val.next>, vp<%min.result>
1075+ // vp<%final.min.iv> = select vp<%final.min.cmp>, ir<%min.idx.next>, ir<-1>
1076+ // vp<%13> = compute-reduction-result ir<%min.idx>, vp<%final.min.iv>
1077+ // vp<%scaled.result.iv> = DERIVED-IV ir<20> + vp<%13> * ir<1>
1078+ // vp<%threshold.cmp> = icmp slt vp<%min.result>, ir<0>
1079+ // vp<%final.result> = select vp<%threshold.cmp>, vp<%scaled.result.iv>,
1080+ // ir<%original.start>
1081+
1082+ VPBuilder Builder (FindLastIVResult);
1083+ VPValue *MinMaxExiting = MinMaxResult->getOperand (1 );
1084+ auto *FinalMinMaxCmp =
1085+ Builder.createICmp (CmpInst::ICMP_EQ, MinMaxExiting, MinMaxResult);
1086+ VPValue *LastIVExiting = FindLastIVResult->getOperand (3 );
1087+ auto *FinalIVSelect =
1088+ Builder.createSelect (FinalMinMaxCmp, LastIVExiting, MaxInt);
1089+ VPSingleDefRecipe *FinalResult = Builder.createNaryOp (
1090+ VPInstruction::ComputeReductionResult, {MinIdxPhiR, FinalIVSelect}, {},
1091+ FindLastIVResult->getDebugLoc ());
1092+
1093+ // If we used a new wide canonical IV convert the reduction result back to the
1094+ // original IV scale before the final select.
1095+ if (!WideIV->isCanonical ()) {
1096+ auto *DerivedIVRecipe =
1097+ new VPDerivedIVRecipe (InductionDescriptor::IK_IntInduction,
1098+ nullptr , // No FPBinOp for integer induction
1099+ WideIV->getStartValue (), FinalResult,
1100+ WideIV->getStepValue (), " derived.iv.result" );
1101+ DerivedIVRecipe->insertBefore (&*Builder.getInsertPoint ());
1102+ FinalResult = DerivedIVRecipe;
1103+ }
1104+
1105+ auto GetPred = [&MinMaxPhiR]() {
1106+ switch (MinMaxPhiR->getRecurrenceKind ()) {
1107+ case RecurKind::UMin:
1108+ return CmpInst::ICMP_ULT;
1109+ case RecurKind::SMin:
1110+ return CmpInst::ICMP_SLT;
1111+ case RecurKind::UMax:
1112+ return CmpInst::ICMP_UGT;
1113+ case RecurKind::SMax:
1114+ return CmpInst::ICMP_SGT;
1115+ default :
1116+ llvm_unreachable (" must be an integer min/max recurrence kind" );
1117+ }
1118+ };
1119+ // If the final min/max value matches the start value, the condition in the
1120+ // loop was always false, i.e. no induction value has been selected. If that's
1121+ // the case, use the original start value.
1122+ VPValue *MinMaxLT =
1123+ Builder.createICmp (GetPred (), MinMaxResult, MinMaxPhiR->getStartValue ());
1124+ VPValue *Res = Builder.createSelect (MinMaxLT, FinalResult,
1125+ FindLastIVResult->getOperand (1 ));
1126+ FindIVPhiR->replaceAllUsesWith (MinIdxPhiR);
1127+ FindLastIVResult->replaceAllUsesWith (Res);
1128+ return true ;
10411129}
10421130
1043- bool VPlanTransforms::handleMultiUseReductions (VPlan &Plan, ScalarEvolution &SE,
1044- const Loop *L) {
1131+ bool VPlanTransforms::handleMultiUseReductions (VPlan &Plan) {
10451132 for (auto &PhiR : make_early_inc_range (
10461133 Plan.getVectorLoopRegion ()->getEntryBasicBlock ()->phis ())) {
10471134 auto *MinMaxPhiR = dyn_cast<VPReductionPHIRecipe>(&PhiR);
@@ -1052,7 +1139,7 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
10521139 // MinMaxPhiR has users outside the reduction cycle in the loop. Check if
10531140 // the only other user is a FindLastIV reduction. MinMaxPhiR must have
10541141 // exactly 3 users: 1) the min/max operation, the compare of a FindLastIV
1055- // reduction and ComputeReductionResult. The comparisom must compare
1142+ // reduction and ComputeReductionResult. The comparison must compare
10561143 // MinMaxPhiR against the min/max operand used for the min/max reduction
10571144 // and only be used by the select of the FindLastIV reduction.
10581145 RecurKind RdxKind = MinMaxPhiR->getRecurrenceKind ();
@@ -1151,13 +1238,14 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
11511238 if (!IsValidPredicate)
11521239 return false ;
11531240
1154- // For strict predicates, transform try to convert FindLastIV to
1155- // FindFirstIV.
1241+ // For strict predicates, use a UMin reduction to find the minimum index.
1242+ // Canonical IVs (0, 1, 2, ...) are guaranteed not to wrap in the vector
1243+ // loop, so UMin can always be used.
11561244 bool IsStrictPredicate = ICmpInst::isLT (Pred) || ICmpInst::isGT (Pred);
11571245 if (IsStrictPredicate) {
1158- FindIVPhiR = tryConvertToFindFirstIV (Plan, FindIVPhiR, IVOp, SE, L);
1159- if (!FindIVPhiR)
1160- return false ;
1246+ return handleStrictArgMinArgMax (Plan, MinMaxPhiR, FindIVPhiR,
1247+ cast<VPWidenIntOrFpInductionRecipe>(IVOp),
1248+ MinMaxResult) ;
11611249 }
11621250
11631251 // The reduction using MinMaxPhiR needs adjusting to compute the correct
0 commit comments