[llvm] [LV] Support argmin/argmax with strict predicates. (PR #170223)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 24 13:58:11 PST 2026


================
@@ -1390,7 +1392,136 @@ bool VPlanTransforms::handleFindLastReductions(VPlan &Plan) {
   return true;
 }
 
-bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
+/// For argmin/argmax reductions with strict predicates, convert the existing
+/// FindLastIV reduction to a new UMin reduction of a wide canonical IV. If the
+/// original IV was not canonical, a new canonical wide IV is added, and the
+/// final result is scaled back to the original IV.
+static bool handleFirstArgMinOrMax(VPlan &Plan,
+                                   VPReductionPHIRecipe *MinOrMaxPhiR,
+                                   VPReductionPHIRecipe *FindLastIVPhiR,
+                                   VPWidenIntOrFpInductionRecipe *WideIV,
+                                   VPInstruction *MinOrMaxResult) {
+  Type *Ty = Plan.getVectorLoopRegion()->getCanonicalIVType();
+  // TODO: Support non (i.e., narrower than) canonical IV types.
+  if (Ty != VPTypeAnalysis(Plan).inferScalarType(WideIV))
+    return false;
+
+  // If the original wide IV is not canonical, create a new one. The wide IV is
+  // guaranteed to not wrap for all lanes that are active in the vector loop.
+  auto *FindIVSelectR = cast<VPSingleDefRecipe>(
+      FindLastIVPhiR->getBackedgeValue()->getDefiningRecipe());
+  if (!WideIV->isCanonical()) {
+    VPIRValue *Zero = Plan.getConstantInt(Ty, 0);
+    VPIRValue *One = Plan.getConstantInt(Ty, 1);
+    auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe(
+        nullptr, Zero, One, WideIV->getVFValue(),
+        WideIV->getInductionDescriptor(),
+        VPIRFlags::WrapFlagsTy(/*HasNUW=*/true, /*HasNSW=*/false),
+        WideIV->getDebugLoc());
+    WidenCanIV->insertBefore(WideIV);
+
+    // Update the select to use the wide canonical IV.
+    assert(
+        match(FindIVSelectR, m_Select(m_VPValue(), m_VPValue(), m_VPValue())) &&
+        "backedge value must be a select");
+    WideIV->replaceUsesWithIf(WidenCanIV,
+                              [FindIVSelectR](const VPUser &U, unsigned) {
+                                return FindIVSelectR == &U;
+                              });
+  }
+
+  // Create the new UMin reduction recipe to track the minimum index.
+  assert(!FindLastIVPhiR->isInLoop() && !FindLastIVPhiR->isOrdered() &&
+         "inloop and ordered reductions not supported");
+  VPValue *MaxIV =
+      Plan.getConstantInt(APInt::getMaxValue(Ty->getIntegerBitWidth()));
+  assert(FindLastIVPhiR->getVFScaleFactor() == 1 &&
+         "FindIV reduction must not be scaled");
+  ReductionStyle Style = RdxUnordered{1};
+  auto *FirstIdxPhiR = new VPReductionPHIRecipe(
+      dyn_cast_or_null<PHINode>(FindLastIVPhiR->getUnderlyingValue()),
+      RecurKind::UMin, *MaxIV, *FindIVSelectR, Style,
+      FindLastIVPhiR->hasUsesOutsideReductionChain());
+  FirstIdxPhiR->insertBefore(FindLastIVPhiR);
+
+  VPInstruction *FindLastIVResult =
+      findUserOf<VPInstruction::ComputeFindIVResult>(
+          FindLastIVPhiR->getBackedgeValue());
+  MinOrMaxResult->moveBefore(*FindLastIVResult->getParent(),
+                             FindLastIVResult->getIterator());
+
+  // The reduction using MinOrMaxPhiR needs adjusting to compute the correct
+  // result:
+  //  1. Find the first canonical indices corresponding to partial min/max
+  //     values, using loop reductions.
+  //  2. Find which of the partial min/max values are equal to the overall
+  //     min/max value.
+  //  3. Select among the canonical indices those corresponding to the overall
+  //     min/max value.
+  //  4. Find the first canonical index of overall min/max and scale it back to
+  //     the original IV using VPDerivedIVRecipe.
+  //  5. If the overall min/max is equal to the start value, the condition in
+  //     the loop was always false, due to being strict; return the start value
+  //     in that case.
+  //
+  // The original reductions need adjusting:
+  // For example, this transforms two independent constructs
+  // vp<%min.result> = compute-reduction-result ir<%min.val.next>
+  // vp<%find.iv.result> = compute-find-iv-result ir<0>, ir<Sentinel>,
+  //                                              vp<%min.idx.next>
+  //
+  // into:
+  //  vp<%min.result> = compute-reduction-result ir<%min.val.next>
+  //  vp<%final.min.cmp> = icmp eq ir<%min.val.next>, vp<%min.result>
+  //  vp<%final.min.idx> = select vp<%final.min.cmp>, ir<%min.idx.next>,
+  //                              ir<MaxUInt>
+  //  vp<%final.can.iv> = compute-reduction-result vp<%final.min.idx>
+  //  vp<%scaled.result.iv> = DERIVED-IV ir<20> + vp<%final.can.iv> *
+  //                                                        ir<1>
+  //  vp<%always.false> = icmp eq vp<%min.result>, ir<%sentinel.min.start>
+  //  vp<%final.result> = select vp<%always.false>, ir<%original.start>,
+  //                             vp<%scaled.result.iv>
+
+  VPBuilder Builder(FindLastIVResult);
+  VPValue *MinOrMaxExiting = MinOrMaxResult->getOperand(0);
+  auto *FinalMinOrMaxCmp =
+      Builder.createICmp(CmpInst::ICMP_EQ, MinOrMaxExiting, MinOrMaxResult);
+  VPValue *LastIVExiting = FindLastIVResult->getOperand(2);
+  auto *FinalIVSelect =
+      Builder.createSelect(FinalMinOrMaxCmp, LastIVExiting, MaxIV);
+  VPIRFlags RdxFlags(RecurKind::UMin, false, false, FastMathFlags());
+  VPSingleDefRecipe *FinalCanIV = Builder.createNaryOp(
+      VPInstruction::ComputeReductionResult, {FinalIVSelect}, RdxFlags,
+      FindLastIVResult->getDebugLoc());
+
+  // If we used a new wide canonical IV convert the reduction result back to the
+  // original IV scale before the final select.
+  if (!WideIV->isCanonical()) {
+    auto *DerivedIVRecipe =
+        new VPDerivedIVRecipe(InductionDescriptor::IK_IntInduction,
+                              nullptr, // No FPBinOp for integer induction
+                              WideIV->getStartValue(), FinalCanIV,
+                              WideIV->getStepValue(), "derived.iv.result");
+    DerivedIVRecipe->insertBefore(&*Builder.getInsertPoint());
+    FinalCanIV = DerivedIVRecipe;
+  }
+
+  // If the final min/max value matches the start value, the condition in the
----------------
fhahn wrote:

updated thanks

https://github.com/llvm/llvm-project/pull/170223


More information about the llvm-commits mailing list