[llvm] [LV] Support argmin/argmax with strict predicates. (PR #170223)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Sun Dec 28 03:29:32 PST 2025
================
@@ -1120,6 +1122,129 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
return true;
}
+/// For argmin/argmax reductions with strict predicates, convert the existing
+/// FindLastIV reduction to a new UMin reduction of a wide canonical IV. If the
+/// original IV was not canonical, a new canonical wide IV is added, and the
+/// final result is scaled back to the original IV.
+static bool handleFirstArgMinArgMax(VPlan &Plan,
+ VPReductionPHIRecipe *MinMaxPhiR,
+ VPReductionPHIRecipe *FindIVPhiR,
+ VPWidenIntOrFpInductionRecipe *WideIV,
+ VPInstruction *MinMaxResult) {
+ Type *Ty = Plan.getVectorLoopRegion()->getCanonicalIVType();
+ // TODO: Support different IV types.
+ if (Ty != VPTypeAnalysis(Plan).inferScalarType(FindIVPhiR))
+ return false;
+
+ // If the original wide IV is not canonical, create a new one. The wide IV is
+ // guaranteed to not wrap for all lanes that are active in the vector loop.
+ if (!WideIV->isCanonical()) {
+ VPValue *Zero = Plan.getConstantInt(Ty, 0);
+ VPValue *One = Plan.getConstantInt(Ty, 1);
+ auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe(
+ nullptr, Zero, One, WideIV->getVFValue(),
+ WideIV->getInductionDescriptor(),
+ VPIRFlags::WrapFlagsTy(/*HasNUW=*/true, /*HasNSW=*/false),
+ WideIV->getDebugLoc());
+ WidenCanIV->insertBefore(WideIV);
+
+ // Update the select to use the wide canonical IV.
+ auto *SelectR = cast<VPSingleDefRecipe>(
+ FindIVPhiR->getBackedgeValue()->getDefiningRecipe());
+ assert(match(SelectR, m_Select(m_VPValue(), m_VPValue(), m_VPValue())) &&
+ "backedge value must be a select");
+ WideIV->replaceUsesWithIf(WidenCanIV, [SelectR](const VPUser &U, unsigned) {
+ return SelectR == &U;
+ });
+ }
+
+ // Create the new UMin reduction recipe to track the minimum index.
+ assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
+ "inloop and ordered reductions not supported");
+ VPValue *MaxInt =
+ Plan.getConstantInt(APInt::getMaxValue(Ty->getIntegerBitWidth()));
+ assert(FindIVPhiR->getVFScaleFactor() == 1 &&
+ "FindIV reduction must not be scaled");
+ ReductionStyle Style = RdxUnordered{1};
+ auto *FirstIdxPhiR = new VPReductionPHIRecipe(
+ dyn_cast_or_null<PHINode>(FindIVPhiR->getUnderlyingValue()),
+ RecurKind::UMin, *MaxInt, *FindIVPhiR->getBackedgeValue(), Style,
+ FindIVPhiR->hasUsesOutsideReductionChain());
+ FirstIdxPhiR->insertBefore(FindIVPhiR);
+
+ VPInstruction *FindLastIVResult =
+ findUserOf<VPInstruction::ComputeFindIVResult>(FindIVPhiR);
+ MinMaxResult->moveBefore(*FindLastIVResult->getParent(),
+ FindLastIVResult->getIterator());
+
+ // The reduction using MinMaxPhiR needs adjusting to compute the correct
+ // result:
+ // 1. Find the first canonical indices corresponding to partial min/max
+ // values, using loop reductions.
+ // 2. Find which of the partial min/max values are equal to the overall
+ // min/max value.
+ // 3. Select among the canonical indices those corresponding to the overall
+ // min/max value.
+ // 4. Find the first canonical index of overall min/max and scale it back to
+ // the original IV using VPDerivedIVRecipe.
+ // 5. If the overall min/max is equal to the start value, the condition in
+ // the
+ // loop was always false, due to being strict; return the start value in
+ // that case.
+ //
+ // The original reductions need adjusting:
+ // For example, this transforms
+ // vp<%min.result> = compute-reduction-result ir<%min.val>, ir<%min.val.next>
+ // vp<%find.iv.result> = compute-find-iv-result ir<%min.idx>, ir<0>,
+ // ir<Sentinel>,
+ // vp<%min.idx.next>
+ //
+ // into:
+ // vp<%min.result> = compute-reduction-result ir<%min.val>, ir<%min.val.next>
+ // vp<%final.min.cmp> = icmp eq ir<%min.val.next>, vp<%min.result>
+ // vp<%final.min.idx> = select vp<%final.min.cmp>, ir<%min.idx.next>,
+ // ir<MaxUInt> vp<%13> = compute-reduction-result ir<%min.idx>,
+ // vp<%final.min.idx> vp<%scaled.result.iv> = DERIVED-IV ir<20> + vp<%13> *
+ // ir<1>
+ // vp<%always.false> = icmp eq vp<%min.result>, ir<%original.min.start>
+ // vp<%final.result> = select vp<%always.false>, vp<%scaled.result.iv>,
+ // ir<%original.start>
+
+ VPBuilder Builder(FindLastIVResult);
+ VPValue *MinMaxExiting = MinMaxResult->getOperand(1);
+ auto *FinalMinMaxCmp =
+ Builder.createICmp(CmpInst::ICMP_EQ, MinMaxExiting, MinMaxResult);
+ VPValue *LastIVExiting = FindLastIVResult->getOperand(3);
+ auto *FinalIVSelect =
+ Builder.createSelect(FinalMinMaxCmp, LastIVExiting, MaxInt);
+ VPSingleDefRecipe *FinalResult = Builder.createNaryOp(
+ VPInstruction::ComputeReductionResult, {FirstIdxPhiR, FinalIVSelect}, {},
+ FindLastIVResult->getDebugLoc());
+
+ // If we used a new wide canonical IV convert the reduction result back to the
+ // original IV scale before the final select.
+ if (!WideIV->isCanonical()) {
+ auto *DerivedIVRecipe =
+ new VPDerivedIVRecipe(InductionDescriptor::IK_IntInduction,
+ nullptr, // No FPBinOp for integer induction
+ WideIV->getStartValue(), FinalResult,
+ WideIV->getStepValue(), "derived.iv.result");
+ DerivedIVRecipe->insertBefore(&*Builder.getInsertPoint());
+ FinalResult = DerivedIVRecipe;
+ }
+
+ // If the final min/max value matches the start value, the condition in the
+ // loop was always false, i.e. no induction value has been selected. If that's
+ // the case, use the original start value.
+ VPValue *AlwaysFalse = Builder.createICmp(
+ CmpInst::ICMP_EQ, MinMaxPhiR->getStartValue(), MinMaxResult);
+ VPValue *Res = Builder.createSelect(AlwaysFalse, FinalResult,
+ FindLastIVResult->getOperand(1));
----------------
fhahn wrote:
Updated to `FinalIV`, derived from `FinalCanIV`.
https://github.com/llvm/llvm-project/pull/170223
More information about the llvm-commits
mailing list