[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 24 09:07:09 PST 2024


================
@@ -8663,6 +8663,110 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
   return Recipe;
 }
 
+/// Find all possible partial reductions in the loop and track all of those that
+/// are valid so recipes can be formed later.
+void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
+  // Find all possible partial reductions.
+  SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
+      PartialReductionChains;
+  for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
+    if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
+            getScaledReduction(Phi, RdxDesc, Range))
+      PartialReductionChains.push_back(*Pair);
+
+  // A partial reduction is invalid if any of its extends are used by
+  // something that isn't another partial reduction. This is because the
+  // extends are intended to be lowered along with the reduction itself.
+
+  // Build up a set of partial reduction bin ops for efficient use checking.
+  SmallSet<User *, 4> PartialReductionBinOps;
+  for (auto It : PartialReductionChains) {
+    if (It.first.BinOp)
+      PartialReductionBinOps.insert(It.first.BinOp);
+  }
+
+  auto ExtendIsOnlyUsedByPartialReductions =
+      [&PartialReductionBinOps](Instruction *Extend) {
+        return all_of(Extend->users(), [&](const User *U) {
+          return PartialReductionBinOps.contains(U);
+        });
+      };
+
+  // Check if each use of a chain's two extends is a partial reduction
+  // and only add those those that don't have non-partial reduction users.
+  for (auto Pair : PartialReductionChains) {
+    PartialReductionChain Chain = Pair.first;
+    if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
+        ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+      ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
+  }
+}
+
+/// Examines reduction operations to see if the target can use a cheaper
+/// operation with a wider per-iteration input VF and narrower PHI VF.
+/// Returns a struct containing the ratio between the two VFs and other cached
+/// information, or null if no scalable reduction was found.
+std::optional<std::pair<PartialReductionChain, unsigned>>
+VPRecipeBuilder::getScaledReduction(PHINode *PHI,
+                                    const RecurrenceDescriptor &Rdx,
+                                    VFRange &Range) {
+  // TODO: Allow scaling reductions when predicating. The select at
+  // the end of the loop chooses between the phi value and most recent
+  // reduction result, both of which have different VFs to the active lane
+  // mask when scaling.
+  if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+    return std::nullopt;
+
+  auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+  if (!Update)
+    return std::nullopt;
+
+  Value *Op = Update->getOperand(0);
+  if (Op == PHI)
+    Op = Update->getOperand(1);
+
+  // Match dot product pattern
+  auto *BinOp = dyn_cast<BinaryOperator>(Op);
+  if (!BinOp || !BinOp->hasOneUse())
+    return std::nullopt;
+
+  using namespace llvm::PatternMatch;
+  Value *A, *B;
+  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
+      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
+    return std::nullopt;
+
+  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
+  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+
+  // Check that the extends extend from the same type
----------------
fhahn wrote:

```suggestion
  // Check that the extends extend from the same type.
```

https://github.com/llvm/llvm-project/pull/92418


More information about the llvm-commits mailing list