[llvm] 03f22b0 - [SLP] Remove LHS and RHS from OperationData.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 24 10:57:48 PDT 2020


Author: Craig Topper
Date: 2020-09-24T10:57:11-07:00
New Revision: 03f22b08e2a387a415dcbb3cf021e41e629c3d34

URL: https://github.com/llvm/llvm-project/commit/03f22b08e2a387a415dcbb3cf021e41e629c3d34
DIFF: https://github.com/llvm/llvm-project/commit/03f22b08e2a387a415dcbb3cf021e41e629c3d34.diff

LOG: [SLP] Remove LHS and RHS from OperationData.

These were only really used for 2 things. One was to check if the operand matches the phi if it exists. The other was for the createOp method to build the reduction.

For the first case we still have the operation we just need to know how to index its operands. So I've modified getLHS/getRHS to just use the opcode/kind to know how to find the right operands on an instruction that is now passed in.

For the other case we had to create an OperationData object to set the LHS/RHS values and copy the opcode/kind from another object. We would then just call createOp on that temporary object. Instead I've made LHS/RHS arguments to createOp and removed all these temporary objects.

Differential Revision: https://reviews.llvm.org/D88193

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c19a7b5c56ae..8722ff9ba95c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6312,20 +6312,13 @@ class HorizontalReduction {
     /// Opcode of the instruction.
     unsigned Opcode = 0;
 
-    /// Left operand of the reduction operation.
-    Value *LHS = nullptr;
-
-    /// Right operand of the reduction operation.
-    Value *RHS = nullptr;
-
     /// Kind of the reduction operation.
     ReductionKind Kind = RK_None;
 
     /// Checks if the reduction operation can be vectorized.
     bool isVectorizable() const {
-      return LHS && RHS &&
-             // We currently only support add/mul/logical && min/max reductions.
-             ((Kind == RK_Arithmetic &&
+      // We currently only support add/mul/logical && min/max reductions.
+      return ((Kind == RK_Arithmetic &&
                (Opcode == Instruction::Add || Opcode == Instruction::FAdd ||
                 Opcode == Instruction::Mul || Opcode == Instruction::FMul ||
                 Opcode == Instruction::And || Opcode == Instruction::Or ||
@@ -6336,7 +6329,8 @@ class HorizontalReduction {
     }
 
     /// Creates reduction operation with the current opcode.
-    Value *createOp(IRBuilder<> &Builder, const Twine &Name) const {
+    Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
+                    const Twine &Name) const {
       assert(isVectorizable() &&
              "Expected add|fadd or min/max reduction operation.");
       Value *Cmp = nullptr;
@@ -6377,8 +6371,8 @@ class HorizontalReduction {
 
     /// Constructor for reduction operations with opcode and its left and
     /// right operands.
-    OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind)
-        : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
+    OperationData(unsigned Opcode, ReductionKind Kind)
+        : Opcode(Opcode), Kind(Kind) {
       assert(Kind != RK_None && "One of the reduction operations is expected.");
     }
 
@@ -6411,16 +6405,14 @@ class HorizontalReduction {
 
     /// Total number of operands in the reduction operation.
     unsigned getNumberOfOperands() const {
-      assert(Kind != RK_None && !!*this && LHS && RHS &&
-             "Expected reduction operation.");
+      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
       return isMinMax() ? 3 : 2;
     }
 
     /// Checks if the instruction is in basic block \p BB.
     /// For a min/max reduction check that both compare and select are in \p BB.
     bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const {
-      assert(Kind != RK_None && !!*this && LHS && RHS &&
-             "Expected reduction operation.");
+      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
       if (IsRedOp && isMinMax()) {
         auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
         return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
@@ -6430,8 +6422,7 @@ class HorizontalReduction {
 
     /// Expected number of uses for reduction operations/reduced values.
     bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
-      assert(Kind != RK_None && !!*this && LHS && RHS &&
-             "Expected reduction operation.");
+      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
       // SelectInst must be used twice while the condition op must have single
       // use only.
       if (isMinMax())
@@ -6445,8 +6436,7 @@ class HorizontalReduction {
 
     /// Initializes the list of reduction operations.
     void initReductionOps(ReductionOpsListType &ReductionOps) {
-      assert(Kind != RK_None && !!*this && LHS && RHS &&
-             "Expected reduction operation.");
+      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
       if (isMinMax())
         ReductionOps.assign(2, ReductionOpsType());
       else
@@ -6455,8 +6445,7 @@ class HorizontalReduction {
 
     /// Add all reduction operations for the reduction instruction \p I.
     void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
-      assert(Kind != RK_None && !!*this && LHS && RHS &&
-             "Expected reduction operation.");
+      assert(Kind != RK_None && !!*this && "Expected reduction operation.");
       if (isMinMax()) {
         ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
         ReductionOps[1].emplace_back(I);
@@ -6467,8 +6456,7 @@ class HorizontalReduction {
 
     /// Checks if instruction is associative and can be vectorized.
     bool isAssociative(Instruction *I) const {
-      assert(Kind != RK_None && *this && LHS && RHS &&
-             "Expected reduction operation.");
+      assert(Kind != RK_None && *this && "Expected reduction operation.");
       switch (Kind) {
       case RK_Arithmetic:
         return I->isAssociative();
@@ -6493,15 +6481,13 @@ class HorizontalReduction {
     /// Checks if two operation data are both a reduction op or both a reduced
     /// value.
     bool operator==(const OperationData &OD) const {
-      assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
+      assert(((Kind != OD.Kind) || (Opcode != 0 && OD.Opcode != 0)) &&
              "One of the comparing operations is incorrect.");
-      return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode);
+      return Kind == OD.Kind && Opcode == OD.Opcode;
     }
     bool operator!=(const OperationData &OD) const { return !(*this == OD); }
     void clear() {
       Opcode = 0;
-      LHS = nullptr;
-      RHS = nullptr;
       Kind = RK_None;
     }
 
@@ -6513,19 +6499,25 @@ class HorizontalReduction {
 
     /// Get kind of reduction data.
     ReductionKind getKind() const { return Kind; }
-    Value *getLHS() const { return LHS; }
-    Value *getRHS() const { return RHS; }
-    Type *getConditionType() const {
-      return isMinMax() ? CmpInst::makeCmpResultType(LHS->getType()) : nullptr;
+    Value *getLHS(Instruction *I) const {
+      if (Kind == RK_None)
+        return nullptr;
+      return I->getOperand(getFirstOperandIndex());
+    }
+    Value *getRHS(Instruction *I) const {
+      if (Kind == RK_None)
+        return nullptr;
+      return I->getOperand(getFirstOperandIndex() + 1);
     }
 
     /// Creates reduction operation with the current opcode with the IR flags
     /// from \p ReductionOps.
-    Value *createOp(IRBuilder<> &Builder, const Twine &Name,
+    Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
+                    const Twine &Name,
                     const ReductionOpsListType &ReductionOps) const {
       assert(isVectorizable() &&
              "Expected add|fadd or min/max reduction operation.");
-      auto *Op = createOp(Builder, Name);
+      auto *Op = createOp(Builder, LHS, RHS, Name);
       switch (Kind) {
       case RK_Arithmetic:
         propagateIRFlags(Op, ReductionOps[0]);
@@ -6545,11 +6537,11 @@ class HorizontalReduction {
     }
     /// Creates reduction operation with the current opcode with the IR flags
     /// from \p I.
-    Value *createOp(IRBuilder<> &Builder, const Twine &Name,
-                    Instruction *I) const {
+    Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
+                    const Twine &Name, Instruction *I) const {
       assert(isVectorizable() &&
              "Expected add|fadd or min/max reduction operation.");
-      auto *Op = createOp(Builder, Name);
+      auto *Op = createOp(Builder, LHS, RHS, Name);
       switch (Kind) {
       case RK_Arithmetic:
         propagateIRFlags(Op, I);
@@ -6637,19 +6629,18 @@ class HorizontalReduction {
     Value *LHS;
     Value *RHS;
     if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(I)) {
-      return OperationData(cast<BinaryOperator>(I)->getOpcode(), LHS, RHS,
-                           RK_Arithmetic);
+      return OperationData(cast<BinaryOperator>(I)->getOpcode(), RK_Arithmetic);
     }
     if (auto *Select = dyn_cast<SelectInst>(I)) {
       // Look for a min/max pattern.
       if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
+        return OperationData(Instruction::ICmp, RK_UMin);
       } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin);
+        return OperationData(Instruction::ICmp, RK_SMin);
       } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
+        return OperationData(Instruction::ICmp, RK_UMax);
       } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
-        return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax);
+        return OperationData(Instruction::ICmp, RK_SMax);
       } else {
         // Try harder: look for min/max pattern based on instructions producing
         // same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2).
@@ -6693,19 +6684,19 @@ class HorizontalReduction {
 
         case CmpInst::ICMP_ULT:
         case CmpInst::ICMP_ULE:
-          return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
+          return OperationData(Instruction::ICmp, RK_UMin);
 
         case CmpInst::ICMP_SLT:
         case CmpInst::ICMP_SLE:
-          return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin);
+          return OperationData(Instruction::ICmp, RK_SMin);
 
         case CmpInst::ICMP_UGT:
         case CmpInst::ICMP_UGE:
-          return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
+          return OperationData(Instruction::ICmp, RK_UMax);
 
         case CmpInst::ICMP_SGT:
         case CmpInst::ICMP_SGE:
-          return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax);
+          return OperationData(Instruction::ICmp, RK_SMax);
         }
       }
     }
@@ -6726,13 +6717,13 @@ class HorizontalReduction {
     //  r *= v1 + v2 + v3 + v4
     // In such a case start looking for a tree rooted in the first '+'.
     if (Phi) {
-      if (ReductionData.getLHS() == Phi) {
+      if (ReductionData.getLHS(B) == Phi) {
         Phi = nullptr;
-        B = dyn_cast<Instruction>(ReductionData.getRHS());
+        B = dyn_cast<Instruction>(ReductionData.getRHS(B));
         ReductionData = getOperationData(B);
-      } else if (ReductionData.getRHS() == Phi) {
+      } else if (ReductionData.getRHS(B) == Phi) {
         Phi = nullptr;
-        B = dyn_cast<Instruction>(ReductionData.getLHS());
+        B = dyn_cast<Instruction>(ReductionData.getLHS(B));
         ReductionData = getOperationData(B);
       }
     }
@@ -6984,11 +6975,8 @@ class HorizontalReduction {
       } else {
         // Update the final value in the reduction.
         Builder.SetCurrentDebugLocation(Loc);
-        OperationData VectReductionData(ReductionData.getOpcode(),
-                                        VectorizedTree, ReducedSubTree,
-                                        ReductionData.getKind());
-        VectorizedTree =
-            VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
+        VectorizedTree = ReductionData.createOp(
+            Builder, VectorizedTree, ReducedSubTree, "op.rdx", ReductionOps);
       }
       i += ReduxWidth;
       ReduxWidth = PowerOf2Floor(NumReducedVals - i);
@@ -6999,19 +6987,15 @@ class HorizontalReduction {
       for (; i < NumReducedVals; ++i) {
         auto *I = cast<Instruction>(ReducedVals[i]);
         Builder.SetCurrentDebugLocation(I->getDebugLoc());
-        OperationData VectReductionData(ReductionData.getOpcode(),
-                                        VectorizedTree, I,
-                                        ReductionData.getKind());
-        VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps);
+        VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, I, "",
+                                                ReductionOps);
       }
       for (auto &Pair : ExternallyUsedValues) {
         // Add each externally used value to the final reduction.
         for (auto *I : Pair.second) {
           Builder.SetCurrentDebugLocation(I->getDebugLoc());
-          OperationData VectReductionData(ReductionData.getOpcode(),
-                                          VectorizedTree, Pair.first,
-                                          ReductionData.getKind());
-          VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I);
+          VectorizedTree = ReductionData.createOp(Builder, VectorizedTree,
+                                                  Pair.first, "op.extra", I);
         }
       }
 
@@ -7133,9 +7117,8 @@ class HorizontalReduction {
           Builder.CreateShuffleVector(TmpVec, LeftMask, "rdx.shuf.l");
       Value *RightShuf =
           Builder.CreateShuffleVector(TmpVec, RightMask, "rdx.shuf.r");
-      OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
-                                      RightShuf, ReductionData.getKind());
-      TmpVec = VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
+      TmpVec = ReductionData.createOp(Builder, LeftShuf, RightShuf, "op.rdx",
+                                      ReductionOps);
     }
 
     // The result is in the first element of the vector.


        


More information about the llvm-commits mailing list