[llvm-branch-commits] [llvm] d777533 - [SLP] simplify reduction matching

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 21 12:10:59 PST 2021


Author: Sanjay Patel
Date: 2021-01-21T14:58:57-05:00
New Revision: d77753381fe024434ae8ffaaacfe4b9ed9d4d760

URL: https://github.com/llvm/llvm-project/commit/d77753381fe024434ae8ffaaacfe4b9ed9d4d760
DIFF: https://github.com/llvm/llvm-project/commit/d77753381fe024434ae8ffaaacfe4b9ed9d4d760.diff

LOG: [SLP] simplify reduction matching

This is NFC-intended and removes the "OperationData"
class which had become nothing more than a recurrence
(reduction) type.

I adjusted the matching logic to distinguish
instructions from non-instructions - that's all that
the "IsLeafValue" member was keeping track of.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2597f88ab88d..73260016f443 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6401,44 +6401,9 @@ class HorizontalReduction {
   SmallVector<Value *, 32> ReducedVals;
   // Use map vector to make stable output.
   MapVector<Instruction *, Value *> ExtraArgs;
-
-  /// This wraps functionality around a RecurKind (reduction kind).
-  /// TODO: Remove this class if callers can use the 'Kind' value directly?
-  class OperationData {
-    /// Kind of the reduction operation.
-    RecurKind Kind = RecurKind::None;
-    bool IsLeafValue = false;
-
-  public:
-    explicit OperationData() = default;
-
-    /// Constructor for reduced values. They are identified by the bool only.
-    explicit OperationData(Instruction &I) { IsLeafValue = true; }
-
-    /// Constructor for reduction operations with opcode and type.
-    OperationData(RecurKind RdxKind) : Kind(RdxKind) {
-      assert(Kind != RecurKind::None && "Expected reduction operation.");
-    }
-
-    explicit operator bool() const {
-      return IsLeafValue || Kind != RecurKind::None;
-    }
-
-    /// Checks if two operation data are both a reduction op or both a reduced
-    /// value.
-    bool operator==(const OperationData &OD) const {
-      return Kind == OD.Kind && IsLeafValue == OD.IsLeafValue;
-    }
-    bool operator!=(const OperationData &OD) const { return !(*this == OD); }
-
-    /// Get kind of reduction data.
-    RecurKind getKind() const { return Kind; }
-  };
-
   WeakTrackingVH ReductionRoot;
-
-  /// The operation data of the reduction operation.
-  OperationData RdxTreeInst;
+  /// The type of reduction operation.
+  RecurKind RdxKind;
 
   /// Checks if instruction is associative and can be vectorized.
   static bool isVectorizable(RecurKind Kind, Instruction *I) {
@@ -6471,8 +6436,8 @@ class HorizontalReduction {
       // in this case.
       // Do not perform analysis of remaining operands of ParentStackElem.first
       // instruction, this whole instruction is an extra argument.
-      OperationData OpData = getOperationData(ParentStackElem.first);
-      ParentStackElem.second = getNumberOfOperands(OpData.getKind());
+      RecurKind RdxKind = getRdxKind(ParentStackElem.first);
+      ParentStackElem.second = getNumberOfOperands(RdxKind);
     } else {
       // We ran into something like:
       // ParentStackElem.first += ... + ExtraArg + ...
@@ -6550,39 +6515,37 @@ class HorizontalReduction {
     return Op;
   }
 
-  static OperationData getOperationData(Instruction *I) {
-    if (!I)
-      return OperationData();
-
+  static RecurKind getRdxKind(Instruction *I) {
+    assert(I && "Expected instruction for reduction matching");
     TargetTransformInfo::ReductionFlags RdxFlags;
     if (match(I, m_Add(m_Value(), m_Value())))
-      return OperationData(RecurKind::Add);
+      return RecurKind::Add;
     if (match(I, m_Mul(m_Value(), m_Value())))
-      return OperationData(RecurKind::Mul);
+      return RecurKind::Mul;
     if (match(I, m_And(m_Value(), m_Value())))
-      return OperationData(RecurKind::And);
+      return RecurKind::And;
     if (match(I, m_Or(m_Value(), m_Value())))
-      return OperationData(RecurKind::Or);
+      return RecurKind::Or;
     if (match(I, m_Xor(m_Value(), m_Value())))
-      return OperationData(RecurKind::Xor);
+      return RecurKind::Xor;
     if (match(I, m_FAdd(m_Value(), m_Value())))
-      return OperationData(RecurKind::FAdd);
+      return RecurKind::FAdd;
     if (match(I, m_FMul(m_Value(), m_Value())))
-      return OperationData(RecurKind::FMul);
+      return RecurKind::FMul;
 
     if (match(I, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_Value())))
-      return OperationData(RecurKind::FMax);
+      return RecurKind::FMax;
     if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
-      return OperationData(RecurKind::FMin);
+      return RecurKind::FMin;
 
     if (match(I, m_SMax(m_Value(), m_Value())))
-      return OperationData(RecurKind::SMax);
+      return RecurKind::SMax;
     if (match(I, m_SMin(m_Value(), m_Value())))
-      return OperationData(RecurKind::SMin);
+      return RecurKind::SMin;
     if (match(I, m_UMax(m_Value(), m_Value())))
-      return OperationData(RecurKind::UMax);
+      return RecurKind::UMax;
     if (match(I, m_UMin(m_Value(), m_Value())))
-      return OperationData(RecurKind::UMin);
+      return RecurKind::UMin;
 
     if (auto *Select = dyn_cast<SelectInst>(I)) {
       // Try harder: look for min/max pattern based on instructions producing
@@ -6608,39 +6571,39 @@ class HorizontalReduction {
       if (match(Cond, m_Cmp(Pred, m_Specific(LHS), m_Instruction(L2)))) {
         if (!isa<ExtractElementInst>(RHS) ||
             !L2->isIdenticalTo(cast<Instruction>(RHS)))
-          return OperationData(*I);
+          return RecurKind::None;
       } else if (match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Specific(RHS)))) {
         if (!isa<ExtractElementInst>(LHS) ||
             !L1->isIdenticalTo(cast<Instruction>(LHS)))
-          return OperationData(*I);
+          return RecurKind::None;
       } else {
         if (!isa<ExtractElementInst>(LHS) || !isa<ExtractElementInst>(RHS))
-          return OperationData(*I);
+          return RecurKind::None;
         if (!match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2))) ||
             !L1->isIdenticalTo(cast<Instruction>(LHS)) ||
             !L2->isIdenticalTo(cast<Instruction>(RHS)))
-          return OperationData(*I);
+          return RecurKind::None;
       }
 
       TargetTransformInfo::ReductionFlags RdxFlags;
       switch (Pred) {
       default:
-        return OperationData(*I);
+        return RecurKind::None;
       case CmpInst::ICMP_SGT:
       case CmpInst::ICMP_SGE:
-        return OperationData(RecurKind::SMax);
+        return RecurKind::SMax;
       case CmpInst::ICMP_SLT:
       case CmpInst::ICMP_SLE:
-        return OperationData(RecurKind::SMin);
+        return RecurKind::SMin;
       case CmpInst::ICMP_UGT:
       case CmpInst::ICMP_UGE:
-        return OperationData(RecurKind::UMax);
+        return RecurKind::UMax;
       case CmpInst::ICMP_ULT:
       case CmpInst::ICMP_ULE:
-        return OperationData(RecurKind::UMin);
+        return RecurKind::UMin;
       }
     }
-    return OperationData(*I);
+    return RecurKind::None;
   }
 
   /// Return true if this operation is a cmp+select idiom.
@@ -6724,24 +6687,28 @@ class HorizontalReduction {
     assert((!Phi || is_contained(Phi->operands(), B)) &&
            "Phi needs to use the binary operator");
 
-    RdxTreeInst = getOperationData(B);
+    RdxKind = getRdxKind(B);
 
     // We could have a initial reductions that is not an add.
     //  r *= v1 + v2 + v3 + v4
     // In such a case start looking for a tree rooted in the first '+'.
     if (Phi) {
-      if (getLHS(RdxTreeInst.getKind(), B) == Phi) {
+      if (getLHS(RdxKind, B) == Phi) {
         Phi = nullptr;
-        B = dyn_cast<Instruction>(getRHS(RdxTreeInst.getKind(), B));
-        RdxTreeInst = getOperationData(B);
-      } else if (getRHS(RdxTreeInst.getKind(), B) == Phi) {
+        B = dyn_cast<Instruction>(getRHS(RdxKind, B));
+        if (!B)
+          return false;
+        RdxKind = getRdxKind(B);
+      } else if (getRHS(RdxKind, B) == Phi) {
         Phi = nullptr;
-        B = dyn_cast<Instruction>(getLHS(RdxTreeInst.getKind(), B));
-        RdxTreeInst = getOperationData(B);
+        B = dyn_cast<Instruction>(getLHS(RdxKind, B));
+        if (!B)
+          return false;
+        RdxKind = getRdxKind(B);
       }
     }
 
-    if (!isVectorizable(RdxTreeInst.getKind(), B))
+    if (!isVectorizable(RdxKind, B))
       return false;
 
     // Analyze "regular" integer/FP types for reductions - no target-specific
@@ -6761,18 +6728,16 @@ class HorizontalReduction {
     // Post order traverse the reduction tree starting at B. We only handle true
     // trees containing only binary operators.
     SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
-    Stack.push_back(
-        std::make_pair(B, getFirstOperandIndex(RdxTreeInst.getKind())));
-    initReductionOps(RdxTreeInst.getKind());
+    Stack.push_back(std::make_pair(B, getFirstOperandIndex(RdxKind)));
+    initReductionOps(RdxKind);
     while (!Stack.empty()) {
       Instruction *TreeN = Stack.back().first;
       unsigned EdgeToVisit = Stack.back().second++;
-      const OperationData OpData = getOperationData(TreeN);
-      bool IsReducedValue = OpData != RdxTreeInst;
+      const RecurKind TreeRdxKind = getRdxKind(TreeN);
+      bool IsReducedValue = TreeRdxKind != RdxKind;
 
       // Postorder visit.
-      if (IsReducedValue ||
-          EdgeToVisit == getNumberOfOperands(OpData.getKind())) {
+      if (IsReducedValue || EdgeToVisit == getNumberOfOperands(TreeRdxKind)) {
         if (IsReducedValue)
           ReducedVals.push_back(TreeN);
         else {
@@ -6790,7 +6755,7 @@ class HorizontalReduction {
             markExtraArg(Stack[Stack.size() - 2], TreeN);
             ExtraArgs.erase(TreeN);
           } else
-            addReductionOps(RdxTreeInst.getKind(), TreeN);
+            addReductionOps(RdxKind, TreeN);
         }
         // Retract.
         Stack.pop_back();
@@ -6798,9 +6763,15 @@ class HorizontalReduction {
       }
 
       // Visit left or right.
-      Value *NextV = TreeN->getOperand(EdgeToVisit);
-      auto *I = dyn_cast<Instruction>(NextV);
-      const OperationData EdgeOpData = getOperationData(I);
+      Value *EdgeVal = TreeN->getOperand(EdgeToVisit);
+      auto *I = dyn_cast<Instruction>(EdgeVal);
+      if (!I) {
+        // Edge value is not a reduction instruction or a leaf instruction.
+        // (It may be a constant, function argument, or something else.)
+        markExtraArg(Stack.back(), EdgeVal);
+        continue;
+      }
+      RecurKind EdgeRdxKind = getRdxKind(I);
       // Continue analysis if the next operand is a reduction operation or
       // (possibly) a leaf value. If the leaf value opcode is not set,
       // the first met operation != reduction operation is considered as the
@@ -6808,14 +6779,14 @@ class HorizontalReduction {
       // Only handle trees in the current basic block.
       // Each tree node needs to have minimal number of users except for the
       // ultimate reduction.
-      const bool IsRdxInst = EdgeOpData == RdxTreeInst;
-      if (I && I != Phi && I != B &&
-          hasSameParent(RdxTreeInst.getKind(), I, B->getParent(), IsRdxInst) &&
-          hasRequiredNumberOfUses(RdxTreeInst.getKind(), I, IsRdxInst) &&
+      const bool IsRdxInst = EdgeRdxKind == RdxKind;
+      if (I != Phi && I != B &&
+          hasSameParent(RdxKind, I, B->getParent(), IsRdxInst) &&
+          hasRequiredNumberOfUses(RdxKind, I, IsRdxInst) &&
           (!LeafOpcode || LeafOpcode == I->getOpcode() || IsRdxInst)) {
         if (IsRdxInst) {
           // We need to be able to reassociate the reduction operations.
-          if (!isVectorizable(EdgeOpData.getKind(), I)) {
+          if (!isVectorizable(EdgeRdxKind, I)) {
             // I is an extra argument for TreeN (its parent operation).
             markExtraArg(Stack.back(), I);
             continue;
@@ -6823,12 +6794,11 @@ class HorizontalReduction {
         } else if (!LeafOpcode) {
           LeafOpcode = I->getOpcode();
         }
-        Stack.push_back(
-            std::make_pair(I, getFirstOperandIndex(EdgeOpData.getKind())));
+        Stack.push_back(std::make_pair(I, getFirstOperandIndex(EdgeRdxKind)));
         continue;
       }
-      // NextV is an extra argument for TreeN (its parent operation).
-      markExtraArg(Stack.back(), NextV);
+      // I is an extra argument for TreeN (its parent operation).
+      markExtraArg(Stack.back(), I);
     }
     return true;
   }
@@ -6922,7 +6892,7 @@ class HorizontalReduction {
       }
       if (V.isTreeTinyAndNotFullyVectorizable())
         break;
-      if (V.isLoadCombineReductionCandidate(RdxTreeInst.getKind()))
+      if (V.isLoadCombineReductionCandidate(RdxKind))
         break;
 
       V.computeMinimumValueSizes();
@@ -6965,7 +6935,7 @@ class HorizontalReduction {
       // Emit a reduction. If the root is a select (min/max idiom), the insert
       // point is the compare condition of that select.
       Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
-      if (isCmpSel(RdxTreeInst.getKind()))
+      if (isCmpSel(RdxKind))
         Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
       else
         Builder.SetInsertPoint(RdxRootInst);
@@ -6979,9 +6949,8 @@ class HorizontalReduction {
       } else {
         // Update the final value in the reduction.
         Builder.SetCurrentDebugLocation(Loc);
-        VectorizedTree =
-            createOp(Builder, RdxTreeInst.getKind(), VectorizedTree,
-                     ReducedSubTree, "op.rdx", ReductionOps);
+        VectorizedTree = createOp(Builder, RdxKind, VectorizedTree,
+                                  ReducedSubTree, "op.rdx", ReductionOps);
       }
       i += ReduxWidth;
       ReduxWidth = PowerOf2Floor(NumReducedVals - i);
@@ -6992,15 +6961,15 @@ class HorizontalReduction {
       for (; i < NumReducedVals; ++i) {
         auto *I = cast<Instruction>(ReducedVals[i]);
         Builder.SetCurrentDebugLocation(I->getDebugLoc());
-        VectorizedTree = createOp(Builder, RdxTreeInst.getKind(),
-                                  VectorizedTree, I, "", ReductionOps);
+        VectorizedTree =
+            createOp(Builder, RdxKind, VectorizedTree, I, "", ReductionOps);
       }
       for (auto &Pair : ExternallyUsedValues) {
         // Add each externally used value to the final reduction.
         for (auto *I : Pair.second) {
           Builder.SetCurrentDebugLocation(I->getDebugLoc());
-          VectorizedTree = createOp(Builder, RdxTreeInst.getKind(),
-                                    VectorizedTree, Pair.first, "op.extra", I);
+          VectorizedTree = createOp(Builder, RdxKind, VectorizedTree,
+                                    Pair.first, "op.extra", I);
         }
       }
 
@@ -7008,7 +6977,7 @@ class HorizontalReduction {
       // select, we also have to RAUW for the compare instruction feeding the
       // reduction root. That's because the original compare may have extra uses
       // besides the final select of the reduction.
-      if (isCmpSel(RdxTreeInst.getKind())) {
+      if (isCmpSel(RdxKind)) {
         if (auto *VecSelect = dyn_cast<SelectInst>(VectorizedTree)) {
           Instruction *ScalarCmp =
               getCmpForMinMaxReduction(cast<Instruction>(ReductionRoot));
@@ -7032,10 +7001,8 @@ class HorizontalReduction {
                        unsigned ReduxWidth) {
     Type *ScalarTy = FirstReducedVal->getType();
     FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth);
-
-    RecurKind Kind = RdxTreeInst.getKind();
     int VectorCost, ScalarCost;
-    switch (Kind) {
+    switch (RdxKind) {
     case RecurKind::Add:
     case RecurKind::Mul:
     case RecurKind::Or:
@@ -7043,7 +7010,7 @@ class HorizontalReduction {
     case RecurKind::Xor:
     case RecurKind::FAdd:
     case RecurKind::FMul: {
-      unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+      unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind);
       VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
                                                    /*IsPairwiseForm=*/false);
       ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
@@ -7066,7 +7033,8 @@ class HorizontalReduction {
     case RecurKind::UMax:
     case RecurKind::UMin: {
       auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
-      bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin;
+      bool IsUnsigned =
+          RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin;
       VectorCost =
           TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
                                       /*IsPairwiseForm=*/false, IsUnsigned);
@@ -7098,8 +7066,7 @@ class HorizontalReduction {
     // FIXME: The builder should use an FMF guard. It should not be hard-coded
     //        to 'fast'.
     assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
-    return createSimpleTargetReduction(Builder, TTI, VectorizedValue,
-                                       RdxTreeInst.getKind(),
+    return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
                                        ReductionOps.back());
   }
 };


        


More information about the llvm-branch-commits mailing list