[llvm-branch-commits] [llvm] 8590d24 - [SLP] move reduction createOp functions; NFC

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 20 08:20:37 PST 2021


Author: Sanjay Patel
Date: 2021-01-20T11:14:48-05:00
New Revision: 8590d245434dd4205c89f0a05b4c22feccb7421c

URL: https://github.com/llvm/llvm-project/commit/8590d245434dd4205c89f0a05b4c22feccb7421c
DIFF: https://github.com/llvm/llvm-project/commit/8590d245434dd4205c89f0a05b4c22feccb7421c.diff

LOG: [SLP] move reduction createOp functions; NFC

We were able to remove almost all of the state from
OperationData, so these don't make sense as members
of that class - just pass the RecurKind in as a param.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 24885e4d8257..3d657b0b898c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6397,7 +6397,7 @@ namespace {
 class HorizontalReduction {
   using ReductionOpsType = SmallVector<Value *, 16>;
   using ReductionOpsListType = SmallVector<ReductionOpsType, 2>;
-  ReductionOpsListType  ReductionOps;
+  ReductionOpsListType ReductionOps;
   SmallVector<Value *, 32> ReducedVals;
   // Use map vector to make stable output.
   MapVector<Instruction *, Value *> ExtraArgs;
@@ -6412,47 +6412,6 @@ class HorizontalReduction {
     /// Checks if the reduction operation can be vectorized.
     bool isVectorizable() const { return Kind != RecurKind::None; }
 
-    /// Creates reduction operation with the current opcode.
-    Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
-                    const Twine &Name) const {
-      assert(isVectorizable() && "Unhandled reduction operation.");
-      unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
-      switch (Kind) {
-      case RecurKind::Add:
-      case RecurKind::Mul:
-      case RecurKind::Or:
-      case RecurKind::And:
-      case RecurKind::Xor:
-      case RecurKind::FAdd:
-      case RecurKind::FMul:
-        return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
-                                   Name);
-      case RecurKind::FMax:
-        return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
-      case RecurKind::FMin:
-        return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
-
-      case RecurKind::SMax: {
-        Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      case RecurKind::SMin: {
-        Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      case RecurKind::UMax: {
-        Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      case RecurKind::UMin: {
-        Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      default:
-        llvm_unreachable("Unknown reduction operation.");
-      }
-    }
-
   public:
     explicit OperationData() = default;
 
@@ -6580,40 +6539,6 @@ class HorizontalReduction {
         return nullptr;
       return I->getOperand(getFirstOperandIndex() + 1);
     }
-
-    /// Creates reduction operation with the current opcode with the IR flags
-    /// from \p ReductionOps.
-    Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
-                    const Twine &Name,
-                    const ReductionOpsListType &ReductionOps) const {
-      assert(isVectorizable() &&
-             "Expected add|fadd or min/max reduction operation.");
-      Value *Op = createOp(Builder, LHS, RHS, Name);
-      if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind)) {
-        if (auto *Sel = dyn_cast<SelectInst>(Op))
-          propagateIRFlags(Sel->getCondition(), ReductionOps[0]);
-        propagateIRFlags(Op, ReductionOps[1]);
-        return Op;
-      }
-      propagateIRFlags(Op, ReductionOps[0]);
-      return Op;
-    }
-    /// Creates reduction operation with the current opcode with the IR flags
-    /// from \p I.
-    Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
-                    const Twine &Name, Instruction *I) const {
-      assert(isVectorizable() &&
-             "Expected add|fadd or min/max reduction operation.");
-      Value *Op = createOp(Builder, LHS, RHS, Name);
-      if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind)) {
-        if (auto *Sel = dyn_cast<SelectInst>(Op)) {
-          propagateIRFlags(Sel->getCondition(),
-                           cast<SelectInst>(I)->getCondition());
-        }
-      }
-      propagateIRFlags(Op, I);
-      return Op;
-    }
   };
 
   WeakTrackingVH ReductionRoot;
@@ -6642,6 +6567,76 @@ class HorizontalReduction {
     }
   }
 
+  /// Creates reduction operation with the current opcode.
+  static Value *createOp(IRBuilder<> &Builder, RecurKind Kind, Value *LHS,
+                         Value *RHS, const Twine &Name) {
+    unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+    switch (Kind) {
+    case RecurKind::Add:
+    case RecurKind::Mul:
+    case RecurKind::Or:
+    case RecurKind::And:
+    case RecurKind::Xor:
+    case RecurKind::FAdd:
+    case RecurKind::FMul:
+      return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
+                                 Name);
+    case RecurKind::FMax:
+      return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
+    case RecurKind::FMin:
+      return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
+
+    case RecurKind::SMax: {
+      Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
+      return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+    }
+    case RecurKind::SMin: {
+      Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name);
+      return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+    }
+    case RecurKind::UMax: {
+      Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name);
+      return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+    }
+    case RecurKind::UMin: {
+      Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name);
+      return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+    }
+    default:
+      llvm_unreachable("Unknown reduction operation.");
+    }
+  }
+
+  /// Creates reduction operation with the current opcode with the IR flags
+  /// from \p ReductionOps.
+  static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS,
+                         Value *RHS, const Twine &Name,
+                         const ReductionOpsListType &ReductionOps) {
+    Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name);
+    if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) {
+      if (auto *Sel = dyn_cast<SelectInst>(Op))
+        propagateIRFlags(Sel->getCondition(), ReductionOps[0]);
+      propagateIRFlags(Op, ReductionOps[1]);
+      return Op;
+    }
+    propagateIRFlags(Op, ReductionOps[0]);
+    return Op;
+  }
+  /// Creates reduction operation with the current opcode with the IR flags
+  /// from \p I.
+  static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS,
+                         Value *RHS, const Twine &Name, Instruction *I) {
+    Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name);
+    if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) {
+      if (auto *Sel = dyn_cast<SelectInst>(Op)) {
+        propagateIRFlags(Sel->getCondition(),
+                         cast<SelectInst>(I)->getCondition());
+      }
+    }
+    propagateIRFlags(Op, I);
+    return Op;
+  }
+
   static OperationData getOperationData(Instruction *I) {
     if (!I)
       return OperationData();
@@ -6995,8 +6990,9 @@ class HorizontalReduction {
       } else {
         // Update the final value in the reduction.
         Builder.SetCurrentDebugLocation(Loc);
-        VectorizedTree = RdxTreeInst.createOp(
-            Builder, VectorizedTree, ReducedSubTree, "op.rdx", ReductionOps);
+        VectorizedTree =
+            createOp(Builder, RdxTreeInst.getKind(), VectorizedTree,
+                     ReducedSubTree, "op.rdx", ReductionOps);
       }
       i += ReduxWidth;
       ReduxWidth = PowerOf2Floor(NumReducedVals - i);
@@ -7007,15 +7003,15 @@ class HorizontalReduction {
       for (; i < NumReducedVals; ++i) {
         auto *I = cast<Instruction>(ReducedVals[i]);
         Builder.SetCurrentDebugLocation(I->getDebugLoc());
-        VectorizedTree = RdxTreeInst.createOp(Builder, VectorizedTree, I, "",
-                                              ReductionOps);
+        VectorizedTree = createOp(Builder, RdxTreeInst.getKind(),
+                                  VectorizedTree, I, "", ReductionOps);
       }
       for (auto &Pair : ExternallyUsedValues) {
         // Add each externally used value to the final reduction.
         for (auto *I : Pair.second) {
           Builder.SetCurrentDebugLocation(I->getDebugLoc());
-          VectorizedTree = RdxTreeInst.createOp(Builder, VectorizedTree,
-                                                Pair.first, "op.extra", I);
+          VectorizedTree = createOp(Builder, RdxTreeInst.getKind(),
+                                    VectorizedTree, Pair.first, "op.extra", I);
         }
       }
 
@@ -7039,9 +7035,7 @@ class HorizontalReduction {
     return VectorizedTree != nullptr;
   }
 
-  unsigned numReductionValues() const {
-    return ReducedVals.size();
-  }
+  unsigned numReductionValues() const { return ReducedVals.size(); }
 
 private:
   /// Calculate the cost of a reduction.
@@ -7062,7 +7056,7 @@ class HorizontalReduction {
     case RecurKind::FMul: {
       unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
       VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
-                                                      /*IsPairwiseForm=*/false);
+                                                   /*IsPairwiseForm=*/false);
       ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy);
       break;
     }


        


More information about the llvm-branch-commits mailing list