[llvm] r309566 - [SLP] Initial rework for min/max horizontal reduction vectorization, NFC.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 07:36:05 PDT 2017


Author: abataev
Date: Mon Jul 31 07:36:05 2017
New Revision: 309566

URL: http://llvm.org/viewvc/llvm-project?rev=309566&view=rev
Log:
[SLP] Initial rework for min/max horizontal reduction vectorization, NFC.

Summary: All getReductionCost() functions are renamed to getArithmeticReductionCost() + added basic infrastructure to handle non-binary reduction operations.

Reviewers: spatel, mzolotukhin, Ayal, mkuper, gilr, hfinkel

Subscribers: RKSimon, llvm-commits

Differential Revision: https://reviews.llvm.org/D29402

Modified:
    llvm/trunk/lib/Analysis/CostModel.cpp
    llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp

Modified: llvm/trunk/lib/Analysis/CostModel.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/CostModel.cpp?rev=309566&r1=309565&r2=309566&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/CostModel.cpp (original)
+++ llvm/trunk/lib/Analysis/CostModel.cpp Mon Jul 31 07:36:05 2017
@@ -24,12 +24,14 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 using namespace llvm;
+using namespace PatternMatch;
 
 #define CM_NAME "cost-model"
 #define DEBUG_TYPE CM_NAME
@@ -183,27 +185,46 @@ static bool matchPairwiseShuffleMask(Shu
   return Mask == ActualMask;
 }
 
-static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp,
-                                          unsigned Level, unsigned NumLevels) {
+namespace {
+/// Contains opcode + LHS/RHS parts of the reduction operations.
+struct ReductionData {
+  explicit ReductionData() = default;
+  ReductionData(unsigned Opcode, Value *LHS, Value *RHS)
+      : Opcode(Opcode), LHS(LHS), RHS(RHS) {}
+  unsigned Opcode = 0;
+  Value *LHS = nullptr;
+  Value *RHS = nullptr;
+};
+} // namespace
+
+static Optional<ReductionData> getReductionData(Instruction *I) {
+  Value *L, *R;
+  if (m_BinOp(m_Value(L), m_Value(R)).match(I))
+    return ReductionData(I->getOpcode(), L, R);
+  return llvm::None;
+}
+
+static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level,
+                                          unsigned NumLevels) {
   // Match one level of pairwise operations.
   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
-  if (BinOp == nullptr)
+  if (!I)
     return false;
 
-  assert(BinOp->getType()->isVectorTy() && "Expecting a vector type");
+  assert(I->getType()->isVectorTy() && "Expecting a vector type");
 
-  unsigned Opcode = BinOp->getOpcode();
-  Value *L = BinOp->getOperand(0);
-  Value *R = BinOp->getOperand(1);
+  Optional<ReductionData> RD = getReductionData(I);
+  if (!RD)
+    return false;
 
-  ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L);
+  ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS);
   if (!LS && Level)
     return false;
-  ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R);
+  ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS);
   if (!RS && Level)
     return false;
 
@@ -228,31 +249,30 @@ static bool matchPairwiseReductionAtLeve
     // Example:
     //  %NextLevelOpL = shufflevector %R, <1, undef ...>
     //  %BinOp        = fadd          %NextLevelOpL, %R
-    if (NextLevelOpL && NextLevelOpL != R)
+    if (NextLevelOpL && NextLevelOpL != RD->RHS)
       return false;
-    else if (NextLevelOpR && NextLevelOpR != L)
+    else if (NextLevelOpR && NextLevelOpR != RD->LHS)
       return false;
 
-    NextLevelOp = NextLevelOpL ? R : L;
+    NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS;
   } else
     return false;
 
   // Check that the next levels binary operation exists and matches with the
   // current one.
-  BinaryOperator *NextLevelBinOp = nullptr;
   if (Level + 1 != NumLevels) {
-    if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp)))
-      return false;
-    else if (NextLevelBinOp->getOpcode() != Opcode)
+    Optional<ReductionData> NextLevelRD =
+        getReductionData(cast<Instruction>(NextLevelOp));
+    if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode)
       return false;
   }
 
   // Shuffle mask for pairwise operation must match.
-  if (matchPairwiseShuffleMask(LS, true, Level)) {
-    if (!matchPairwiseShuffleMask(RS, false, Level))
+  if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) {
+    if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level))
       return false;
-  } else if (matchPairwiseShuffleMask(RS, true, Level)) {
-    if (!matchPairwiseShuffleMask(LS, false, Level))
+  } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) {
+    if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level))
       return false;
   } else
     return false;
@@ -261,7 +281,8 @@ static bool matchPairwiseReductionAtLeve
     return true;
 
   // Match next level.
-  return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels);
+  return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level,
+                                       NumLevels);
 }
 
 static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
@@ -277,11 +298,14 @@ static bool matchPairwiseReduction(const
   if (Idx != 0)
     return false;
 
-  BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
+  auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
   if (!RdxStart)
     return false;
+  Optional<ReductionData> RD = getReductionData(RdxStart);
+  if (!RD)
+    return false;
 
-  Type *VecTy = ReduxRoot->getOperand(0)->getType();
+  Type *VecTy = RdxStart->getType();
   unsigned NumVecElems = VecTy->getVectorNumElements();
   if (!isPowerOf2_32(NumVecElems))
     return false;
@@ -307,17 +331,14 @@ static bool matchPairwiseReduction(const
   if (!matchPairwiseReductionAtLevel(RdxStart, 0,  Log2_32(NumVecElems)))
     return false;
 
-  Opcode = RdxStart->getOpcode();
+  Opcode = RD->Opcode;
   Ty = VecTy;
 
   return true;
 }
 
 static std::pair<Value *, ShuffleVectorInst *>
-getShuffleAndOtherOprd(BinaryOperator *B) {
-
-  Value *L = B->getOperand(0);
-  Value *R = B->getOperand(1);
+getShuffleAndOtherOprd(Value *L, Value *R) {
   ShuffleVectorInst *S = nullptr;
 
   if ((S = dyn_cast<ShuffleVectorInst>(L)))
@@ -340,10 +361,12 @@ static bool matchVectorSplittingReductio
   if (Idx != 0)
     return false;
 
-  BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
+  auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
   if (!RdxStart)
     return false;
-  unsigned RdxOpcode = RdxStart->getOpcode();
+  Optional<ReductionData> RD = getReductionData(RdxStart);
+  if (!RD)
+    return false;
 
   Type *VecTy = ReduxRoot->getOperand(0)->getType();
   unsigned NumVecElems = VecTy->getVectorNumElements();
@@ -362,20 +385,21 @@ static bool matchVectorSplittingReductio
   // %r = extractelement <4 x float> %bin.rdx8, i32 0
 
   unsigned MaskStart = 1;
-  Value *RdxOp = RdxStart;
+  Instruction *RdxOp = RdxStart;
   SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
   unsigned NumVecElemsRemain = NumVecElems;
   while (NumVecElemsRemain - 1) {
     // Check for the right reduction operation.
-    BinaryOperator *BinOp;
-    if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp)))
+    if (!RdxOp)
       return false;
-    if (BinOp->getOpcode() != RdxOpcode)
+    Optional<ReductionData> RDLevel = getReductionData(RdxOp);
+    if (!RDLevel || RDLevel->Opcode != RD->Opcode)
       return false;
 
     Value *NextRdxOp;
     ShuffleVectorInst *Shuffle;
-    std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp);
+    std::tie(NextRdxOp, Shuffle) =
+        getShuffleAndOtherOprd(RDLevel->LHS, RDLevel->RHS);
 
     // Check the current reduction operation and the shuffle use the same value.
     if (Shuffle == nullptr)
@@ -393,12 +417,12 @@ static bool matchVectorSplittingReductio
     if (ShuffleMask != Mask)
       return false;
 
-    RdxOp = NextRdxOp;
+    RdxOp = dyn_cast<Instruction>(NextRdxOp);
     NumVecElemsRemain /= 2;
     MaskStart *= 2;
   }
 
-  Opcode = RdxOpcode;
+  Opcode = RD->Opcode;
   Ty = VecTy;
   return true;
 }
@@ -495,10 +519,14 @@ unsigned CostModelAnalysis::getInstructi
     unsigned ReduxOpCode;
     Type *ReduxType;
 
-    if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType))
-      return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, false);
-    else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType))
-      return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, true);
+    if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
+      return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
+                                             /*IsPairwiseForm=*/false);
+    }
+    if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+      return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
+                                             /*IsPairwiseForm=*/true);
+    }
 
     return TTI->getVectorInstrCost(I->getOpcode(),
                                    EEI->getOperand(0)->getType(), Idx);

Modified: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp?rev=309566&r1=309565&r2=309566&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp (original)
+++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp Mon Jul 31 07:36:05 2017
@@ -33,6 +33,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/NoFolder.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/IR/Verifier.h"
@@ -48,6 +49,7 @@
 #include <memory>
 
 using namespace llvm;
+using namespace llvm::PatternMatch;
 using namespace slpvectorizer;
 
 #define SV_NAME "slp-vectorizer"
@@ -4321,12 +4323,104 @@ class HorizontalReduction {
   // Use map vector to make stable output.
   MapVector<Instruction *, Value *> ExtraArgs;
 
-  BinaryOperator *ReductionRoot = nullptr;
+  /// Contains info about operation, like its opcode, left and right operands.
+  struct OperationData {
+    /// true if the operation is a reduced value, false if reduction operation.
+    bool IsReducedValue = false;
+    /// Opcode of the instruction.
+    unsigned Opcode = 0;
+    /// Left operand of the reduction operation.
+    Value *LHS = nullptr;
+    /// Right operand of the reduction operation.
+    Value *RHS = nullptr;
+
+    /// Checks if the reduction operation can be vectorized.
+    bool isVectorizable() const {
+      return LHS && RHS &&
+             // We currently only support adds.
+             (Opcode == Instruction::Add || Opcode == Instruction::FAdd);
+    }
+
+  public:
+    explicit OperationData() = default;
+    /// Construction for reduced values. They are identified by opcode only and
+    /// don't have associated LHS/RHS values.
+    explicit OperationData(Value *V) : IsReducedValue(true) {
+      if (auto *I = dyn_cast<Instruction>(V))
+        Opcode = I->getOpcode();
+    }
+    /// Constructor for binary reduction operations with opcode and its left and
+    /// right operands.
+    OperationData(unsigned Opcode, Value *LHS, Value *RHS)
+        : IsReducedValue(false), Opcode(Opcode), LHS(LHS), RHS(RHS) {}
+    explicit operator bool() const { return Opcode; }
+    /// Get the index of the first operand.
+    unsigned getFirstOperandIndex() const {
+      assert(!!*this && "The opcode is not set.");
+      return 0;
+    }
+    /// Total number of operands in the reduction operation.
+    unsigned getNumberOfOperands() const {
+      assert(!IsReducedValue && !!*this && LHS && RHS &&
+             "Expected reduction operation.");
+      return 2;
+    }
+    /// Expected number of uses for reduction operations/reduced values.
+    unsigned getRequiredNumberOfUses() const {
+      assert(!IsReducedValue && !!*this && LHS && RHS &&
+             "Expected reduction operation.");
+      return 1;
+    }
+    /// Checks if instruction is associative and can be vectorized.
+    bool isAssociative(Instruction *I) const {
+      assert(!IsReducedValue && *this && LHS && RHS &&
+             "Expected reduction operation.");
+      return I->isAssociative();
+    }
+    /// Checks if the reduction operation can be vectorized.
+    bool isVectorizable(Instruction *I) const {
+      return isVectorizable() && isAssociative(I);
+    }
+
+    /// Checks if two operation data are both a reduction op or both a reduced
+    /// value.
+    bool operator==(const OperationData &OD) {
+      assert((IsReducedValue != OD.IsReducedValue) ||
+             ((!LHS == !OD.LHS) && (!RHS == !OD.RHS)) &&
+                 "One of the comparing operations is incorrect.");
+      return this == &OD ||
+             (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode);
+    }
+    bool operator!=(const OperationData &OD) { return !(*this == OD); }
+    void clear() {
+      IsReducedValue = false;
+      Opcode = 0;
+      LHS = nullptr;
+      RHS = nullptr;
+    }
+    /// Get the opcode of the reduction operation.
+    unsigned getOpcode() const {
+      assert(isVectorizable() && "Expected vectorizable operation.");
+      return Opcode;
+    }
+    Value *getLHS() const { return LHS; }
+    Value *getRHS() const { return RHS; }
+    /// Creates reduction operation with the current opcode.
+    Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const {
+      assert(!IsReducedValue &&
+             (Opcode == Instruction::FAdd || Opcode == Instruction::Add) &&
+             "Expected add|fadd reduction operation.");
+      return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
+                                 Name);
+    }
+  };
 
-  /// The opcode of the reduction.
-  Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd;
-  /// The opcode of the values we perform a reduction on.
-  unsigned ReducedValueOpcode = 0;
+  Instruction *ReductionRoot = nullptr;
+
+  /// The operation data of the reduction operation.
+  OperationData ReductionData;
+  /// The operation data of the values we perform a reduction on.
+  OperationData ReducedValueData;
   /// Should we model this reduction as a pairwise reduction tree or a tree that
   /// splits the vector in halves and adds those halves.
   bool IsPairwiseReduction = false;
@@ -4351,55 +4445,65 @@ class HorizontalReduction {
     }
   }
 
+  static OperationData getOperationData(Value *V) {
+    if (!V)
+      return OperationData();
+
+    Value *LHS;
+    Value *RHS;
+    if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V))
+      return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS);
+    return OperationData(V);
+  }
+
 public:
   HorizontalReduction() = default;
 
   /// \brief Try to find a reduction tree.
-  bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) {
+  bool matchAssociativeReduction(PHINode *Phi, Instruction *B) {
     assert((!Phi || is_contained(Phi->operands(), B)) &&
            "Thi phi needs to use the binary operator");
 
+    ReductionData = getOperationData(B);
+
     // We could have a initial reductions that is not an add.
     //  r *= v1 + v2 + v3 + v4
     // In such a case start looking for a tree rooted in the first '+'.
     if (Phi) {
-      if (B->getOperand(0) == Phi) {
+      if (ReductionData.getLHS() == Phi) {
         Phi = nullptr;
-        B = dyn_cast<BinaryOperator>(B->getOperand(1));
-      } else if (B->getOperand(1) == Phi) {
+        B = dyn_cast<Instruction>(ReductionData.getRHS());
+        ReductionData = getOperationData(B);
+      } else if (ReductionData.getRHS() == Phi) {
         Phi = nullptr;
-        B = dyn_cast<BinaryOperator>(B->getOperand(0));
+        B = dyn_cast<Instruction>(ReductionData.getLHS());
+        ReductionData = getOperationData(B);
       }
     }
 
-    if (!B)
+    if (!ReductionData.isVectorizable(B))
       return false;
 
     Type *Ty = B->getType();
     if (!isValidElementType(Ty))
       return false;
 
-    ReductionOpcode = B->getOpcode();
-    ReducedValueOpcode = 0;
+    ReducedValueData.clear();
     ReductionRoot = B;
 
-    // We currently only support adds.
-    if ((ReductionOpcode != Instruction::Add &&
-         ReductionOpcode != Instruction::FAdd) ||
-        !B->isAssociative())
-      return false;
-
     // Post order traverse the reduction tree starting at B. We only handle true
-    // trees containing only binary operators or selects.
+    // trees containing only binary operators.
     SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
-    Stack.push_back(std::make_pair(B, 0));
+    Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex()));
+    const unsigned NUses = ReductionData.getRequiredNumberOfUses();
     while (!Stack.empty()) {
       Instruction *TreeN = Stack.back().first;
       unsigned EdgeToVist = Stack.back().second++;
-      bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode;
+      OperationData OpData = getOperationData(TreeN);
+      bool IsReducedValue = OpData != ReductionData;
 
       // Postorder vist.
-      if (EdgeToVist == 2 || IsReducedValue) {
+      if (IsReducedValue || EdgeToVist == OpData.getNumberOfOperands()) {
         if (IsReducedValue)
           ReducedVals.push_back(TreeN);
         else {
@@ -4428,12 +4532,13 @@ public:
       Value *NextV = TreeN->getOperand(EdgeToVist);
       if (NextV != Phi) {
         auto *I = dyn_cast<Instruction>(NextV);
+        OpData = getOperationData(I);
         // Continue analysis if the next operand is a reduction operation or
         // (possibly) a reduced value. If the reduced value opcode is not set,
         // the first met operation != reduction operation is considered as the
         // reduced value class.
-        if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode ||
-                  I->getOpcode() == ReductionOpcode)) {
+        if (I && (!ReducedValueData || OpData == ReducedValueData ||
+                  OpData == ReductionData)) {
           // Only handle trees in the current basic block.
           if (I->getParent() != B->getParent()) {
             // I is an extra argument for TreeN (its parent operation).
@@ -4441,32 +4546,32 @@ public:
             continue;
           }
 
-          // Each tree node needs to have one user except for the ultimate
-          // reduction.
-          if (!I->hasOneUse() && I != B) {
+          // Each tree node needs to have minimal number of users except for the
+          // ultimate reduction.
+          if (!I->hasNUses(NUses) && I != B) {
             // I is an extra argument for TreeN (its parent operation).
             markExtraArg(Stack.back(), I);
             continue;
           }
 
-          if (I->getOpcode() == ReductionOpcode) {
+          if (OpData == ReductionData) {
             // We need to be able to reassociate the reduction operations.
-            if (!I->isAssociative()) {
+            if (!OpData.isAssociative(I)) {
               // I is an extra argument for TreeN (its parent operation).
               markExtraArg(Stack.back(), I);
               continue;
             }
-          } else if (ReducedValueOpcode &&
-                     ReducedValueOpcode != I->getOpcode()) {
+          } else if (ReducedValueData &&
+                     ReducedValueData != OpData) {
             // Make sure that the opcodes of the operations that we are going to
             // reduce match.
             // I is an extra argument for TreeN (its parent operation).
             markExtraArg(Stack.back(), I);
             continue;
-          } else if (!ReducedValueOpcode)
-            ReducedValueOpcode = I->getOpcode();
+          } else if (!ReducedValueData)
+            ReducedValueData = OpData;
 
-          Stack.push_back(std::make_pair(I, 0));
+          Stack.push_back(std::make_pair(I, OpData.getFirstOperandIndex()));
           continue;
         }
       }
@@ -4539,8 +4644,9 @@ public:
           emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps, TTI);
       if (VectorizedTree) {
         Builder.SetCurrentDebugLocation(Loc);
-        VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree,
-                                             ReducedSubTree, "bin.rdx");
+        OperationData VectReductionData(ReductionData.getOpcode(),
+                                        VectorizedTree, ReducedSubTree);
+        VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx");
         propagateIRFlags(VectorizedTree, ReductionOps);
       } else
         VectorizedTree = ReducedSubTree;
@@ -4553,8 +4659,9 @@ public:
       for (; i < NumReducedVals; ++i) {
         auto *I = cast<Instruction>(ReducedVals[i]);
         Builder.SetCurrentDebugLocation(I->getDebugLoc());
-        VectorizedTree =
-            Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I);
+        OperationData VectReductionData(ReductionData.getOpcode(),
+                                        VectorizedTree, I);
+        VectorizedTree = VectReductionData.createOp(Builder);
         propagateIRFlags(VectorizedTree, ReductionOps);
       }
       for (auto &Pair : ExternallyUsedValues) {
@@ -4563,8 +4670,9 @@ public:
         // Add each externally used value to the final reduction.
         for (auto *I : Pair.second) {
           Builder.SetCurrentDebugLocation(I->getDebugLoc());
-          VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree,
-                                               Pair.first, "bin.extra");
+          OperationData VectReductionData(ReductionData.getOpcode(),
+                                          VectorizedTree, Pair.first);
+          VectorizedTree = VectReductionData.createOp(Builder, "bin.extra");
           propagateIRFlags(VectorizedTree, I);
         }
       }
@@ -4586,16 +4694,18 @@ private:
     Type *VecTy = VectorType::get(ScalarTy, ReduxWidth);
 
     int PairwiseRdxCost =
-        TTI->getArithmeticReductionCost(ReductionOpcode, VecTy, true);
+        TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
+                                        /*IsPairwiseForm=*/true);
     int SplittingRdxCost =
-        TTI->getArithmeticReductionCost(ReductionOpcode, VecTy, false);
+        TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
+                                        /*IsPairwiseForm=*/false);
 
     IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost;
     int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost;
 
     int ScalarReduxCost =
         (ReduxWidth - 1) *
-        TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy);
+        TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
 
     DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost
                  << " for reduction that starts with " << *FirstReducedVal
@@ -4616,7 +4726,7 @@ private:
 
     if (!IsPairwiseReduction)
       return createSimpleTargetReduction(
-          Builder, TTI, ReductionOpcode, VectorizedValue,
+          Builder, TTI, ReductionData.getOpcode(), VectorizedValue,
           TargetTransformInfo::ReductionFlags(), RedOps);
 
     Value *TmpVec = VectorizedValue;
@@ -4631,8 +4741,9 @@ private:
       Value *RightShuf = Builder.CreateShuffleVector(
           TmpVec, UndefValue::get(TmpVec->getType()), (RightMask),
           "rdx.shuf.r");
-      TmpVec =
-          Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, "bin.rdx");
+      OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
+                                      RightShuf);
+      TmpVec = VectReductionData.createOp(Builder, "bin.rdx");
       propagateIRFlags(TmpVec, RedOps);
     }
 




More information about the llvm-commits mailing list