[llvm-branch-commits] [llvm] 21a3a02 - [SLP] replace local reduction enum with RecurrenceKind; NFCI

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 29 11:57:35 PST 2020


Author: Sanjay Patel
Date: 2020-12-29T14:52:11-05:00
New Revision: 21a3a0225d84cd35227fc9d4d08234918a54f8d3

URL: https://github.com/llvm/llvm-project/commit/21a3a0225d84cd35227fc9d4d08234918a54f8d3
DIFF: https://github.com/llvm/llvm-project/commit/21a3a0225d84cd35227fc9d4d08234918a54f8d3.diff

LOG: [SLP] replace local reduction enum with RecurrenceKind; NFCI

I'm not sure if the SLP enum was created before the IVDescriptor
RecurrenceDescriptor / RecurrenceKind existed, but the code in
SLP is now redundant with that class, so it just makes things
more complicated to have both. We eventually call LoopUtils
createSimpleTargetReduction() to create reduction ops, so we
might as well standardize on those enum names.

There's still a question of whether we need to use TTI::ReductionFlags
vs. MinMaxRecurrenceKind, but that can be another clean-up step.

Another option would just be to flatten the enums in RecurrenceDescriptor
into a single enum. There isn't much benefit (smaller switches?) to
having a min/max subset.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 9f1768907227..eff0690eda82 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -35,6 +35,7 @@
 #include "llvm/Analysis/CodeMetrics.h"
 #include "llvm/Analysis/DemandedBits.h"
 #include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/IVDescriptors.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryLocation.h"
@@ -6445,16 +6446,7 @@ class HorizontalReduction {
   SmallVector<Value *, 32> ReducedVals;
   // Use map vector to make stable output.
   MapVector<Instruction *, Value *> ExtraArgs;
-
-  /// Kind of the reduction data.
-  enum ReductionKind {
-    RK_None,       /// Not a reduction.
-    RK_Arithmetic, /// Binary reduction data.
-    RK_SMin,       /// Signed minimum reduction data.
-    RK_UMin,       /// Unsigned minimum reduction data.
-    RK_SMax,       /// Signed maximum reduction data.
-    RK_UMax,       /// Unsigned maximum reduction data.
-  };
+  using RD = RecurrenceDescriptor;
 
   /// Contains info about operation, like its opcode, left and right operands.
   class OperationData {
@@ -6462,20 +6454,27 @@ class HorizontalReduction {
     unsigned Opcode = 0;
 
     /// Kind of the reduction operation.
-    ReductionKind Kind = RK_None;
+    RD::RecurrenceKind Kind = RD::RK_NoRecurrence;
+    TargetTransformInfo::ReductionFlags RdxFlags;
 
     /// Checks if the reduction operation can be vectorized.
     bool isVectorizable() const {
       switch (Kind) {
-      case RK_Arithmetic:
-        return Opcode == Instruction::Add || Opcode == Instruction::FAdd ||
-               Opcode == Instruction::Mul || Opcode == Instruction::FMul ||
-               Opcode == Instruction::And || Opcode == Instruction::Or ||
-               Opcode == Instruction::Xor;
-      case RK_SMin:
-      case RK_SMax:
-      case RK_UMin:
-      case RK_UMax:
+      case RD::RK_IntegerAdd:
+        return Opcode == Instruction::Add;
+      case RD::RK_IntegerMult:
+        return Opcode == Instruction::Mul;
+      case RD::RK_IntegerOr:
+        return Opcode == Instruction::Or;
+      case RD::RK_IntegerAnd:
+        return Opcode == Instruction::And;
+      case RD::RK_IntegerXor:
+        return Opcode == Instruction::Xor;
+      case RD::RK_FloatAdd:
+        return Opcode == Instruction::FAdd;
+      case RD::RK_FloatMult:
+        return Opcode == Instruction::FMul;
+      case RD::RK_IntegerMinMax:
         return Opcode == Instruction::ICmp;
       default:
         return false;
@@ -6485,33 +6484,31 @@ class HorizontalReduction {
     /// Creates reduction operation with the current opcode.
     Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
                     const Twine &Name) const {
-      assert(isVectorizable() &&
-             "Expected add|fadd or min/max reduction operation.");
-      Value *Cmp = nullptr;
+      assert(isVectorizable() && "Unhandled reduction operation.");
       switch (Kind) {
-      case RK_Arithmetic:
+      case RD::RK_IntegerAdd:
+      case RD::RK_IntegerMult:
+      case RD::RK_IntegerOr:
+      case RD::RK_IntegerAnd:
+      case RD::RK_IntegerXor:
+      case RD::RK_FloatAdd:
+      case RD::RK_FloatMult:
         return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
                                    Name);
-      case RK_SMin:
-        assert(Opcode == Instruction::ICmp && "Expected integer types.");
-        Cmp = Builder.CreateICmpSLT(LHS, RHS);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      case RK_SMax:
-        assert(Opcode == Instruction::ICmp && "Expected integer types.");
-        Cmp = Builder.CreateICmpSGT(LHS, RHS);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      case RK_UMin:
-        assert(Opcode == Instruction::ICmp && "Expected integer types.");
-        Cmp = Builder.CreateICmpULT(LHS, RHS);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      case RK_UMax:
+
+      case RD::RK_IntegerMinMax: {
         assert(Opcode == Instruction::ICmp && "Expected integer types.");
-        Cmp = Builder.CreateICmpUGT(LHS, RHS);
+        ICmpInst::Predicate Pred;
+        if (RdxFlags.IsMaxOp)
+          Pred = RdxFlags.IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
+        else
+          Pred = RdxFlags.IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
+        Value *Cmp = Builder.CreateICmp(Pred, LHS, RHS, Name);
         return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      case RK_None:
-        break;
       }
-      llvm_unreachable("Unknown reduction operation.");
+      default:
+        llvm_unreachable("Unknown reduction operation.");
+      }
     }
 
   public:
@@ -6523,50 +6520,42 @@ class HorizontalReduction {
       Opcode = I.getOpcode();
     }
 
-    /// Constructor for reduction operations with opcode and its left and
-    /// right operands.
-    OperationData(unsigned Opcode, ReductionKind Kind)
-        : Opcode(Opcode), Kind(Kind) {
-      assert(Kind != RK_None && "One of the reduction operations is expected.");
+    /// Constructor for reduction operations with opcode and type.
+    OperationData(unsigned Opcode, RD::RecurrenceKind Kind,
+                  TargetTransformInfo::ReductionFlags Flags)
+        : Opcode(Opcode), Kind(Kind), RdxFlags(Flags) {
+      assert(Kind != RD::RK_NoRecurrence && "Expected reduction operation.");
     }
 
     explicit operator bool() const { return Opcode; }
 
     /// Return true if this operation is any kind of minimum or maximum.
     bool isMinMax() const {
-      switch (Kind) {
-      case RK_Arithmetic:
-        return false;
-      case RK_SMin:
-      case RK_SMax:
-      case RK_UMin:
-      case RK_UMax:
-        return true;
-      case RK_None:
-        break;
-      }
-      llvm_unreachable("Reduction kind is not set");
+      assert(Kind != RD::RK_NoRecurrence && "Expected reduction operation.");
+      return Kind == RD::RK_IntegerMinMax;
     }
 
     /// Get the index of the first operand.
     unsigned getFirstOperandIndex() const {
       assert(!!*this && "The opcode is not set.");
       // We allow calling this before 'Kind' is set, so handle that specially.
-      if (Kind == RK_None)
+      if (Kind == RD::RK_NoRecurrence)
         return 0;
       return isMinMax() ? 1 : 0;
     }
 
     /// Total number of operands in the reduction operation.
     unsigned getNumberOfOperands() const {
-      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
+      assert(Kind != RD::RK_NoRecurrence && !!*this &&
+             "Expected reduction operation.");
       return isMinMax() ? 3 : 2;
     }
 
     /// Checks if the instruction is in basic block \p BB.
     /// For a min/max reduction check that both compare and select are in \p BB.
     bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const {
-      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
+      assert(Kind != RD::RK_NoRecurrence && !!*this &&
+             "Expected reduction operation.");
       if (IsRedOp && isMinMax()) {
         auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
         return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
@@ -6576,7 +6565,8 @@ class HorizontalReduction {
 
     /// Expected number of uses for reduction operations/reduced values.
     bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
-      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
+      assert(Kind != RD::RK_NoRecurrence && !!*this &&
+             "Expected reduction operation.");
       // SelectInst must be used twice while the condition op must have single
       // use only.
       if (isMinMax())
@@ -6590,7 +6580,8 @@ class HorizontalReduction {
 
     /// Initializes the list of reduction operations.
     void initReductionOps(ReductionOpsListType &ReductionOps) {
-      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
+      assert(Kind != RD::RK_NoRecurrence && !!*this &&
+             "Expected reduction operation.");
       if (isMinMax())
         ReductionOps.assign(2, ReductionOpsType());
       else
@@ -6599,7 +6590,8 @@ class HorizontalReduction {
 
     /// Add all reduction operations for the reduction instruction \p I.
     void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
-      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
+      assert(Kind != RD::RK_NoRecurrence && !!*this &&
+             "Expected reduction operation.");
       if (isMinMax()) {
         ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
         ReductionOps[1].emplace_back(I);
@@ -6610,21 +6602,16 @@ class HorizontalReduction {
 
     /// Checks if instruction is associative and can be vectorized.
     bool isAssociative(Instruction *I) const {
-      assert(Kind != RK_None && *this && "Expected reduction operation.");
+      assert(Kind != RD::RK_NoRecurrence && *this &&
+             "Expected reduction operation.");
       switch (Kind) {
-      case RK_Arithmetic:
-        return I->isAssociative();
-      case RK_SMin:
-      case RK_SMax:
-      case RK_UMin:
-      case RK_UMax:
+      case RD::RK_IntegerMinMax:
         assert(Opcode == Instruction::ICmp &&
                "Only integer compare operation is expected.");
         return true;
-      case RK_None:
-        break;
+      default:
+        return I->isAssociative();
       }
-      llvm_unreachable("Reduction kind is not set");
     }
 
     /// Checks if the reduction operation can be vectorized.
@@ -6642,7 +6629,7 @@ class HorizontalReduction {
     bool operator!=(const OperationData &OD) const { return !(*this == OD); }
     void clear() {
       Opcode = 0;
-      Kind = RK_None;
+      Kind = RD::RK_NoRecurrence;
     }
 
     /// Get the opcode of the reduction operation.
@@ -6652,14 +6639,14 @@ class HorizontalReduction {
     }
 
     /// Get kind of reduction data.
-    ReductionKind getKind() const { return Kind; }
+    RD::RecurrenceKind getKind() const { return Kind; }
     Value *getLHS(Instruction *I) const {
-      if (Kind == RK_None)
+      if (Kind == RD::RK_NoRecurrence)
         return nullptr;
       return I->getOperand(getFirstOperandIndex());
     }
     Value *getRHS(Instruction *I) const {
-      if (Kind == RK_None)
+      if (Kind == RD::RK_NoRecurrence)
         return nullptr;
       return I->getOperand(getFirstOperandIndex() + 1);
     }
@@ -6673,21 +6660,23 @@ class HorizontalReduction {
              "Expected add|fadd or min/max reduction operation.");
       auto *Op = createOp(Builder, LHS, RHS, Name);
       switch (Kind) {
-      case RK_Arithmetic:
+      case RD::RK_IntegerAdd:
+      case RD::RK_IntegerMult:
+      case RD::RK_IntegerOr:
+      case RD::RK_IntegerAnd:
+      case RD::RK_IntegerXor:
+      case RD::RK_FloatAdd:
+      case RD::RK_FloatMult:
         propagateIRFlags(Op, ReductionOps[0]);
         return Op;
-      case RK_SMin:
-      case RK_SMax:
-      case RK_UMin:
-      case RK_UMax:
+      case RD::RK_IntegerMinMax:
         if (auto *SI = dyn_cast<SelectInst>(Op))
           propagateIRFlags(SI->getCondition(), ReductionOps[0]);
         propagateIRFlags(Op, ReductionOps[1]);
         return Op;
-      case RK_None:
-        break;
+      default:
+        llvm_unreachable("Unknown reduction operation.");
       }
-      llvm_unreachable("Unknown reduction operation.");
     }
     /// Creates reduction operation with the current opcode with the IR flags
     /// from \p I.
@@ -6697,51 +6686,28 @@ class HorizontalReduction {
              "Expected add|fadd or min/max reduction operation.");
       auto *Op = createOp(Builder, LHS, RHS, Name);
       switch (Kind) {
-      case RK_Arithmetic:
+      case RD::RK_IntegerAdd:
+      case RD::RK_IntegerMult:
+      case RD::RK_IntegerOr:
+      case RD::RK_IntegerAnd:
+      case RD::RK_IntegerXor:
+      case RD::RK_FloatAdd:
+      case RD::RK_FloatMult:
         propagateIRFlags(Op, I);
         return Op;
-      case RK_SMin:
-      case RK_SMax:
-      case RK_UMin:
-      case RK_UMax:
+      case RD::RK_IntegerMinMax:
         if (auto *SI = dyn_cast<SelectInst>(Op)) {
           propagateIRFlags(SI->getCondition(),
                            cast<SelectInst>(I)->getCondition());
         }
         propagateIRFlags(Op, I);
         return Op;
-      case RK_None:
-        break;
+      default:
+        llvm_unreachable("Unknown reduction operation.");
       }
-      llvm_unreachable("Unknown reduction operation.");
     }
 
-    TargetTransformInfo::ReductionFlags getFlags() const {
-      TargetTransformInfo::ReductionFlags Flags;
-      switch (Kind) {
-      case RK_Arithmetic:
-        break;
-      case RK_SMin:
-        Flags.IsSigned = true;
-        Flags.IsMaxOp = false;
-        break;
-      case RK_SMax:
-        Flags.IsSigned = true;
-        Flags.IsMaxOp = true;
-        break;
-      case RK_UMin:
-        Flags.IsSigned = false;
-        Flags.IsMaxOp = false;
-        break;
-      case RK_UMax:
-        Flags.IsSigned = false;
-        Flags.IsMaxOp = true;
-        break;
-      case RK_None:
-        llvm_unreachable("Reduction kind is not set");
-      }
-      return Flags;
-    }
+    TargetTransformInfo::ReductionFlags getFlags() const { return RdxFlags; }
   };
 
   WeakTrackingVH ReductionRoot;
@@ -6781,79 +6747,95 @@ class HorizontalReduction {
     if (!I)
       return OperationData();
 
-    Value *LHS;
-    Value *RHS;
-    if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(I)) {
-      return OperationData(cast<BinaryOperator>(I)->getOpcode(), RK_Arithmetic);
+    TargetTransformInfo::ReductionFlags RdxFlags;
+    if (match(I, m_Add(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_IntegerAdd, RdxFlags);
+    if (match(I, m_Mul(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_IntegerMult, RdxFlags);
+    if (match(I, m_And(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_IntegerAnd, RdxFlags);
+    if (match(I, m_Or(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_IntegerOr, RdxFlags);
+    if (match(I, m_Xor(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_IntegerXor, RdxFlags);
+    if (match(I, m_FAdd(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_FloatAdd, RdxFlags);
+    if (match(I, m_FMul(m_Value(), m_Value())))
+      return OperationData(I->getOpcode(), RD::RK_FloatMult, RdxFlags);
+
+    if (match(I, m_MaxOrMin(m_Value(), m_Value()))) {
+      RdxFlags.IsMaxOp = match(I, m_UMax(m_Value(), m_Value())) ||
+                         match(I, m_SMax(m_Value(), m_Value()));
+      RdxFlags.IsSigned = match(I, m_SMin(m_Value(), m_Value())) ||
+                          match(I, m_SMax(m_Value(), m_Value()));
+      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax, RdxFlags);
     }
+
+
     if (auto *Select = dyn_cast<SelectInst>(I)) {
-      // Look for a min/max pattern.
-      if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, RK_UMin);
-      } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, RK_SMin);
-      } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, RK_UMax);
-      } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, RK_SMax);
+      // Try harder: look for min/max pattern based on instructions producing
+      // same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2).
+      // During the intermediate stages of SLP, it's very common to have
+      // pattern like this (since optimizeGatherSequence is run only once
+      // at the end):
+      // %1 = extractelement <2 x i32> %a, i32 0
+      // %2 = extractelement <2 x i32> %a, i32 1
+      // %cond = icmp sgt i32 %1, %2
+      // %3 = extractelement <2 x i32> %a, i32 0
+      // %4 = extractelement <2 x i32> %a, i32 1
+      // %select = select i1 %cond, i32 %3, i32 %4
+      CmpInst::Predicate Pred;
+      Instruction *L1;
+      Instruction *L2;
+
+      Value *LHS = Select->getTrueValue();
+      Value *RHS = Select->getFalseValue();
+      Value *Cond = Select->getCondition();
+
+      // TODO: Support inverse predicates.
+      if (match(Cond, m_Cmp(Pred, m_Specific(LHS), m_Instruction(L2)))) {
+        if (!isa<ExtractElementInst>(RHS) ||
+            !L2->isIdenticalTo(cast<Instruction>(RHS)))
+          return OperationData(*I);
+      } else if (match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Specific(RHS)))) {
+        if (!isa<ExtractElementInst>(LHS) ||
+            !L1->isIdenticalTo(cast<Instruction>(LHS)))
+          return OperationData(*I);
       } else {
-        // Try harder: look for min/max pattern based on instructions producing
-        // same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2).
-        // During the intermediate stages of SLP, it's very common to have
-        // pattern like this (since optimizeGatherSequence is run only once
-        // at the end):
-        // %1 = extractelement <2 x i32> %a, i32 0
-        // %2 = extractelement <2 x i32> %a, i32 1
-        // %cond = icmp sgt i32 %1, %2
-        // %3 = extractelement <2 x i32> %a, i32 0
-        // %4 = extractelement <2 x i32> %a, i32 1
-        // %select = select i1 %cond, i32 %3, i32 %4
-        CmpInst::Predicate Pred;
-        Instruction *L1;
-        Instruction *L2;
-
-        LHS = Select->getTrueValue();
-        RHS = Select->getFalseValue();
-        Value *Cond = Select->getCondition();
-
-        // TODO: Support inverse predicates.
-        if (match(Cond, m_Cmp(Pred, m_Specific(LHS), m_Instruction(L2)))) {
-          if (!isa<ExtractElementInst>(RHS) ||
-              !L2->isIdenticalTo(cast<Instruction>(RHS)))
-            return OperationData(*I);
-        } else if (match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Specific(RHS)))) {
-          if (!isa<ExtractElementInst>(LHS) ||
-              !L1->isIdenticalTo(cast<Instruction>(LHS)))
-            return OperationData(*I);
-        } else {
-          if (!isa<ExtractElementInst>(LHS) || !isa<ExtractElementInst>(RHS))
-            return OperationData(*I);
-          if (!match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2))) ||
-              !L1->isIdenticalTo(cast<Instruction>(LHS)) ||
-              !L2->isIdenticalTo(cast<Instruction>(RHS)))
-            return OperationData(*I);
-        }
-        switch (Pred) {
-        default:
+        if (!isa<ExtractElementInst>(LHS) || !isa<ExtractElementInst>(RHS))
           return OperationData(*I);
+        if (!match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2))) ||
+            !L1->isIdenticalTo(cast<Instruction>(LHS)) ||
+            !L2->isIdenticalTo(cast<Instruction>(RHS)))
+          return OperationData(*I);
+      }
 
-        case CmpInst::ICMP_ULT:
-        case CmpInst::ICMP_ULE:
-          return OperationData(Instruction::ICmp, RK_UMin);
-
-        case CmpInst::ICMP_SLT:
-        case CmpInst::ICMP_SLE:
-          return OperationData(Instruction::ICmp, RK_SMin);
-
-        case CmpInst::ICMP_UGT:
-        case CmpInst::ICMP_UGE:
-          return OperationData(Instruction::ICmp, RK_UMax);
-
-        case CmpInst::ICMP_SGT:
-        case CmpInst::ICMP_SGE:
-          return OperationData(Instruction::ICmp, RK_SMax);
-        }
+      TargetTransformInfo::ReductionFlags RdxFlags;
+      switch (Pred) {
+      default:
+        return OperationData(*I);
+      case CmpInst::ICMP_ULT:
+      case CmpInst::ICMP_ULE:
+        RdxFlags.IsMaxOp = false;
+        RdxFlags.IsSigned = false;
+        break;
+      case CmpInst::ICMP_SLT:
+      case CmpInst::ICMP_SLE:
+        RdxFlags.IsMaxOp = false;
+        RdxFlags.IsSigned = true;
+        break;
+      case CmpInst::ICMP_UGT:
+      case CmpInst::ICMP_UGE:
+        RdxFlags.IsMaxOp = true;
+        RdxFlags.IsSigned = false;
+        break;
+      case CmpInst::ICMP_SGT:
+      case CmpInst::ICMP_SGE:
+        RdxFlags.IsMaxOp = true;
+        RdxFlags.IsSigned = true;
+        break;
       }
+      return OperationData(Instruction::ICmp, RD::RK_IntegerMinMax, RdxFlags);
     }
     return OperationData(*I);
   }
@@ -7188,7 +7170,13 @@ class HorizontalReduction {
     int PairwiseRdxCost;
     int SplittingRdxCost;
     switch (RdxTreeInst.getKind()) {
-    case RK_Arithmetic:
+    case RD::RK_IntegerAdd:
+    case RD::RK_IntegerMult:
+    case RD::RK_IntegerOr:
+    case RD::RK_IntegerAnd:
+    case RD::RK_IntegerXor:
+    case RD::RK_FloatAdd:
+    case RD::RK_FloatMult:
       PairwiseRdxCost =
           TTI->getArithmeticReductionCost(RdxTreeInst.getOpcode(), VecTy,
                                           /*IsPairwiseForm=*/true);
@@ -7196,13 +7184,10 @@ class HorizontalReduction {
           TTI->getArithmeticReductionCost(RdxTreeInst.getOpcode(), VecTy,
                                           /*IsPairwiseForm=*/false);
       break;
-    case RK_SMin:
-    case RK_SMax:
-    case RK_UMin:
-    case RK_UMax: {
+    case RD::RK_IntegerMinMax: {
       auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy));
-      bool IsUnsigned = RdxTreeInst.getKind() == RK_UMin ||
-                        RdxTreeInst.getKind() == RK_UMax;
+      bool IsUnsigned = !RdxTreeInst.getFlags().IsSigned;
+
       PairwiseRdxCost =
           TTI->getMinMaxReductionCost(VecTy, VecCondTy,
                                       /*IsPairwiseForm=*/true, IsUnsigned);
@@ -7211,7 +7196,7 @@ class HorizontalReduction {
                                       /*IsPairwiseForm=*/false, IsUnsigned);
       break;
     }
-    case RK_None:
+    default:
       llvm_unreachable("Expected arithmetic or min/max reduction operation");
     }
 
@@ -7220,20 +7205,23 @@ class HorizontalReduction {
 
     int ScalarReduxCost = 0;
     switch (RdxTreeInst.getKind()) {
-    case RK_Arithmetic:
+    case RD::RK_IntegerAdd:
+    case RD::RK_IntegerMult:
+    case RD::RK_IntegerOr:
+    case RD::RK_IntegerAnd:
+    case RD::RK_IntegerXor:
+    case RD::RK_FloatAdd:
+    case RD::RK_FloatMult:
       ScalarReduxCost =
           TTI->getArithmeticInstrCost(RdxTreeInst.getOpcode(), ScalarTy);
       break;
-    case RK_SMin:
-    case RK_SMax:
-    case RK_UMin:
-    case RK_UMax:
+    case RD::RK_IntegerMinMax:
       ScalarReduxCost =
           TTI->getCmpSelInstrCost(RdxTreeInst.getOpcode(), ScalarTy) +
           TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
                                   CmpInst::makeCmpResultType(ScalarTy));
       break;
-    case RK_None:
+    default:
       llvm_unreachable("Expected arithmetic or min/max reduction operation");
     }
     ScalarReduxCost *= (ReduxWidth - 1);


        


More information about the llvm-branch-commits mailing list