[llvm] 6189a8d - [TTI] add wrapper for matching vector reduction to reduce code duplication; NFC

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 23 10:49:18 PDT 2020


Author: Sanjay Patel
Date: 2020-09-23T13:48:57-04:00
New Revision: 6189a8d9f56ac9434eac94d6c515d3e460fdecd0

URL: https://github.com/llvm/llvm-project/commit/6189a8d9f56ac9434eac94d6c515d3e460fdecd0
DIFF: https://github.com/llvm/llvm-project/commit/6189a8d9f56ac9434eac94d6c515d3e460fdecd0.diff

LOG: [TTI] add wrapper for matching vector reduction to reduce code duplication; NFC

I'm not sure what this means, but the order in which we try
the matches makes a difference on at least 1 regression test...

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 297cc8e55fb4..b7bcc34b3b3b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -876,6 +876,10 @@ class TargetTransformInfo {
   static ReductionKind matchVectorSplittingReduction(
     const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty);
 
+  static ReductionKind matchVectorReduction(const ExtractElementInst *ReduxRoot,
+                                            unsigned &Opcode, VectorType *&Ty,
+                                            bool &IsPairwise);
+
   /// Additional information about an operand's possible values.
   enum OperandValueKind {
     OK_AnyValue,               // Operand can have any value.

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index ebd1beb6e39e..22708d073a1b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -1004,41 +1004,23 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
       if (CI)
         Idx = CI->getZExtValue();
 
-      // Try to match a reduction sequence (series of shufflevector and
-      // vector  adds followed by a extractelement).
-      unsigned ReduxOpCode;
-      VectorType *ReduxType;
-
-      switch (TTI::matchVectorSplittingReduction(EEI, ReduxOpCode,
-                                                 ReduxType)) {
-      case TTI::RK_Arithmetic:
-        return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
-                                          /*IsPairwiseForm=*/false,
-                                          CostKind);
-      case TTI::RK_MinMax:
-        return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/false, /*IsUnsigned=*/false, CostKind);
-      case TTI::RK_UnsignedMinMax:
-        return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/false, /*IsUnsigned=*/true, CostKind);
-      case TTI::RK_None:
-        break;
-      }
-
-      switch (TTI::matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+      // Try to match a reduction (a series of shufflevector and vector ops
+      // followed by an extractelement).
+      unsigned RdxOpcode;
+      VectorType *RdxType;
+      bool IsPairwise;
+      switch (TTI::matchVectorReduction(EEI, RdxOpcode, RdxType, IsPairwise)) {
       case TTI::RK_Arithmetic:
-        return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
-                                          /*IsPairwiseForm=*/true, CostKind);
+        return TargetTTI->getArithmeticReductionCost(RdxOpcode, RdxType,
+                                                     IsPairwise, CostKind);
       case TTI::RK_MinMax:
         return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/true, /*IsUnsigned=*/false, CostKind);
+            RdxType, cast<VectorType>(CmpInst::makeCmpResultType(RdxType)),
+            IsPairwise, /*IsUnsigned=*/false, CostKind);
       case TTI::RK_UnsignedMinMax:
         return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/true, /*IsUnsigned=*/true, CostKind);
+            RdxType, cast<VectorType>(CmpInst::makeCmpResultType(RdxType)),
+            IsPairwise, /*IsUnsigned=*/true, CostKind);
       case TTI::RK_None:
         break;
       }

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index efecb4501853..4836d80ddb2d 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1308,6 +1308,18 @@ TTI::ReductionKind TTI::matchVectorSplittingReduction(
   return RD->Kind;
 }
 
+TTI::ReductionKind
+TTI::matchVectorReduction(const ExtractElementInst *Root, unsigned &Opcode,
+                          VectorType *&Ty, bool &IsPairwise) {
+  TTI::ReductionKind RdxKind = matchVectorSplittingReduction(Root, Opcode, Ty);
+  if (RdxKind != TTI::ReductionKind::RK_None) {
+    IsPairwise = false;
+    return RdxKind;
+  }
+  IsPairwise = true;
+  return matchPairwiseReduction(Root, Opcode, Ty);
+}
+
 int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
 


        


More information about the llvm-commits mailing list