[llvm] 60280e9 - [Analysis] TTI: Add CastContextHint for getCastInstrCost

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 29 05:33:13 PDT 2020


Author: David Green
Date: 2020-07-29T13:32:53+01:00
New Revision: 60280e9818a6274d6e16556e4392821bbb65bd31

URL: https://github.com/llvm/llvm-project/commit/60280e9818a6274d6e16556e4392821bbb65bd31
DIFF: https://github.com/llvm/llvm-project/commit/60280e9818a6274d6e16556e4392821bbb65bd31.diff

LOG: [Analysis] TTI: Add CastContextHint for getCastInstrCost

Currently, getCastInstrCost has limited information about the cast it's
rating, often just the opcode and types.  Sometimes there is a context
instruction as well, but it isn't trustworthy: for instance, when the
vectorizer is rating a plan, it calls getCastInstrCost with the old
instructions when, in fact, it's trying to evaluate the cost of the
instruction post-vectorization.  Thus, the current system can get the
cost of certain casts incorrect as the correct cost can vary greatly
based on the context in which it's used.

For example, if the vectorizer queries getCastInstrCost to evaluate the
cost of a sext(load) with tail predication enabled, getCastInstrCost
will think it's free most of the time, but it's not always free. On ARM
MVE, a VLD2 group cannot be extended like a normal VLDR can. Similar
situations can come up with how masked loads can be extended when being
split.

To fix that, this path adds a new parameter to getCastInstrCost to give
it a hint about the context of the cast. It adds a CastContextHint enum
which contains the type of the load/store being created by the
vectorizer - one for each of the types it can produce.

Original patch by Pierre van Houtryve

Differential Revision: https://reviews.llvm.org/D79162

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index e1426e3e5192..092c9744a21b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1021,10 +1021,47 @@ class TargetTransformInfo {
   int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index = 0,
                      VectorType *SubTp = nullptr) const;
 
+  /// Represents a hint about the context in which a cast is used.
+  ///
+  /// For zext/sext, the context of the cast is the operand, which must be a
+  /// load of some kind. For trunc, the context is of the cast is the single
+  /// user of the instruction, which must be a store of some kind.
+  ///
+  /// This enum allows the vectorizer to give getCastInstrCost an idea of the
+  /// type of cast it's dealing with, as not every cast is equal. For instance,
+  /// the zext of a load may be free, but the zext of an interleaving load can
+  //// be (very) expensive!
+  ///
+  /// See \c getCastContextHint to compute a CastContextHint from a cast
+  /// Instruction*. Callers can use it if they don't need to override the
+  /// context and just want it to be calculated from the instruction.
+  ///
+  /// FIXME: This handles the types of load/store that the vectorizer can
+  /// produce, which are the cases where the context instruction is most
+  /// likely to be incorrect. There are other situations where that can happen
+  /// too, which might be handled here but in the long run a more general
+  /// solution of costing multiple instructions at the same times may be better.
+  enum class CastContextHint : uint8_t {
+    None,          ///< The cast is not used with a load/store of any kind.
+    Normal,        ///< The cast is used with a normal load/store.
+    Masked,        ///< The cast is used with a masked load/store.
+    GatherScatter, ///< The cast is used with a gather/scatter.
+    Interleave,    ///< The cast is used with an interleaved load/store.
+    Reversed,      ///< The cast is used with a reversed load/store.
+  };
+
+  /// Calculates a CastContextHint from \p I.
+  /// This should be used by callers of getCastInstrCost if they wish to
+  /// determine the context from some instruction.
+  /// \returns the CastContextHint for ZExt/SExt/Trunc, None if \p I is nullptr,
+  /// or if it's another type of cast.
+  static CastContextHint getCastContextHint(const Instruction *I);
+
   /// \return The expected cost of cast instructions, such as bitcast, trunc,
   /// zext, etc. If there is an existing instruction that holds Opcode, it
   /// may be passed in the 'I' parameter.
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                       TTI::CastContextHint CCH,
                        TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency,
                        const Instruction *I = nullptr) const;
 
@@ -1454,6 +1491,7 @@ class TargetTransformInfo::Concept {
   virtual int getShuffleCost(ShuffleKind Kind, VectorType *Tp, int Index,
                              VectorType *SubTp) = 0;
   virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                               CastContextHint CCH,
                                TTI::TargetCostKind CostKind,
                                const Instruction *I) = 0;
   virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst,
@@ -1882,9 +1920,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getShuffleCost(Kind, Tp, Index, SubTp);
   }
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I) override {
-    return Impl.getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    return Impl.getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
   }
   int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy,
                                unsigned Index) override {

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 73e5ff60a4be..4dc0a90cc9db 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -423,6 +423,7 @@ class TargetTransformInfoImplBase {
   }
 
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                            TTI::CastContextHint CCH,
                             TTI::TargetCostKind CostKind,
                             const Instruction *I) {
     switch (Opcode) {
@@ -915,7 +916,8 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
     case Instruction::SExt:
     case Instruction::ZExt:
     case Instruction::AddrSpaceCast:
-      return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I);
+      return TargetTTI->getCastInstrCost(
+          Opcode, Ty, OpTy, TTI::getCastContextHint(I), CostKind, I);
     case Instruction::Store: {
       auto *SI = cast<StoreInst>(U);
       Type *ValTy = U->getOperand(0)->getType();

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index aabfb74b7cef..e9f632979705 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -716,9 +716,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                            TTI::CastContextHint CCH,
                             TTI::TargetCostKind CostKind,
                             const Instruction *I = nullptr) {
-    if (BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I) == 0)
+    if (BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I) == 0)
       return 0;
 
     const TargetLoweringBase *TLI = getTLI();
@@ -756,15 +757,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         return 0;
       LLVM_FALLTHROUGH;
     case Instruction::SExt:
-      if (!I)
-        break;
-
-      if (getTLI()->isExtFree(I))
+      if (I && getTLI()->isExtFree(I))
         return 0;
 
       // If this is a zext/sext of a load, return 0 if the corresponding
       // extending load exists on target.
-      if (I && isa<LoadInst>(I->getOperand(0))) {
+      if (CCH == TTI::CastContextHint::Normal) {
         EVT ExtVT = EVT::getEVT(Dst);
         EVT LoadVT = EVT::getEVT(Src);
         unsigned LType =
@@ -839,7 +837,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         unsigned SplitCost =
             (!SplitSrc || !SplitDst) ? TTI->getVectorSplitCost() : 0;
         return SplitCost +
-               (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy,
+               (2 * TTI->getCastInstrCost(Opcode, SplitDstTy, SplitSrcTy, CCH,
                                           CostKind, I));
       }
 
@@ -847,7 +845,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       // the operation will get scalarized.
       unsigned Num = cast<FixedVectorType>(DstVTy)->getNumElements();
       unsigned Cost = thisT()->getCastInstrCost(
-          Opcode, Dst->getScalarType(), Src->getScalarType(), CostKind, I);
+          Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind, I);
 
       // Return the cost of multiple scalar invocation plus the cost of
       // inserting and extracting the values.
@@ -872,7 +870,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return thisT()->getVectorInstrCost(Instruction::ExtractElement, VecTy,
                                        Index) +
            thisT()->getCastInstrCost(Opcode, Dst, VecTy->getElementType(),
-                                     TTI::TCK_RecipThroughput);
+                                     TTI::CastContextHint::None, TTI::TCK_RecipThroughput);
   }
 
   unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) {
@@ -1522,13 +1520,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       unsigned ExtOp =
           IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt;
+      TTI::CastContextHint CCH = TTI::CastContextHint::None;
 
       unsigned Cost = 0;
-      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, RetTy, CostKind);
+      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, RetTy, CCH, CostKind);
       Cost +=
           thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
       Cost += 2 * thisT()->getCastInstrCost(Instruction::Trunc, RetTy, ExtTy,
-                                            CostKind);
+                                            CCH, CostKind);
       Cost += thisT()->getArithmeticInstrCost(Instruction::LShr, RetTy,
                                               CostKind, TTI::OK_AnyValue,
                                               TTI::OK_UniformConstantValue);
@@ -1587,13 +1586,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       unsigned ExtOp =
           IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt;
+      TTI::CastContextHint CCH = TTI::CastContextHint::None;
 
       unsigned Cost = 0;
-      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, MulTy, CostKind);
+      Cost += 2 * thisT()->getCastInstrCost(ExtOp, ExtTy, MulTy, CCH, CostKind);
       Cost +=
           thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
       Cost += 2 * thisT()->getCastInstrCost(Instruction::Trunc, MulTy, ExtTy,
-                                            CostKind);
+                                            CCH, CostKind);
       Cost += thisT()->getArithmeticInstrCost(Instruction::LShr, MulTy,
                                               CostKind, TTI::OK_AnyValue,
                                               TTI::OK_UniformConstantValue);

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index c9e702ce56b8..944e621fc070 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -730,12 +730,57 @@ int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, VectorType *Ty,
   return Cost;
 }
 
+TTI::CastContextHint
+TargetTransformInfo::getCastContextHint(const Instruction *I) {
+  if (!I)
+    return CastContextHint::None;
+
+  auto getLoadStoreKind = [](const Value *V, unsigned LdStOp, unsigned MaskedOp,
+                             unsigned GatScatOp) {
+    const Instruction *I = dyn_cast<Instruction>(V);
+    if (!I)
+      return CastContextHint::None;
+
+    if (I->getOpcode() == LdStOp)
+      return CastContextHint::Normal;
+
+    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+      if (II->getIntrinsicID() == MaskedOp)
+        return TTI::CastContextHint::Masked;
+      if (II->getIntrinsicID() == GatScatOp)
+        return TTI::CastContextHint::GatherScatter;
+    }
+
+    return TTI::CastContextHint::None;
+  };
+
+  switch (I->getOpcode()) {
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::FPExt:
+    return getLoadStoreKind(I->getOperand(0), Instruction::Load,
+                            Intrinsic::masked_load, Intrinsic::masked_gather);
+  case Instruction::Trunc:
+  case Instruction::FPTrunc:
+    if (I->hasOneUse())
+      return getLoadStoreKind(*I->user_begin(), Instruction::Store,
+                              Intrinsic::masked_store,
+                              Intrinsic::masked_scatter);
+    break;
+  default:
+    return CastContextHint::None;
+  }
+
+  return TTI::CastContextHint::None;
+}
+
 int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                          CastContextHint CCH,
                                           TTI::TargetCostKind CostKind,
                                           const Instruction *I) const {
   assert((I == nullptr || I->getOpcode() == Opcode) &&
          "Opcode should reflect passed instruction.");
-  int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cf6de797727b..5f5da63b21b6 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -270,6 +270,7 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
 }
 
 int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                     TTI::CastContextHint CCH,
                                      TTI::TargetCostKind CostKind,
                                      const Instruction *I) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -306,7 +307,8 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+    return AdjustCost(
+        BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 
   static const TypeConversionCostTblEntry
   ConversionTbl[] = {
@@ -410,7 +412,8 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                                  SrcTy.getSimpleVT()))
     return AdjustCost(Entry->Cost);
 
-  return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+  return AdjustCost(
+      BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
 int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
@@ -442,12 +445,14 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
   // we may get the extension for free. If not, get the default cost for the
   // extend.
   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
-    return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind);
+    return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
+                                   CostKind);
 
   // The destination type should be larger than the element type. If not, get
   // the default cost for the extend.
   if (DstVT.getSizeInBits() < SrcVT.getSizeInBits())
-    return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind);
+    return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
+                                   CostKind);
 
   switch (Opcode) {
   default:
@@ -466,7 +471,8 @@ int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
   }
 
   // If we are unable to perform the extend for free, get the default cost.
-  return Cost + getCastInstrCost(Opcode, Dst, Src, CostKind);
+  return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
+                                 CostKind);
 }
 
 unsigned AArch64TTIImpl::getCFInstrCost(unsigned Opcode,

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1f029689a60e..5d1371f13fb3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -114,7 +114,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   unsigned getMaxInterleaveFactor(unsigned VF);
 
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
 
   int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy,

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 25d0a4c21fa6..59f33e694442 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -301,6 +301,7 @@ int ARMTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Im
 }
 
 int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                 TTI::CastContextHint CCH,
                                  TTI::TargetCostKind CostKind,
                                  const Instruction *I) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -317,7 +318,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+    return AdjustCost(
+        BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 
   // The extend of a load is free
   if (I && isa<LoadInst>(I->getOperand(0))) {
@@ -388,8 +390,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     };
     if (SrcTy.isVector() && ST->hasMVEIntegerOps()) {
       if (const auto *Entry =
-              ConvertCostTableLookup(MVELoadConversionTbl, ISD, SrcTy.getSimpleVT(),
-                                     DstTy.getSimpleVT()))
+              ConvertCostTableLookup(MVELoadConversionTbl, ISD,
+                                     SrcTy.getSimpleVT(), DstTy.getSimpleVT()))
         return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
     }
 
@@ -399,8 +401,8 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     };
     if (SrcTy.isVector() && ST->hasMVEFloatOps()) {
       if (const auto *Entry =
-              ConvertCostTableLookup(MVEFLoadConversionTbl, ISD, SrcTy.getSimpleVT(),
-                                     DstTy.getSimpleVT()))
+              ConvertCostTableLookup(MVEFLoadConversionTbl, ISD,
+                                     SrcTy.getSimpleVT(), DstTy.getSimpleVT()))
         return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
     }
   }
@@ -672,7 +674,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                      ? ST->getMVEVectorCostFactor()
                      : 1;
   return AdjustCost(
-    BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+      BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
 int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 093dfbbf5f02..ac7d0378d90b 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -210,7 +210,7 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
   }
 
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
 
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 80c8736cb74a..68efaf767502 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -270,7 +270,9 @@ unsigned HexagonTTIImpl::getArithmeticInstrCost(
 }
 
 unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy,
-      Type *SrcTy, TTI::TargetCostKind CostKind, const Instruction *I) {
+                                          Type *SrcTy, TTI::CastContextHint CCH,
+                                          TTI::TargetCostKind CostKind,
+                                          const Instruction *I) {
   if (SrcTy->isFPOrFPVectorTy() || DstTy->isFPOrFPVectorTy()) {
     unsigned SrcN = SrcTy->isFPOrFPVectorTy() ? getTypeNumElements(SrcTy) : 0;
     unsigned DstN = DstTy->isFPOrFPVectorTy() ? getTypeNumElements(DstTy) : 0;

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 5fe397486402..07e59fb5585e 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -146,8 +146,9 @@ class HexagonTTIImpl : public BasicTTIImplBase<HexagonTTIImpl> {
       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
       const Instruction *CxtI = nullptr);
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-            TTI::TargetCostKind CostKind,
-            const Instruction *I = nullptr);
+                            TTI::CastContextHint CCH,
+                            TTI::TargetCostKind CostKind,
+                            const Instruction *I = nullptr);
   unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
 
   unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) {

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index ee8842f4d866..4171d4a35b25 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -879,11 +879,12 @@ int PPCTTIImpl::getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind) {
 }
 
 int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                 TTI::CastContextHint CCH,
                                  TTI::TargetCostKind CostKind,
                                  const Instruction *I) {
   assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode");
 
-  int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
   Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src);
   // TODO: Allow non-throughput costs that aren't binary.
   if (CostKind != TTI::TCK_RecipThroughput)

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
index a3453b00a864..d9aab298c7c1 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
@@ -106,7 +106,7 @@ class PPCTTIImpl : public BasicTTIImplBase<PPCTTIImpl> {
       const Instruction *CxtI = nullptr);
   int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp);
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
   int getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind);
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 864200e5f71c..8758ddee0aab 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -699,11 +699,12 @@ getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
 }
 
 int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                     TTI::CastContextHint CCH,
                                      TTI::TargetCostKind CostKind,
                                      const Instruction *I) {
   // FIXME: Can the logic below also be used for these cost kinds?
   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) {
-    int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
     return BaseCost == 0 ? BaseCost : 1;
   }
 
@@ -786,8 +787,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
       // Return the cost of multiple scalar invocation plus the cost of
       // inserting and extracting the values. Base implementation does not
       // realize float->int gets scalarized.
-      unsigned ScalarCost = getCastInstrCost(Opcode, Dst->getScalarType(),
-                                             Src->getScalarType(), CostKind);
+      unsigned ScalarCost = getCastInstrCost(
+          Opcode, Dst->getScalarType(), Src->getScalarType(), CCH, CostKind);
       unsigned TotCost = VF * ScalarCost;
       bool NeedsInserts = true, NeedsExtracts = true;
       // FP128 registers do not get inserted or extracted.
@@ -828,7 +829,7 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     }
   }
 
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 }
 
 // Scalar i8 / i16 operations will typically be made after first extending

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index 7f8f7f6f923f..1aa31ff1690c 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -93,7 +93,7 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
   unsigned getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
                                          const Instruction *I);
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
                          TTI::TargetCostKind CostKind,

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 491078fd0542..a484114330cc 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1367,6 +1367,7 @@ int X86TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *BaseTp,
 }
 
 int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
+                                 TTI::CastContextHint CCH,
                                  TTI::TargetCostKind CostKind,
                                  const Instruction *I) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -1988,7 +1989,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
 
   // The function getSimpleVT only handles simple value types.
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind));
+    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind));
 
   MVT SimpleSrcTy = SrcTy.getSimpleVT();
   MVT SimpleDstTy = DstTy.getSimpleVT();
@@ -2049,7 +2050,8 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
       return AdjustCost(Entry->Cost);
   }
 
-  return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
+  return AdjustCost(
+      BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
 int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index fb4719064613..8d2fa27ee7b0 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -130,7 +130,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   int getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index,
                      VectorType *SubTp);
   int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
-                       TTI::TargetCostKind CostKind,
+                       TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                        const Instruction *I = nullptr);
   int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
                          TTI::TargetCostKind CostKind,

diff  --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
index dc2ad14ae61e..c344c6c68477 100644
--- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
+++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
@@ -2025,8 +2025,8 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,
 
       Type *SrcTy = CI->getOperand(0)->getType();
       Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy,
-                                   TargetTransformInfo::TCK_SizeAndLatency,
-                                   CI);
+                                   TTI::getCastContextHint(CI),
+                                   TargetTransformInfo::TCK_SizeAndLatency, CI);
 
     } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) {
       // Cost of the address calculation

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 3527ca4bbbbc..c99d612a7768 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -2150,8 +2150,9 @@ bool SCEVExpander::isHighCostExpansionHelper(
       llvm_unreachable("There are no other cast types.");
     }
     const SCEV *Op = CastExpr->getOperand();
-    BudgetRemaining -= TTI.getCastInstrCost(Opcode, /*Dst=*/S->getType(),
-                                            /*Src=*/Op->getType(), CostKind);
+    BudgetRemaining -= TTI.getCastInstrCost(
+        Opcode, /*Dst=*/S->getType(),
+        /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind);
     Worklist.emplace_back(Op);
     return false; // Will answer upon next entry into this function.
   }

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d207ca250783..4b519514d543 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6458,13 +6458,54 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
   case Instruction::Trunc:
   case Instruction::FPTrunc:
   case Instruction::BitCast: {
+    // Computes the CastContextHint from a Load/Store instruction.
+    auto ComputeCCH = [&](Instruction *I) -> TTI::CastContextHint {
+      assert((isa<LoadInst>(I) || isa<StoreInst>(I)) &&
+             "Expected a load or a store!");
+
+      if (VF == 1)
+        return TTI::CastContextHint::Normal;
+
+      switch (getWideningDecision(I, VF)) {
+      case LoopVectorizationCostModel::CM_GatherScatter:
+        return TTI::CastContextHint::GatherScatter;
+      case LoopVectorizationCostModel::CM_Interleave:
+        return TTI::CastContextHint::Interleave;
+      case LoopVectorizationCostModel::CM_Scalarize:
+      case LoopVectorizationCostModel::CM_Widen:
+        return Legal->isMaskRequired(I) ? TTI::CastContextHint::Masked
+                                        : TTI::CastContextHint::Normal;
+      case LoopVectorizationCostModel::CM_Widen_Reverse:
+        return TTI::CastContextHint::Reversed;
+      case LoopVectorizationCostModel::CM_Unknown:
+        llvm_unreachable("Instr did not go through cost modelling?");
+      }
+
+      llvm_unreachable("Unhandled case!");
+    };
+
+    unsigned Opcode = I->getOpcode();
+    TTI::CastContextHint CCH = TTI::CastContextHint::None;
+    // For Trunc, the context is the only user, which must be a StoreInst.
+    if (Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) {
+      if (I->hasOneUse())
+        if (StoreInst *Store = dyn_cast<StoreInst>(*I->user_begin()))
+          CCH = ComputeCCH(Store);
+    }
+    // For Z/Sext, the context is the operand, which must be a LoadInst.
+    else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt ||
+             Opcode == Instruction::FPExt) {
+      if (LoadInst *Load = dyn_cast<LoadInst>(I->getOperand(0)))
+        CCH = ComputeCCH(Load);
+    }
+
     // We optimize the truncation of induction variables having constant
     // integer steps. The cost of these truncations is the same as the scalar
     // operation.
     if (isOptimizableIVTruncate(I, VF)) {
       auto *Trunc = cast<TruncInst>(I);
       return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(),
-                                  Trunc->getSrcTy(), CostKind, Trunc);
+                                  Trunc->getSrcTy(), CCH, CostKind, Trunc);
     }
 
     Type *SrcScalarTy = I->getOperand(0)->getType();
@@ -6477,12 +6518,11 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
       //
       // Calculate the modified src and dest types.
       Type *MinVecTy = VectorTy;
-      if (I->getOpcode() == Instruction::Trunc) {
+      if (Opcode == Instruction::Trunc) {
         SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy);
         VectorTy =
             largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy);
-      } else if (I->getOpcode() == Instruction::ZExt ||
-                 I->getOpcode() == Instruction::SExt) {
+      } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) {
         SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy);
         VectorTy =
             smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy);
@@ -6490,8 +6530,8 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
     }
 
     unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1;
-    return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy,
-                                    CostKind, I);
+    return N *
+           TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I);
   }
   case Instruction::Call: {
     bool NeedToScalarize;

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 5fb8ad56d8b3..d575b3245c65 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -3399,8 +3399,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
                   Ext->getOpcode(), Ext->getType(), VecTy, i);
               // Add back the cost of s|zext which is subtracted separately.
               DeadCost += TTI->getCastInstrCost(
-                  Ext->getOpcode(), Ext->getType(), E->getType(), CostKind,
-                  Ext);
+                  Ext->getOpcode(), Ext->getType(), E->getType(),
+                  TTI::getCastContextHint(Ext), CostKind, Ext);
               continue;
             }
           }
@@ -3424,8 +3424,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
     case Instruction::BitCast: {
       Type *SrcTy = VL0->getOperand(0)->getType();
       int ScalarEltCost =
-          TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, CostKind,
-                                VL0);
+          TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy,
+                                TTI::getCastContextHint(VL0), CostKind, VL0);
       if (NeedToShuffleReuses) {
         ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost;
       }
@@ -3437,9 +3437,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
       int VecCost = 0;
       // Check if the values are candidates to demote.
       if (!MinBWs.count(VL0) || VecTy != SrcVecTy) {
-        VecCost = ReuseShuffleCost +
-                  TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy,
-                                        CostKind, VL0);
+        VecCost =
+            ReuseShuffleCost +
+            TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy,
+                                  TTI::getCastContextHint(VL0), CostKind, VL0);
       }
       return VecCost - ScalarCost;
     }
@@ -3644,9 +3645,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
         auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size());
         auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size());
         VecCost = TTI->getCastInstrCost(E->getOpcode(), VecTy, Src0Ty,
-                                        CostKind);
+                                        TTI::CastContextHint::None, CostKind);
         VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty,
-                                         CostKind);
+                                         TTI::CastContextHint::None, CostKind);
       }
       VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0);
       return ReuseShuffleCost + VecCost - ScalarCost;


        


More information about the llvm-commits mailing list