[llvm] 92b2a26 - [NFC][SLP] Cleanup: Moves code that changes the reduction root into a separate function.

Vasileios Porpodas via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 28 10:06:32 PDT 2023


Author: Vasileios Porpodas
Date: 2023-04-28T10:05:32-07:00
New Revision: 92b2a266e99066a085aa9c371efe9ea13b0755cd

URL: https://github.com/llvm/llvm-project/commit/92b2a266e99066a085aa9c371efe9ea13b0755cd
DIFF: https://github.com/llvm/llvm-project/commit/92b2a266e99066a085aa9c371efe9ea13b0755cd.diff

LOG: [NFC][SLP] Cleanup: Moves code that changes the reduction root into a separate function.

This makes `matchAssociativeReduction()` a bit simpler.

Differential Revision: https://reviews.llvm.org/D149452

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ba882fe83a3c..c7685455c117 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12645,6 +12645,7 @@ class HorizontalReduction {
     return Op;
   }
 
+public:
   static RecurKind getRdxKind(Value *V) {
     auto *I = dyn_cast<Instruction>(V);
     if (!I)
@@ -12747,6 +12748,7 @@ class HorizontalReduction {
     return isCmpSelMinMax(I) ? 1 : 0;
   }
 
+private:
   /// Total number of operands in the reduction operation.
   static unsigned getNumberOfOperands(Instruction *I) {
     return isCmpSelMinMax(I) ? 3 : 2;
@@ -12795,17 +12797,6 @@ class HorizontalReduction {
     }
   }
 
-  static Value *getLHS(RecurKind Kind, Instruction *I) {
-    if (Kind == RecurKind::None)
-      return nullptr;
-    return I->getOperand(getFirstOperandIndex(I));
-  }
-  static Value *getRHS(RecurKind Kind, Instruction *I) {
-    if (Kind == RecurKind::None)
-      return nullptr;
-    return I->getOperand(getFirstOperandIndex(I) + 1);
-  }
-
   static bool isGoodForReduction(ArrayRef<Value *> Data) {
     int Sz = Data.size();
     auto *I = dyn_cast<Instruction>(Data.front());
@@ -12817,57 +12808,32 @@ class HorizontalReduction {
   HorizontalReduction() = default;
 
   /// Try to find a reduction tree.
-  bool matchAssociativeReduction(PHINode *Phi, Instruction *Inst,
+  bool matchAssociativeReduction(Instruction *Root,
                                  ScalarEvolution &SE, const DataLayout &DL,
                                  const TargetLibraryInfo &TLI) {
-    assert((!Phi || is_contained(Phi->operands(), Inst)) &&
-           "Phi needs to use the binary operator");
-    assert((isa<BinaryOperator>(Inst) || isa<SelectInst>(Inst) ||
-            isa<IntrinsicInst>(Inst)) &&
-           "Expected binop, select, or intrinsic for reduction matching");
-    RdxKind = getRdxKind(Inst);
-
-    // We could have a initial reductions that is not an add.
-    //  r *= v1 + v2 + v3 + v4
-    // In such a case start looking for a tree rooted in the first '+'.
-    if (Phi) {
-      if (getLHS(RdxKind, Inst) == Phi) {
-        Phi = nullptr;
-        Inst = dyn_cast<Instruction>(getRHS(RdxKind, Inst));
-        if (!Inst)
-          return false;
-        RdxKind = getRdxKind(Inst);
-      } else if (getRHS(RdxKind, Inst) == Phi) {
-        Phi = nullptr;
-        Inst = dyn_cast<Instruction>(getLHS(RdxKind, Inst));
-        if (!Inst)
-          return false;
-        RdxKind = getRdxKind(Inst);
-      }
-    }
-
-    if (!isVectorizable(RdxKind, Inst))
+    RdxKind = HorizontalReduction::getRdxKind(Root);
+    if (!isVectorizable(RdxKind, Root))
       return false;
 
     // Analyze "regular" integer/FP types for reductions - no target-specific
     // types or pointers.
-    Type *Ty = Inst->getType();
+    Type *Ty = Root->getType();
     if (!isValidElementType(Ty) || Ty->isPointerTy())
       return false;
 
     // Though the ultimate reduction may have multiple uses, its condition must
     // have only single use.
-    if (auto *Sel = dyn_cast<SelectInst>(Inst))
+    if (auto *Sel = dyn_cast<SelectInst>(Root))
       if (!Sel->getCondition()->hasOneUse())
         return false;
 
-    ReductionRoot = Inst;
+    ReductionRoot = Root;
 
     // Iterate through all the operands of the possible reduction tree and
     // gather all the reduced values, sorting them by their value id.
-    BasicBlock *BB = Inst->getParent();
-    bool IsCmpSelMinMax = isCmpSelMinMax(Inst);
-    SmallVector<Instruction *> Worklist(1, Inst);
+    BasicBlock *BB = Root->getParent();
+    bool IsCmpSelMinMax = isCmpSelMinMax(Root);
+    SmallVector<Instruction *> Worklist(1, Root);
     // Checks if the operands of the \p TreeN instruction are also reduction
     // operations or should be treated as reduced values or an extra argument,
     // which is not part of the reduction.
@@ -12907,7 +12873,7 @@ class HorizontalReduction {
     // instructions (grouping them by the predicate).
     MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>>
         PossibleReducedVals;
-    initReductionOps(Inst);
+    initReductionOps(Root);
     DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap;
     SmallSet<size_t, 2> LoadKeyUsed;
     SmallPtrSet<Value *, 4> DoNotReverseVals;
@@ -14012,6 +13978,26 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) {
   return false;
 }
 
+/// We could have an initial reduction that is not an add.
+///  r *= v1 + v2 + v3 + v4
+/// In such a case start looking for a tree rooted in the first '+'.
+/// \Returns the new root if found, which may be nullptr if not an instruction.
+static Instruction *tryGetScondaryReductionRoot(PHINode *Phi,
+                                                Instruction *Root) {
+  assert((isa<BinaryOperator>(Root) || isa<SelectInst>(Root) ||
+          isa<IntrinsicInst>(Root)) &&
+         "Expected binop, select, or intrinsic for reduction matching");
+  Value *LHS =
+      Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root));
+  Value *RHS =
+      Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root) + 1);
+  if (LHS == Phi)
+    return dyn_cast<Instruction>(RHS);
+  if (RHS == Phi)
+    return dyn_cast<Instruction>(LHS);
+  return nullptr;
+}
+
 bool SLPVectorizerPass::vectorizeHorReduction(
     PHINode *P, Value *V, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI,
     SmallVectorImpl<WeakTrackingVH> &PostponedInsts) {
@@ -14049,8 +14035,14 @@ bool SLPVectorizerPass::vectorizeHorReduction(
     bool IsBinop = matchRdxBop(Inst, B0, B1);
     bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value()));
     if (IsBinop || IsSelect) {
+      assert((!P || is_contained(P->operands(), Inst)) &&
+             "Phi needs to use the binary operator");
+      if (P && HorizontalReduction::getRdxKind(Inst) != RecurKind::None)
+        if (Instruction *NewRoot = tryGetScondaryReductionRoot(P, Inst))
+          Inst = NewRoot;
+
       HorizontalReduction HorRdx;
-      if (HorRdx.matchAssociativeReduction(P, Inst, *SE, *DL, *TLI))
+      if (HorRdx.matchAssociativeReduction(Inst, *SE, *DL, *TLI))
         return HorRdx.tryToReduce(R, TTI, *TLI);
     }
     return nullptr;


        


More information about the llvm-commits mailing list