[llvm] [LV] Support argmin/argmax with strict predicates. (PR #170223)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 13 04:47:19 PST 2026


================
@@ -1432,7 +1433,143 @@ bool VPlanTransforms::handleFindLastReductions(VPlan &Plan) {
   return true;
 }
 
-bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
+/// Given a first argmin/argmax pattern with strict predicate consisting of
+/// 1) a MinOrMax reduction \p MinOrMaxPhiR producing \p MinOrMaxResult,
+/// 2) a wide induction \p WideIV,
+/// 3) a FindLastIV reduction \p FindLastIVPhiR,
+/// return the smallest index of the FindLastIV reduction result using UMin,
+/// unless \p MinOrMaxResult equals the start value of its MinOrMax reduction.
+/// In that case, return the start value of the FindLastIV reduction instead.
+/// If \p WideIV is not canonical, a new canonical wide IV is added, and the
+/// final result is scaled back to the non-canonical \p WideIV.
+/// The final value of the FindLastIV reduction was originally computed using
+/// \p FindIVSelect, \p FindIVCmp, and \p FindIVRdxResult, which are replaced
+/// and removed.
+/// Returns true if the pattern was handled successfully, false otherwise.
+static bool handleFirstArgMinOrMax(
+    VPlan &Plan, VPReductionPHIRecipe *MinOrMaxPhiR,
+    VPReductionPHIRecipe *FindLastIVPhiR, VPWidenIntOrFpInductionRecipe *WideIV,
+    VPInstruction *MinOrMaxResult, VPInstruction *FindIVSelect,
+    VPRecipeBase *FindIVCmp, VPInstruction *FindIVRdxResult) {
+  Type *Ty = Plan.getVectorLoopRegion()->getCanonicalIVType();
+  // TODO: Support non (i.e., narrower than) canonical IV types.
+  if (Ty != VPTypeAnalysis(Plan).inferScalarType(WideIV))
+    return false;
+
+  auto *FindIVSelectR = cast<VPSingleDefRecipe>(
+      FindLastIVPhiR->getBackedgeValue()->getDefiningRecipe());
+  assert(
+      match(FindIVSelectR, m_Select(m_VPValue(), m_VPValue(), m_VPValue())) &&
+      "backedge value must be a select");
+  if (FindIVSelectR->getOperand(1) != WideIV &&
+      FindIVSelectR->getOperand(2) != WideIV)
+    return false;
+
+  // If the original wide IV is not canonical, create a new one. The canonical
+  // wide IV is guaranteed to not wrap for all lanes that are active in the
+  // vector loop.
+  if (!WideIV->isCanonical()) {
+    VPIRValue *Zero = Plan.getConstantInt(Ty, 0);
+    VPIRValue *One = Plan.getConstantInt(Ty, 1);
+    auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe(
+        nullptr, Zero, One, WideIV->getVFValue(),
+        WideIV->getInductionDescriptor(),
+        VPIRFlags::WrapFlagsTy(/*HasNUW=*/true, /*HasNSW=*/false),
+        WideIV->getDebugLoc());
+    WidenCanIV->insertBefore(WideIV);
+
+    // Update the select to use the wide canonical IV.
+    FindIVSelectR->setOperand(FindIVSelectR->getOperand(1) == WideIV ? 1 : 2,
+                              WidenCanIV);
+  }
+
+  assert(!FindLastIVPhiR->isInLoop() && !FindLastIVPhiR->isOrdered() &&
+         "inloop and ordered reductions not supported");
+  assert(FindLastIVPhiR->getVFScaleFactor() == 1 &&
+         "FindIV reduction must not be scaled");
+  // Set the starting value of FindLastIV reduction to be the upper bound.
+  VPValue *MaxIV =
+      Plan.getConstantInt(APInt::getMaxValue(Ty->getIntegerBitWidth()));
+  FindLastIVPhiR->setOperand(0, MaxIV);
+
+  // The reduction using MinOrMaxPhiR needs adjusting to compute the correct
+  // result:
+  //  1. Find the first canonical indices corresponding to partial min/max
+  //     values, using loop reductions.
+  //  2. Find which of the partial min/max values are equal to the overall
+  //     min/max value.
+  //  3. Select among the canonical indices those corresponding to the overall
+  //     min/max value.
+  //  4. Find the first canonical index of overall min/max and scale it back to
+  //     the original IV using VPDerivedIVRecipe.
+  //  5. If the overall min/max equals the starting min/max, the condition in
+  //     the loop was always false, due to being strict; return the start value
+  //     of FindLastIVPhiR in that case.
+  //
+  // For example, this transforms two independent constructs
+  // vp<%min.result> = compute-reduction-result (min/max) ir<%min.val.next>
----------------
ayalz wrote:

```suggestion
  // vp<%min.result> = compute-reduction-result (min) ir<%min.val.next>
```
example computes argmin using `%min.result`, `%min.val.next`.

https://github.com/llvm/llvm-project/pull/170223


More information about the llvm-commits mailing list