[llvm-branch-commits] [llvm] 1c54112 - [SLP] refactor more reduction functions; NFC

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 20 08:20:39 PST 2021


Author: Sanjay Patel
Date: 2021-01-20T11:14:48-05:00
New Revision: 1c54112a5762ebab2c14a90c55f27d00bfced7f8

URL: https://github.com/llvm/llvm-project/commit/1c54112a5762ebab2c14a90c55f27d00bfced7f8
DIFF: https://github.com/llvm/llvm-project/commit/1c54112a5762ebab2c14a90c55f27d00bfced7f8.diff

LOG: [SLP] refactor more reduction functions; NFC

We were able to remove almost all of the state from
OperationData, so these don't make sense as members
of that class - just pass the RecurKind in as a param.

More streamlining is possible, but I'm trying to avoid
logic/typo bugs while fixing this. Eventually, we should
not need the `OperationData` class.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 3d657b0b898c..3192d7959f70 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6427,76 +6427,6 @@ class HorizontalReduction {
       return IsLeafValue || Kind != RecurKind::None;
     }
 
-    /// Return true if this operation is a cmp+select idiom.
-    bool isCmpSel() const {
-      assert(Kind != RecurKind::None && "Expected reduction operation.");
-      return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind);
-    }
-
-    /// Get the index of the first operand.
-    unsigned getFirstOperandIndex() const {
-      assert(!!*this && "The opcode is not set.");
-      // We allow calling this before 'Kind' is set, so handle that specially.
-      if (Kind == RecurKind::None)
-        return 0;
-      return isCmpSel() ? 1 : 0;
-    }
-
-    /// Total number of operands in the reduction operation.
-    unsigned getNumberOfOperands() const {
-      assert(Kind != RecurKind::None && !!*this &&
-             "Expected reduction operation.");
-      return isCmpSel() ? 3 : 2;
-    }
-
-    /// Checks if the instruction is in basic block \p BB.
-    /// For a min/max reduction check that both compare and select are in \p BB.
-    bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const {
-      assert(Kind != RecurKind::None && !!*this &&
-             "Expected reduction operation.");
-      if (IsRedOp && isCmpSel()) {
-        auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
-        return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
-      }
-      return I->getParent() == BB;
-    }
-
-    /// Expected number of uses for reduction operations/reduced values.
-    bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
-      assert(Kind != RecurKind::None && !!*this &&
-             "Expected reduction operation.");
-      // SelectInst must be used twice while the condition op must have single
-      // use only.
-      if (isCmpSel())
-        return I->hasNUses(2) &&
-               (!IsReductionOp ||
-                cast<SelectInst>(I)->getCondition()->hasOneUse());
-
-      // Arithmetic reduction operation must be used once only.
-      return I->hasOneUse();
-    }
-
-    /// Initializes the list of reduction operations.
-    void initReductionOps(ReductionOpsListType &ReductionOps) {
-      assert(Kind != RecurKind::None && !!*this &&
-             "Expected reduction operation.");
-      if (isCmpSel())
-        ReductionOps.assign(2, ReductionOpsType());
-      else
-        ReductionOps.assign(1, ReductionOpsType());
-    }
-
-    /// Add all reduction operations for the reduction instruction \p I.
-    void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
-      assert(Kind != RecurKind::None && "Expected reduction operation.");
-      if (isCmpSel()) {
-        ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
-        ReductionOps[1].emplace_back(I);
-      } else {
-        ReductionOps[0].emplace_back(I);
-      }
-    }
-
     /// Checks if instruction is associative and can be vectorized.
     bool isAssociative(Instruction *I) const {
       assert(Kind != RecurKind::None && "Expected reduction operation.");
@@ -6529,16 +6459,6 @@ class HorizontalReduction {
 
     /// Get kind of reduction data.
     RecurKind getKind() const { return Kind; }
-    Value *getLHS(Instruction *I) const {
-      if (Kind == RecurKind::None)
-        return nullptr;
-      return I->getOperand(getFirstOperandIndex());
-    }
-    Value *getRHS(Instruction *I) const {
-      if (Kind == RecurKind::None)
-        return nullptr;
-      return I->getOperand(getFirstOperandIndex() + 1);
-    }
   };
 
   WeakTrackingVH ReductionRoot;
@@ -6559,7 +6479,7 @@ class HorizontalReduction {
       // Do not perform analysis of remaining operands of ParentStackElem.first
       // instruction, this whole instruction is an extra argument.
       OperationData OpData = getOperationData(ParentStackElem.first);
-      ParentStackElem.second = OpData.getNumberOfOperands();
+      ParentStackElem.second = getNumberOfOperands(OpData.getKind());
     } else {
       // We ran into something like:
       // ParentStackElem.first += ... + ExtraArg + ...
@@ -6730,6 +6650,81 @@ class HorizontalReduction {
     return OperationData(*I);
   }
 
+  /// Return true if this operation is a cmp+select idiom.
+  static bool isCmpSel(RecurKind Kind) {
+    return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind);
+  }
+
+  /// Get the index of the first operand.
+  static unsigned getFirstOperandIndex(RecurKind Kind) {
+    // We allow calling this before 'Kind' is set, so handle that specially.
+    if (Kind == RecurKind::None)
+      return 0;
+    return isCmpSel(Kind) ? 1 : 0;
+  }
+
+  /// Total number of operands in the reduction operation.
+  static unsigned getNumberOfOperands(RecurKind Kind) {
+    return isCmpSel(Kind) ? 3 : 2;
+  }
+
+  /// Checks if the instruction is in basic block \p BB.
+  /// For a min/max reduction check that both compare and select are in \p BB.
+  static bool hasSameParent(RecurKind Kind, Instruction *I, BasicBlock *BB,
+                            bool IsRedOp) {
+    if (IsRedOp && isCmpSel(Kind)) {
+      auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
+      return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
+    }
+    return I->getParent() == BB;
+  }
+
+  /// Expected number of uses for reduction operations/reduced values.
+  static bool hasRequiredNumberOfUses(RecurKind Kind, Instruction *I,
+                                      bool IsReductionOp) {
+    // SelectInst must be used twice while the condition op must have single
+    // use only.
+    if (isCmpSel(Kind))
+      return I->hasNUses(2) &&
+             (!IsReductionOp ||
+              cast<SelectInst>(I)->getCondition()->hasOneUse());
+
+    // Arithmetic reduction operation must be used once only.
+    return I->hasOneUse();
+  }
+
+  /// Initializes the list of reduction operations.
+  static void initReductionOps(RecurKind Kind,
+                               ReductionOpsListType &ReductionOps) {
+    if (isCmpSel(Kind))
+      ReductionOps.assign(2, ReductionOpsType());
+    else
+      ReductionOps.assign(1, ReductionOpsType());
+  }
+
+  /// Add all reduction operations for the reduction instruction \p I.
+  static void addReductionOps(RecurKind Kind, Instruction *I,
+                              ReductionOpsListType &ReductionOps) {
+    assert(Kind != RecurKind::None && "Expected reduction operation.");
+    if (isCmpSel(Kind)) {
+      ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
+      ReductionOps[1].emplace_back(I);
+    } else {
+      ReductionOps[0].emplace_back(I);
+    }
+  }
+
+  static Value *getLHS(RecurKind Kind, Instruction *I) {
+    if (Kind == RecurKind::None)
+      return nullptr;
+    return I->getOperand(getFirstOperandIndex(Kind));
+  }
+  static Value *getRHS(RecurKind Kind, Instruction *I) {
+    if (Kind == RecurKind::None)
+      return nullptr;
+    return I->getOperand(getFirstOperandIndex(Kind) + 1);
+  }
+
 public:
   HorizontalReduction() = default;
 
@@ -6744,13 +6739,13 @@ class HorizontalReduction {
     //  r *= v1 + v2 + v3 + v4
     // In such a case start looking for a tree rooted in the first '+'.
     if (Phi) {
-      if (RdxTreeInst.getLHS(B) == Phi) {
+      if (getLHS(RdxTreeInst.getKind(), B) == Phi) {
         Phi = nullptr;
-        B = dyn_cast<Instruction>(RdxTreeInst.getRHS(B));
+        B = dyn_cast<Instruction>(getRHS(RdxTreeInst.getKind(), B));
         RdxTreeInst = getOperationData(B);
-      } else if (RdxTreeInst.getRHS(B) == Phi) {
+      } else if (getRHS(RdxTreeInst.getKind(), B) == Phi) {
         Phi = nullptr;
-        B = dyn_cast<Instruction>(RdxTreeInst.getLHS(B));
+        B = dyn_cast<Instruction>(getLHS(RdxTreeInst.getKind(), B));
         RdxTreeInst = getOperationData(B);
       }
     }
@@ -6775,8 +6770,9 @@ class HorizontalReduction {
     // Post order traverse the reduction tree starting at B. We only handle true
     // trees containing only binary operators.
     SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
-    Stack.push_back(std::make_pair(B, RdxTreeInst.getFirstOperandIndex()));
-    RdxTreeInst.initReductionOps(ReductionOps);
+    Stack.push_back(
+        std::make_pair(B, getFirstOperandIndex(RdxTreeInst.getKind())));
+    initReductionOps(RdxTreeInst.getKind(), ReductionOps);
     while (!Stack.empty()) {
       Instruction *TreeN = Stack.back().first;
       unsigned EdgeToVisit = Stack.back().second++;
@@ -6784,7 +6780,8 @@ class HorizontalReduction {
       bool IsReducedValue = OpData != RdxTreeInst;
 
       // Postorder visit.
-      if (IsReducedValue || EdgeToVisit == OpData.getNumberOfOperands()) {
+      if (IsReducedValue ||
+          EdgeToVisit == getNumberOfOperands(OpData.getKind())) {
         if (IsReducedValue)
           ReducedVals.push_back(TreeN);
         else {
@@ -6802,7 +6799,7 @@ class HorizontalReduction {
             markExtraArg(Stack[Stack.size() - 2], TreeN);
             ExtraArgs.erase(TreeN);
           } else
-            RdxTreeInst.addReductionOps(TreeN, ReductionOps);
+            addReductionOps(RdxTreeInst.getKind(), TreeN, ReductionOps);
         }
         // Retract.
         Stack.pop_back();
@@ -6822,8 +6819,8 @@ class HorizontalReduction {
       // ultimate reduction.
       const bool IsRdxInst = EdgeOpData == RdxTreeInst;
       if (I && I != Phi && I != B &&
-          RdxTreeInst.hasSameParent(I, B->getParent(), IsRdxInst) &&
-          RdxTreeInst.hasRequiredNumberOfUses(I, IsRdxInst) &&
+          hasSameParent(RdxTreeInst.getKind(), I, B->getParent(), IsRdxInst) &&
+          hasRequiredNumberOfUses(RdxTreeInst.getKind(), I, IsRdxInst) &&
           (!LeafOpcode || LeafOpcode == I->getOpcode() || IsRdxInst)) {
         if (IsRdxInst) {
           // We need to be able to reassociate the reduction operations.
@@ -6835,7 +6832,8 @@ class HorizontalReduction {
         } else if (!LeafOpcode) {
           LeafOpcode = I->getOpcode();
         }
-        Stack.push_back(std::make_pair(I, EdgeOpData.getFirstOperandIndex()));
+        Stack.push_back(
+            std::make_pair(I, getFirstOperandIndex(EdgeOpData.getKind())));
         continue;
       }
       // NextV is an extra argument for TreeN (its parent operation).
@@ -6976,7 +6974,7 @@ class HorizontalReduction {
       // Emit a reduction. If the root is a select (min/max idiom), the insert
       // point is the compare condition of that select.
       Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
-      if (RdxTreeInst.isCmpSel())
+      if (isCmpSel(RdxTreeInst.getKind()))
         Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
       else
         Builder.SetInsertPoint(RdxRootInst);
@@ -7019,7 +7017,7 @@ class HorizontalReduction {
       // select, we also have to RAUW for the compare instruction feeding the
       // reduction root. That's because the original compare may have extra uses
       // besides the final select of the reduction.
-      if (RdxTreeInst.isCmpSel()) {
+      if (isCmpSel(RdxTreeInst.getKind())) {
         if (auto *VecSelect = dyn_cast<SelectInst>(VectorizedTree)) {
           Instruction *ScalarCmp =
               getCmpForMinMaxReduction(cast<Instruction>(ReductionRoot));


        


More information about the llvm-branch-commits mailing list