[llvm] 8ea274b - [VPlan] Fix in-loop reduction chains using VPlan def-use chains (NFCI)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 2 09:05:34 PDT 2023
Author: Florian Hahn
Date: 2023-08-02T17:04:29+01:00
New Revision: 8ea274b46beb01756c2b26743fa165d3cbdeb355
URL: https://github.com/llvm/llvm-project/commit/8ea274b46beb01756c2b26743fa165d3cbdeb355
DIFF: https://github.com/llvm/llvm-project/commit/8ea274b46beb01756c2b26743fa165d3cbdeb355.diff
LOG: [VPlan] Fix in-loop reduction chains using VPlan def-use chains (NFCI)
Update adjustRecipesForReductions to directly use the VPlan def-use
chains for in-loop reductions to collect the reduction operations that
need adjusting.
This allows the removal of
* ReductionChainMap
* recording of recipes for instruction in the reduction chain
* removes late uses of getVPValue
* removes to need for removeVPValueFor.
Reviewed By: Ayal
Differential Revision: https://reviews.llvm.org/D155845
Added:
Modified:
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/lib/Transforms/Vectorize/VPlan.h
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 96680d16b5f25a..8bbcf7750c8916 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1266,7 +1266,7 @@ class LoopVectorizationCostModel {
void collectElementTypesForWidening();
/// Split reductions into those that happen in the loop, and those that happen
- /// outside. In loop reductions are collected into InLoopReductionChains.
+ /// outside. In loop reductions are collected into InLoopReductions.
void collectInLoopReductions();
/// Returns true if we should use strict in-order reductions for the given
@@ -1602,20 +1602,9 @@ class LoopVectorizationCostModel {
return foldTailByMasking() || Legal->blockNeedsPredication(BB);
}
- /// A SmallMapVector to store the InLoop reduction op chains, mapping phi
- /// nodes to the chain of instructions representing the reductions. Uses a
- /// MapVector to ensure deterministic iteration order.
- using ReductionChainMap =
- SmallMapVector<PHINode *, SmallVector<Instruction *, 4>, 4>;
-
- /// Return the chain of instructions representing an inloop reduction.
- const ReductionChainMap &getInLoopReductionChains() const {
- return InLoopReductionChains;
- }
-
/// Returns true if the Phi is part of an inloop reduction.
bool isInLoopReduction(PHINode *Phi) const {
- return InLoopReductionChains.count(Phi);
+ return InLoopReductions.contains(Phi);
}
/// Estimate cost of an intrinsic call instruction CI if it were vectorized
@@ -1779,15 +1768,12 @@ class LoopVectorizationCostModel {
/// scalarized.
DenseMap<ElementCount, SmallPtrSet<Instruction *, 4>> ForcedScalars;
- /// PHINodes of the reductions that should be expanded in-loop along with
- /// their associated chains of reduction operations, in program order from top
- /// (PHI) to bottom
- ReductionChainMap InLoopReductionChains;
+ /// PHINodes of the reductions that should be expanded in-loop.
+ SmallPtrSet<PHINode *, 4> InLoopReductions;
/// A Map of inloop reduction operations and their immediate chain operand.
/// FIXME: This can be removed once reductions can be costed correctly in
- /// vplan. This was added to allow quick lookup to the inloop operations,
- /// without having to loop through InLoopReductionChains.
+ /// VPlan. This was added to allow quick lookup of the inloop operations.
DenseMap<Instruction *, Instruction *> InLoopReductionImmediateChains;
/// Returns the expected
diff erence in cost from scalarizing the expression
@@ -6623,7 +6609,7 @@ LoopVectorizationCostModel::getReductionPatternCost(
Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) {
using namespace llvm::PatternMatch;
// Early exit for no inloop reductions
- if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty))
+ if (InLoopReductions.empty() || VF.isScalar() || !isa<VectorType>(Ty))
return std::nullopt;
auto *VectorTy = cast<VectorType>(Ty);
@@ -7473,8 +7459,9 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
SmallVector<Instruction *, 4> ReductionOperations =
RdxDesc.getReductionOpChain(Phi, TheLoop);
bool InLoop = !ReductionOperations.empty();
+
if (InLoop) {
- InLoopReductionChains[Phi] = ReductionOperations;
+ InLoopReductions.insert(Phi);
// Add the elements to InLoopReductionImmediateChains for cost modelling.
Instruction *LastChain = Phi;
for (auto *I : ReductionOperations) {
@@ -8866,24 +8853,6 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
// process after constructing the initial VPlan.
// ---------------------------------------------------------------------------
- for (const auto &Reduction : CM.getInLoopReductionChains()) {
- PHINode *Phi = Reduction.first;
- RecurKind Kind =
- Legal->getReductionVars().find(Phi)->second.getRecurrenceKind();
- const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second;
-
- RecipeBuilder.recordRecipeOf(Phi);
- for (const auto &R : ReductionOperations) {
- RecipeBuilder.recordRecipeOf(R);
- // For min/max reductions, where we have a pair of icmp/select, we also
- // need to record the ICmp recipe, so it can be removed later.
- assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) &&
- "Only min/max recurrences allowed for inloop reductions");
- if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
- RecipeBuilder.recordRecipeOf(cast<Instruction>(R->getOperand(0)));
- }
- }
-
// For each interleave group which is relevant for this (possibly trimmed)
// Range, add it to the set of groups to be later applied to the VPlan and add
// placeholders for its members' Recipes which we'll be replacing with a
@@ -9163,86 +9132,118 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
void LoopVectorizationPlanner::adjustRecipesForReductions(
VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder,
ElementCount MinVF) {
- for (const auto &Reduction : CM.getInLoopReductionChains()) {
- PHINode *Phi = Reduction.first;
- const RecurrenceDescriptor &RdxDesc =
- Legal->getReductionVars().find(Phi)->second;
- const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second;
-
- if (MinVF.isScalar() && !CM.useOrderedReductions(RdxDesc))
+ SmallVector<VPReductionPHIRecipe *> InLoopReductionPhis;
+ for (VPRecipeBase &R :
+ Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
+ auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
+ if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
continue;
+ InLoopReductionPhis.push_back(PhiR);
+ }
+
+ for (VPReductionPHIRecipe *PhiR : InLoopReductionPhis) {
+ const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
+ RecurKind Kind = RdxDesc.getRecurrenceKind();
+ assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) &&
+ "select/cmp reductions are not allowed for in-loop reductions");
+
+ // Collect the chain of "link" recipes for the reduction starting at PhiR.
+ SetVector<VPRecipeBase *> Worklist;
+ Worklist.insert(PhiR);
+ for (unsigned I = 0; I != Worklist.size(); ++I) {
+ VPRecipeBase *Cur = Worklist[I];
+ for (VPUser *U : Cur->getVPSingleValue()->users()) {
+ auto *UserRecipe = dyn_cast<VPRecipeBase>(U);
+ if (!UserRecipe)
+ continue;
+ assert(UserRecipe->getNumDefinedValues() == 1 &&
+ "recipes must define exactly one result value");
+ Worklist.insert(UserRecipe);
+ }
+ }
+
+ // Visit operation "Links" along the reduction chain top-down starting from
+ // the phi until LoopExitValue. We keep track of the previous item
+ // (PreviousLink) to tell which of the two operands of a Link will remain
+ // scalar and which will be reduced. For minmax by select(cmp), Link will be
+ // the select instructions.
+ VPRecipeBase *PreviousLink = PhiR; // Aka Worklist[0].
+ for (VPRecipeBase *CurrentLink : Worklist.getArrayRef().drop_front()) {
+ VPValue *PreviousLinkV = PreviousLink->getVPSingleValue();
- // ReductionOperations are orders top-down from the phi's use to the
- // LoopExitValue. We keep a track of the previous item (the Chain) to tell
- // which of the two operands will remain scalar and which will be reduced.
- // For minmax the chain will be the select instructions.
- Instruction *Chain = Phi;
- for (Instruction *R : ReductionOperations) {
- VPRecipeBase *WidenRecipe = RecipeBuilder.getRecipe(R);
- RecurKind Kind = RdxDesc.getRecurrenceKind();
-
- VPValue *ChainOp = Plan->getVPValue(Chain);
- unsigned FirstOpId;
- assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) &&
- "Only min/max recurrences allowed for inloop reductions");
+ Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr();
+
+ // Index of the first operand which holds a non-mask vector operand.
+ unsigned IndexOfFirstOperand;
// Recognize a call to the llvm.fmuladd intrinsic.
bool IsFMulAdd = (Kind == RecurKind::FMulAdd);
- assert((!IsFMulAdd || RecurrenceDescriptor::isFMulAddIntrinsic(R)) &&
- "Expected instruction to be a call to the llvm.fmuladd intrinsic");
- if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
- assert(isa<VPWidenSelectRecipe>(WidenRecipe) &&
- "Expected to replace a VPWidenSelectSC");
- FirstOpId = 1;
+ VPValue *VecOp;
+ VPBasicBlock *LinkVPBB = CurrentLink->getParent();
+ if (IsFMulAdd) {
+ assert(
+ RecurrenceDescriptor::isFMulAddIntrinsic(CurrentLinkI) &&
+ "Expected instruction to be a call to the llvm.fmuladd intrinsic");
+ assert(((MinVF.isScalar() && isa<VPReplicateRecipe>(CurrentLink)) ||
+ isa<VPWidenCallRecipe>(CurrentLink)) &&
+ CurrentLink->getOperand(2) == PreviousLinkV &&
+ "expected a call where the previous link is the added operand");
+
+ // If the instruction is a call to the llvm.fmuladd intrinsic then we
+ // need to create an fmul recipe (multiplying the first two operands of
+ // the fmuladd together) to use as the vector operand for the fadd
+ // reduction.
+ VPInstruction *FMulRecipe =
+ new VPInstruction(Instruction::FMul, {CurrentLink->getOperand(0),
+ CurrentLink->getOperand(1)});
+ FMulRecipe->setFastMathFlags(CurrentLinkI->getFastMathFlags());
+ LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator());
+ VecOp = FMulRecipe;
} else {
- assert((MinVF.isScalar() || isa<VPWidenRecipe>(WidenRecipe) ||
- (IsFMulAdd && isa<VPWidenCallRecipe>(WidenRecipe))) &&
- "Expected to replace a VPWidenSC");
- FirstOpId = 0;
+ if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
+ if (auto *Cmp = dyn_cast<VPWidenRecipe>(CurrentLink)) {
+ assert(isa<CmpInst>(CurrentLinkI) &&
+ "need to have the compare of the select");
+ continue;
+ }
+ assert(isa<VPWidenSelectRecipe>(CurrentLink) &&
+ "must be a select recipe");
+ IndexOfFirstOperand = 1;
+ } else {
+ assert((MinVF.isScalar() || isa<VPWidenRecipe>(CurrentLink)) &&
+ "Expected to replace a VPWidenSC");
+ IndexOfFirstOperand = 0;
+ }
+ // Note that for non-commutable operands (cmp-selects), the semantics of
+ // the cmp-select are captured in the recurrence kind.
+ unsigned VecOpId =
+ CurrentLink->getOperand(IndexOfFirstOperand) == PreviousLinkV
+ ? IndexOfFirstOperand + 1
+ : IndexOfFirstOperand;
+ VecOp = CurrentLink->getOperand(VecOpId);
+ assert(VecOp != PreviousLinkV &&
+ CurrentLink->getOperand(CurrentLink->getNumOperands() - 1 -
+ (VecOpId - IndexOfFirstOperand)) ==
+ PreviousLinkV &&
+ "PreviousLinkV must be the operand other than VecOp");
}
- unsigned VecOpId =
- R->getOperand(FirstOpId) == Chain ? FirstOpId + 1 : FirstOpId;
- VPValue *VecOp = Plan->getVPValue(R->getOperand(VecOpId));
+ BasicBlock *BB = CurrentLinkI->getParent();
VPValue *CondOp = nullptr;
- if (CM.blockNeedsPredicationForAnyReason(R->getParent())) {
+ if (CM.blockNeedsPredicationForAnyReason(BB)) {
VPBuilder::InsertPointGuard Guard(Builder);
- Builder.setInsertPoint(WidenRecipe->getParent(),
- WidenRecipe->getIterator());
- CondOp = RecipeBuilder.createBlockInMask(R->getParent(), *Plan);
+ Builder.setInsertPoint(LinkVPBB, CurrentLink->getIterator());
+ CondOp = RecipeBuilder.createBlockInMask(BB, *Plan);
}
- if (IsFMulAdd) {
- // If the instruction is a call to the llvm.fmuladd intrinsic then we
- // need to create an fmul recipe to use as the vector operand for the
- // fadd reduction.
- VPInstruction *FMulRecipe = new VPInstruction(
- Instruction::FMul, {VecOp, Plan->getVPValue(R->getOperand(1))});
- FMulRecipe->setFastMathFlags(R->getFastMathFlags());
- WidenRecipe->getParent()->insert(FMulRecipe,
- WidenRecipe->getIterator());
- VecOp = FMulRecipe;
- }
- VPReductionRecipe *RedRecipe =
- new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, &TTI);
- WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
- Plan->removeVPValueFor(R);
- Plan->addVPValue(R, RedRecipe);
+ VPReductionRecipe *RedRecipe = new VPReductionRecipe(
+ &RdxDesc, CurrentLinkI, PreviousLinkV, VecOp, CondOp, &TTI);
// Append the recipe to the end of the VPBasicBlock because we need to
// ensure that it comes after all of it's inputs, including CondOp.
- WidenRecipe->getParent()->appendRecipe(RedRecipe);
- WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
- WidenRecipe->eraseFromParent();
-
- if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
- VPRecipeBase *CompareRecipe =
- RecipeBuilder.getRecipe(cast<Instruction>(R->getOperand(0)));
- assert(isa<VPWidenRecipe>(CompareRecipe) &&
- "Expected to replace a VPWidenSC");
- assert(cast<VPWidenRecipe>(CompareRecipe)->getNumUsers() == 0 &&
- "Expected no remaining users");
- CompareRecipe->eraseFromParent();
- }
- Chain = R;
+ // Note that this transformation may leave over dead recipes (including
+ // CurrentLink), which will be cleaned by a later VPlan transform.
+ LinkVPBB->appendRecipe(RedRecipe);
+ CurrentLink->getVPSingleValue()->replaceAllUsesWith(RedRecipe);
+ PreviousLink = RedRecipe;
}
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 1e833829fda414..496d7682f1929b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2586,12 +2586,6 @@ class VPlan {
return getVPValue(V);
}
- void removeVPValueFor(Value *V) {
- assert(Value2VPValueEnabled &&
- "IR value to VPValue mapping may be out of date!");
- Value2VPValue.erase(V);
- }
-
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print this VPlan to \p O.
void print(raw_ostream &O) const;
More information about the llvm-commits
mailing list