[llvm] 51541c0 - [CostModel] Unify ExtractElement cost.

Sam Parker via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 15 00:28:01 PDT 2020


Author: Sam Parker
Date: 2020-06-15T08:27:14+01:00
New Revision: 51541c068a83355720e54509a3fca0c2558ee555

URL: https://github.com/llvm/llvm-project/commit/51541c068a83355720e54509a3fca0c2558ee555
DIFF: https://github.com/llvm/llvm-project/commit/51541c068a83355720e54509a3fca0c2558ee555.diff

LOG: [CostModel] Unify ExtractElement cost.

Move the cost modelling, with the reduction pattern matching, from
getInstructionThroughput into generic TTIImpl::getUserCost. The
modelling in the AMDGPU backend can now be removed.

Differential Revision: https://reviews.llvm.org/D81643

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 06051552b9d5..247b8a6c1991 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -39,6 +39,7 @@ class BlockFrequencyInfo;
 class DominatorTree;
 class BranchInst;
 class CallBase;
+class ExtractElementInst;
 class Function;
 class GlobalValue;
 class IntrinsicInst;
@@ -805,6 +806,36 @@ class TargetTransformInfo {
                          ///< shuffle mask.
   };
 
+  /// Kind of the reduction data.
+  enum ReductionKind {
+    RK_None,           /// Not a reduction.
+    RK_Arithmetic,     /// Binary reduction data.
+    RK_MinMax,         /// Min/max reduction data.
+    RK_UnsignedMinMax, /// Unsigned min/max reduction data.
+  };
+
+  /// Contains opcode + LHS/RHS parts of the reduction operations.
+  struct ReductionData {
+    ReductionData() = delete;
+    ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS)
+        : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
+      assert(Kind != RK_None && "expected binary or min/max reduction only.");
+    }
+    unsigned Opcode = 0;
+    Value *LHS = nullptr;
+    Value *RHS = nullptr;
+    ReductionKind Kind = RK_None;
+    bool hasSameData(ReductionData &RD) const {
+      return Kind == RD.Kind && Opcode == RD.Opcode;
+    }
+  };
+
+  static ReductionKind matchPairwiseReduction(
+    const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty);
+
+  static ReductionKind matchVectorSplittingReduction(
+    const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty);
+
   /// Additional information about an operand's possible values.
   enum OperandValueKind {
     OK_AnyValue,               // Operand can have any value.

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index cedffe7da4ba..3d0af73d445c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -944,6 +944,54 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
       return TargetTTI->getShuffleCost(TTI::SK_PermuteTwoSrc, VecTy, 0,
                                        nullptr);
     }
+    case Instruction::ExtractElement: {
+      unsigned Idx = -1;
+      auto *EEI = cast<ExtractElementInst>(U);
+      auto *CI = dyn_cast<ConstantInt>(EEI->getOperand(1));
+      if (CI)
+        Idx = CI->getZExtValue();
+
+      // Try to match a reduction sequence (series of shufflevector and
+      // vector  adds followed by a extractelement).
+      unsigned ReduxOpCode;
+      VectorType *ReduxType;
+
+      switch (TTI::matchVectorSplittingReduction(EEI, ReduxOpCode,
+                                                 ReduxType)) {
+      case TTI::RK_Arithmetic:
+        return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
+                                          /*IsPairwiseForm=*/false,
+                                          CostKind);
+      case TTI::RK_MinMax:
+        return TargetTTI->getMinMaxReductionCost(
+            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
+            /*IsPairwiseForm=*/false, /*IsUnsigned=*/false, CostKind);
+      case TTI::RK_UnsignedMinMax:
+        return TargetTTI->getMinMaxReductionCost(
+            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
+            /*IsPairwiseForm=*/false, /*IsUnsigned=*/true, CostKind);
+      case TTI::RK_None:
+        break;
+      }
+
+      switch (TTI::matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+      case TTI::RK_Arithmetic:
+        return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
+                                          /*IsPairwiseForm=*/true, CostKind);
+      case TTI::RK_MinMax:
+        return TargetTTI->getMinMaxReductionCost(
+            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
+            /*IsPairwiseForm=*/true, /*IsUnsigned=*/false, CostKind);
+      case TTI::RK_UnsignedMinMax:
+        return TargetTTI->getMinMaxReductionCost(
+            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
+            /*IsPairwiseForm=*/true, /*IsUnsigned=*/true, CostKind);
+      case TTI::RK_None:
+        break;
+      }
+      return TargetTTI->getVectorInstrCost(Opcode, U->getOperand(0)->getType(),
+                                           Idx);
+    }
     }
     // By default, just classify everything as 'basic'.
     return TTI::TCC_Basic;

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a862397a592c..12c5c84e5b0b 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -970,35 +970,10 @@ static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
   return Mask == ActualMask;
 }
 
-namespace {
-/// Kind of the reduction data.
-enum ReductionKind {
-  RK_None,           /// Not a reduction.
-  RK_Arithmetic,     /// Binary reduction data.
-  RK_MinMax,         /// Min/max reduction data.
-  RK_UnsignedMinMax, /// Unsigned min/max reduction data.
-};
-/// Contains opcode + LHS/RHS parts of the reduction operations.
-struct ReductionData {
-  ReductionData() = delete;
-  ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS)
-      : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
-    assert(Kind != RK_None && "expected binary or min/max reduction only.");
-  }
-  unsigned Opcode = 0;
-  Value *LHS = nullptr;
-  Value *RHS = nullptr;
-  ReductionKind Kind = RK_None;
-  bool hasSameData(ReductionData &RD) const {
-    return Kind == RD.Kind && Opcode == RD.Opcode;
-  }
-};
-} // namespace
-
-static Optional<ReductionData> getReductionData(Instruction *I) {
+static Optional<TTI::ReductionData> getReductionData(Instruction *I) {
   Value *L, *R;
   if (m_BinOp(m_Value(L), m_Value(R)).match(I))
-    return ReductionData(RK_Arithmetic, I->getOpcode(), L, R);
+    return TTI::ReductionData(TTI::RK_Arithmetic, I->getOpcode(), L, R);
   if (auto *SI = dyn_cast<SelectInst>(I)) {
     if (m_SMin(m_Value(L), m_Value(R)).match(SI) ||
         m_SMax(m_Value(L), m_Value(R)).match(SI) ||
@@ -1007,20 +982,20 @@ static Optional<ReductionData> getReductionData(Instruction *I) {
         m_UnordFMin(m_Value(L), m_Value(R)).match(SI) ||
         m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) {
       auto *CI = cast<CmpInst>(SI->getCondition());
-      return ReductionData(RK_MinMax, CI->getOpcode(), L, R);
+      return TTI::ReductionData(TTI::RK_MinMax, CI->getOpcode(), L, R);
     }
     if (m_UMin(m_Value(L), m_Value(R)).match(SI) ||
         m_UMax(m_Value(L), m_Value(R)).match(SI)) {
       auto *CI = cast<CmpInst>(SI->getCondition());
-      return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R);
+      return TTI::ReductionData(TTI::RK_UnsignedMinMax, CI->getOpcode(), L, R);
     }
   }
   return llvm::None;
 }
 
-static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
-                                                   unsigned Level,
-                                                   unsigned NumLevels) {
+static TTI::ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
+                                                        unsigned Level,
+                                                        unsigned NumLevels) {
   // Match one level of pairwise operations.
   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
@@ -1028,24 +1003,24 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
   if (!I)
-    return RK_None;
+    return TTI::RK_None;
 
   assert(I->getType()->isVectorTy() && "Expecting a vector type");
 
-  Optional<ReductionData> RD = getReductionData(I);
+  Optional<TTI::ReductionData> RD = getReductionData(I);
   if (!RD)
-    return RK_None;
+    return TTI::RK_None;
 
   ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS);
   if (!LS && Level)
-    return RK_None;
+    return TTI::RK_None;
   ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS);
   if (!RS && Level)
-    return RK_None;
+    return TTI::RK_None;
 
   // On level 0 we can omit one shufflevector instruction.
   if (!Level && !RS && !LS)
-    return RK_None;
+    return TTI::RK_None;
 
   // Shuffle inputs must match.
   Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
@@ -1054,7 +1029,7 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
   if (NextLevelOpR && NextLevelOpL) {
     // If we have two shuffles their operands must match.
     if (NextLevelOpL != NextLevelOpR)
-      return RK_None;
+      return TTI::RK_None;
 
     NextLevelOp = NextLevelOpL;
   } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
@@ -1065,32 +1040,32 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
     //  %NextLevelOpL = shufflevector %R, <1, undef ...>
     //  %BinOp        = fadd          %NextLevelOpL, %R
     if (NextLevelOpL && NextLevelOpL != RD->RHS)
-      return RK_None;
+      return TTI::RK_None;
     else if (NextLevelOpR && NextLevelOpR != RD->LHS)
-      return RK_None;
+      return TTI::RK_None;
 
     NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS;
   } else
-    return RK_None;
+    return TTI::RK_None;
 
   // Check that the next levels binary operation exists and matches with the
   // current one.
   if (Level + 1 != NumLevels) {
-    Optional<ReductionData> NextLevelRD =
+    Optional<TTI::ReductionData> NextLevelRD =
         getReductionData(cast<Instruction>(NextLevelOp));
     if (!NextLevelRD || !RD->hasSameData(*NextLevelRD))
-      return RK_None;
+      return TTI::RK_None;
   }
 
   // Shuffle mask for pairwise operation must match.
   if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) {
     if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level))
-      return RK_None;
+      return TTI::RK_None;
   } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) {
     if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level))
-      return RK_None;
+      return TTI::RK_None;
   } else {
-    return RK_None;
+    return TTI::RK_None;
   }
 
   if (++Level == NumLevels)
@@ -1101,11 +1076,10 @@ static ReductionKind matchPairwiseReductionAtLevel(Instruction *I,
                                        NumLevels);
 }
 
-static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
-                                            unsigned &Opcode,
-                                            VectorType *&Ty) {
+TTI::ReductionKind TTI::matchPairwiseReduction(
+  const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty) {
   if (!EnableReduxCost)
-    return RK_None;
+    return TTI::RK_None;
 
   // Need to extract the first element.
   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
@@ -1113,19 +1087,19 @@ static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
   if (CI)
     Idx = CI->getZExtValue();
   if (Idx != 0)
-    return RK_None;
+    return TTI::RK_None;
 
   auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
   if (!RdxStart)
-    return RK_None;
-  Optional<ReductionData> RD = getReductionData(RdxStart);
+    return TTI::RK_None;
+  Optional<TTI::ReductionData> RD = getReductionData(RdxStart);
   if (!RD)
-    return RK_None;
+    return TTI::RK_None;
 
   auto *VecTy = cast<VectorType>(RdxStart->getType());
   unsigned NumVecElems = VecTy->getNumElements();
   if (!isPowerOf2_32(NumVecElems))
-    return RK_None;
+    return TTI::RK_None;
 
   // We look for a sequence of shuffle,shuffle,add triples like the following
   // that builds a pairwise reduction tree.
@@ -1146,8 +1120,8 @@ static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
   // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
   // %r = extractelement <4 x float> %bin.rdx8, i32 0
   if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) ==
-      RK_None)
-    return RK_None;
+      TTI::RK_None)
+    return TTI::RK_None;
 
   Opcode = RD->Opcode;
   Ty = VecTy;
@@ -1166,11 +1140,11 @@ getShuffleAndOtherOprd(Value *L, Value *R) {
   return std::make_pair(L, S);
 }
 
-static ReductionKind
-matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
-                              unsigned &Opcode, VectorType *&Ty) {
+TTI::ReductionKind TTI::matchVectorSplittingReduction(
+  const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty) {
+
   if (!EnableReduxCost)
-    return RK_None;
+    return TTI::RK_None;
 
   // Need to extract the first element.
   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
@@ -1178,19 +1152,19 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
   if (CI)
     Idx = CI->getZExtValue();
   if (Idx != 0)
-    return RK_None;
+    return TTI::RK_None;
 
   auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0));
   if (!RdxStart)
-    return RK_None;
-  Optional<ReductionData> RD = getReductionData(RdxStart);
+    return TTI::RK_None;
+  Optional<TTI::ReductionData> RD = getReductionData(RdxStart);
   if (!RD)
-    return RK_None;
+    return TTI::RK_None;
 
   auto *VecTy = cast<VectorType>(ReduxRoot->getOperand(0)->getType());
   unsigned NumVecElems = VecTy->getNumElements();
   if (!isPowerOf2_32(NumVecElems))
-    return RK_None;
+    return TTI::RK_None;
 
   // We look for a sequence of shuffles and adds like the following matching one
   // fadd, shuffle vector pair at a time.
@@ -1210,10 +1184,10 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
   while (NumVecElemsRemain - 1) {
     // Check for the right reduction operation.
     if (!RdxOp)
-      return RK_None;
-    Optional<ReductionData> RDLevel = getReductionData(RdxOp);
+      return TTI::RK_None;
+    Optional<TTI::ReductionData> RDLevel = getReductionData(RdxOp);
     if (!RDLevel || !RDLevel->hasSameData(*RD))
-      return RK_None;
+      return TTI::RK_None;
 
     Value *NextRdxOp;
     ShuffleVectorInst *Shuffle;
@@ -1222,9 +1196,9 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
 
     // Check the current reduction operation and the shuffle use the same value.
     if (Shuffle == nullptr)
-      return RK_None;
+      return TTI::RK_None;
     if (Shuffle->getOperand(0) != NextRdxOp)
-      return RK_None;
+      return TTI::RK_None;
 
     // Check that shuffle masks matches.
     for (unsigned j = 0; j != MaskStart; ++j)
@@ -1234,7 +1208,7 @@ matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
 
     ArrayRef<int> Mask = Shuffle->getShuffleMask();
     if (ShuffleMask != Mask)
-      return RK_None;
+      return TTI::RK_None;
 
     RdxOp = dyn_cast<Instruction>(NextRdxOp);
     NumVecElemsRemain /= 2;
@@ -1291,10 +1265,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
   case Instruction::Select:
   case Instruction::ICmp:
   case Instruction::FCmp:
-    return getUserCost(I, CostKind);
   case Instruction::Store:
   case Instruction::Load:
-    return getUserCost(I, CostKind);
   case Instruction::ZExt:
   case Instruction::SExt:
   case Instruction::FPToUI:
@@ -1308,59 +1280,10 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
   case Instruction::FPTrunc:
   case Instruction::BitCast:
   case Instruction::AddrSpaceCast:
-    return getUserCost(I, CostKind);
-  case Instruction::ExtractElement: {
-    const ExtractElementInst *EEI = cast<ExtractElementInst>(I);
-    ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
-    unsigned Idx = -1;
-    if (CI)
-      Idx = CI->getZExtValue();
-
-    // Try to match a reduction sequence (series of shufflevector and vector
-    // adds followed by a extractelement).
-    unsigned ReduxOpCode;
-    VectorType *ReduxType;
-
-    switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) {
-    case RK_Arithmetic:
-      return getArithmeticReductionCost(ReduxOpCode, ReduxType,
-                                        /*IsPairwiseForm=*/false,
-                                        CostKind);
-    case RK_MinMax:
-      return getMinMaxReductionCost(
-          ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-          /*IsPairwiseForm=*/false, /*IsUnsigned=*/false);
-    case RK_UnsignedMinMax:
-      return getMinMaxReductionCost(
-          ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-          /*IsPairwiseForm=*/false, /*IsUnsigned=*/true);
-    case RK_None:
-      break;
-    }
-
-    switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
-    case RK_Arithmetic:
-      return getArithmeticReductionCost(ReduxOpCode, ReduxType,
-                                        /*IsPairwiseForm=*/true, CostKind);
-    case RK_MinMax:
-      return getMinMaxReductionCost(
-          ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-          /*IsPairwiseForm=*/true, /*IsUnsigned=*/false);
-    case RK_UnsignedMinMax:
-      return getMinMaxReductionCost(
-          ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-          /*IsPairwiseForm=*/true, /*IsUnsigned=*/true);
-    case RK_None:
-      break;
-    }
-
-    return getVectorInstrCost(I->getOpcode(), EEI->getOperand(0)->getType(),
-                              Idx);
-  }
+  case Instruction::ExtractElement:
   case Instruction::InsertElement:
   case Instruction::ExtractValue:
   case Instruction::ShuffleVector:
-    return getUserCost(I, CostKind);
   case Instruction::Call:
     return getUserCost(I, CostKind);
   default:

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index 38ee6d1e7aaf..2c42b7a31aa2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -992,13 +992,6 @@ GCNTTIImpl::getUserCost(const User *U, ArrayRef<const Value *> Operands,
 
   // Estimate 
diff erent operations to be optimized out
   switch (I->getOpcode()) {
-  case Instruction::ExtractElement: {
-    ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
-    unsigned Idx = -1;
-    if (CI)
-      Idx = CI->getZExtValue();
-    return getVectorInstrCost(I->getOpcode(), I->getOperand(0)->getType(), Idx);
-  }
   case Instruction::FNeg:
     return getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind,
                                   TTI::OK_AnyValue, TTI::OK_AnyValue,


        


More information about the llvm-commits mailing list