[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 18 09:14:20 PST 2024
================
@@ -8663,6 +8663,113 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
return Recipe;
}
+void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
+ // Find all possible partial reductions
+ SmallVector<PartialReductionChain, 1> PartialReductionChains;
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
+ if (std::optional<PartialReductionChain> Chain =
+ getScaledReduction(Phi, RdxDesc, Range))
+ PartialReductionChains.push_back(*Chain);
+
+ // A partial reduction is invalid if any of its extends are used by
+ // something that isn't another partial reduction. This is because the
+ // extends are intended to be lowered along with the reduction itself.
+
+ // Build up a set of partial reduction bin ops for efficient use checking
+ SmallSet<User *, 4> PartialReductionBinOps;
+ for (auto It : PartialReductionChains) {
+ if (It.BinOp)
+ PartialReductionBinOps.insert(It.BinOp);
+ }
+
+ auto ExtendIsOnlyUsedByPartialReductions =
+ [&PartialReductionBinOps](Instruction *Extend) {
+ return all_of(Extend->users(), [&](const User *U) {
+ return PartialReductionBinOps.contains(U);
+ });
+ };
+
+ // Check if each use of a chain's two extends is a partial reduction
+ // and only add those those that don't have non-partial reduction users
+ for (auto It : PartialReductionChains) {
+ PartialReductionChain Chain = It;
+ if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
+ ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+ ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Chain));
+ }
+}
+
+/// Examines reduction operations to see if the target can use a cheaper
+/// operation with a wider per-iteration input VF and narrower PHI VF.
+/// Returns a struct containing the ratio between the two VFs and other cached
+/// information, or null if no scalable reduction was found.
+std::optional<PartialReductionChain> VPRecipeBuilder::getScaledReduction(
+ PHINode *PHI, const RecurrenceDescriptor &Rdx, VFRange &Range) {
+ // TODO: Allow scaling reductions when predicating. The select at
+ // the end of the loop chooses between the phi value and most recent
+ // reduction result, both of which have different VFs to the active lane
+ // mask when scaling.
+ if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+ return std::nullopt;
+
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+ if (!Update)
+ return std::nullopt;
+
+ Value *Op = Update->getOperand(0);
+ if (Op == PHI)
+ Op = Update->getOperand(1);
+
+ // Match dot product pattern
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
+ if (!BinOp || !BinOp->hasOneUse())
+ return std::nullopt;
+
+ auto IsSextOrZext = [](Instruction *I) {
+ return I && (I->getOpcode() == Instruction::ZExt ||
+ I->getOpcode() == Instruction::SExt);
+ };
+
+ auto *ExtA = dyn_cast<Instruction>(BinOp->getOperand(0));
+ auto *ExtB = dyn_cast<Instruction>(BinOp->getOperand(1));
+ if (!IsSextOrZext(ExtA) || !IsSextOrZext(ExtB))
+ return std::nullopt;
+
+ Value *A = ExtA->getOperand(0);
+ Value *B = ExtB->getOperand(0);
+ // Check that the extends extend from the same type
+ if (A->getType() != B->getType())
+ return std::nullopt;
+
+ unsigned TargetScaleFactor =
+ PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
+ A->getType()->getPrimitiveSizeInBits());
+
+ TTI::PartialReductionExtendKind OpAExtend =
+ TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+ TTI::PartialReductionExtendKind OpBExtend =
+ TargetTransformInfo::getPartialReductionExtendKind(ExtB);
+
+ PartialReductionChain Chain;
+ Chain.Reduction = Rdx.getLoopExitInstr();
+ Chain.ExtendA = ExtA;
+ Chain.ExtendB = ExtB;
+ Chain.ScaleFactor = TargetScaleFactor;
+ Chain.BinOp = dyn_cast<Instruction>(Op);
----------------
SamTebbs33 wrote:
Done.
https://github.com/llvm/llvm-project/pull/92418
More information about the llvm-commits
mailing list