[llvm] 92b2a26 - [NFC][SLP] Cleanup: Moves code that changes the reduction root into a separate function.
Vasileios Porpodas via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 28 10:06:32 PDT 2023
Author: Vasileios Porpodas
Date: 2023-04-28T10:05:32-07:00
New Revision: 92b2a266e99066a085aa9c371efe9ea13b0755cd
URL: https://github.com/llvm/llvm-project/commit/92b2a266e99066a085aa9c371efe9ea13b0755cd
DIFF: https://github.com/llvm/llvm-project/commit/92b2a266e99066a085aa9c371efe9ea13b0755cd.diff
LOG: [NFC][SLP] Cleanup: Moves code that changes the reduction root into a separate function.
This makes `matchAssociativeReduction()` a bit simpler.
Differential Revision: https://reviews.llvm.org/D149452
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ba882fe83a3c..c7685455c117 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12645,6 +12645,7 @@ class HorizontalReduction {
return Op;
}
+public:
static RecurKind getRdxKind(Value *V) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
@@ -12747,6 +12748,7 @@ class HorizontalReduction {
return isCmpSelMinMax(I) ? 1 : 0;
}
+private:
/// Total number of operands in the reduction operation.
static unsigned getNumberOfOperands(Instruction *I) {
return isCmpSelMinMax(I) ? 3 : 2;
@@ -12795,17 +12797,6 @@ class HorizontalReduction {
}
}
- static Value *getLHS(RecurKind Kind, Instruction *I) {
- if (Kind == RecurKind::None)
- return nullptr;
- return I->getOperand(getFirstOperandIndex(I));
- }
- static Value *getRHS(RecurKind Kind, Instruction *I) {
- if (Kind == RecurKind::None)
- return nullptr;
- return I->getOperand(getFirstOperandIndex(I) + 1);
- }
-
static bool isGoodForReduction(ArrayRef<Value *> Data) {
int Sz = Data.size();
auto *I = dyn_cast<Instruction>(Data.front());
@@ -12817,57 +12808,32 @@ class HorizontalReduction {
HorizontalReduction() = default;
/// Try to find a reduction tree.
- bool matchAssociativeReduction(PHINode *Phi, Instruction *Inst,
+ bool matchAssociativeReduction(Instruction *Root,
ScalarEvolution &SE, const DataLayout &DL,
const TargetLibraryInfo &TLI) {
- assert((!Phi || is_contained(Phi->operands(), Inst)) &&
- "Phi needs to use the binary operator");
- assert((isa<BinaryOperator>(Inst) || isa<SelectInst>(Inst) ||
- isa<IntrinsicInst>(Inst)) &&
- "Expected binop, select, or intrinsic for reduction matching");
- RdxKind = getRdxKind(Inst);
-
- // We could have a initial reductions that is not an add.
- // r *= v1 + v2 + v3 + v4
- // In such a case start looking for a tree rooted in the first '+'.
- if (Phi) {
- if (getLHS(RdxKind, Inst) == Phi) {
- Phi = nullptr;
- Inst = dyn_cast<Instruction>(getRHS(RdxKind, Inst));
- if (!Inst)
- return false;
- RdxKind = getRdxKind(Inst);
- } else if (getRHS(RdxKind, Inst) == Phi) {
- Phi = nullptr;
- Inst = dyn_cast<Instruction>(getLHS(RdxKind, Inst));
- if (!Inst)
- return false;
- RdxKind = getRdxKind(Inst);
- }
- }
-
- if (!isVectorizable(RdxKind, Inst))
+ RdxKind = HorizontalReduction::getRdxKind(Root);
+ if (!isVectorizable(RdxKind, Root))
return false;
// Analyze "regular" integer/FP types for reductions - no target-specific
// types or pointers.
- Type *Ty = Inst->getType();
+ Type *Ty = Root->getType();
if (!isValidElementType(Ty) || Ty->isPointerTy())
return false;
// Though the ultimate reduction may have multiple uses, its condition must
// have only single use.
- if (auto *Sel = dyn_cast<SelectInst>(Inst))
+ if (auto *Sel = dyn_cast<SelectInst>(Root))
if (!Sel->getCondition()->hasOneUse())
return false;
- ReductionRoot = Inst;
+ ReductionRoot = Root;
// Iterate through all the operands of the possible reduction tree and
// gather all the reduced values, sorting them by their value id.
- BasicBlock *BB = Inst->getParent();
- bool IsCmpSelMinMax = isCmpSelMinMax(Inst);
- SmallVector<Instruction *> Worklist(1, Inst);
+ BasicBlock *BB = Root->getParent();
+ bool IsCmpSelMinMax = isCmpSelMinMax(Root);
+ SmallVector<Instruction *> Worklist(1, Root);
// Checks if the operands of the \p TreeN instruction are also reduction
// operations or should be treated as reduced values or an extra argument,
// which is not part of the reduction.
@@ -12907,7 +12873,7 @@ class HorizontalReduction {
// instructions (grouping them by the predicate).
MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>>
PossibleReducedVals;
- initReductionOps(Inst);
+ initReductionOps(Root);
DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap;
SmallSet<size_t, 2> LoadKeyUsed;
SmallPtrSet<Value *, 4> DoNotReverseVals;
@@ -14012,6 +13978,26 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) {
return false;
}
+/// We could have an initial reduction that is not an add.
+/// r *= v1 + v2 + v3 + v4
+/// In such a case start looking for a tree rooted in the first '+'.
+/// \Returns the new root if found, which may be nullptr if not an instruction.
+static Instruction *tryGetScondaryReductionRoot(PHINode *Phi,
+ Instruction *Root) {
+ assert((isa<BinaryOperator>(Root) || isa<SelectInst>(Root) ||
+ isa<IntrinsicInst>(Root)) &&
+ "Expected binop, select, or intrinsic for reduction matching");
+ Value *LHS =
+ Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root));
+ Value *RHS =
+ Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root) + 1);
+ if (LHS == Phi)
+ return dyn_cast<Instruction>(RHS);
+ if (RHS == Phi)
+ return dyn_cast<Instruction>(LHS);
+ return nullptr;
+}
+
bool SLPVectorizerPass::vectorizeHorReduction(
PHINode *P, Value *V, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI,
SmallVectorImpl<WeakTrackingVH> &PostponedInsts) {
@@ -14049,8 +14035,14 @@ bool SLPVectorizerPass::vectorizeHorReduction(
bool IsBinop = matchRdxBop(Inst, B0, B1);
bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value()));
if (IsBinop || IsSelect) {
+ assert((!P || is_contained(P->operands(), Inst)) &&
+ "Phi needs to use the binary operator");
+ if (P && HorizontalReduction::getRdxKind(Inst) != RecurKind::None)
+ if (Instruction *NewRoot = tryGetScondaryReductionRoot(P, Inst))
+ Inst = NewRoot;
+
HorizontalReduction HorRdx;
- if (HorRdx.matchAssociativeReduction(P, Inst, *SE, *DL, *TLI))
+ if (HorRdx.matchAssociativeReduction(Inst, *SE, *DL, *TLI))
return HorRdx.tryToReduce(R, TTI, *TLI);
}
return nullptr;
More information about the llvm-commits
mailing list