[llvm] [TTI] Add VectorInstrContext for context-aware insert/extract costs. (PR #175982)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 16 13:08:44 PST 2026


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/175982

>From 576f57aa87c4889fd9b38d101b0553106dcf927c Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 14 Jan 2026 16:19:00 +0000
Subject: [PATCH] [TTI] Add VectorInstrContext for context-aware insert/extract
 costs.

This commit introduces the VectorInstrContext (VIC) infrastructure to
improve cost estimates for insert/extracts based on the context
instruction in which the insert/extract is used.

This is similar to CastContextHint, and allows providing context on how
the insert/extract is going to be used before creating IR. This is
useful in the LoopVectorizer, where costs need to estimated before
creating IR.

The new hint currently only replaces an existing check in AArch64, but
I plan to add additional uses of the store context hint as a follow-up
---
 .../llvm/Analysis/TargetTransformInfo.h       | 44 ++++++++---
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 32 ++++----
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      | 73 ++++++++++++-------
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 48 ++++++++----
 .../AArch64/AArch64TargetTransformInfo.cpp    | 37 +++++-----
 .../AArch64/AArch64TargetTransformInfo.h      | 42 ++++++-----
 .../AMDGPU/AMDGPUTargetTransformInfo.cpp      | 14 ++--
 .../Target/AMDGPU/AMDGPUTargetTransformInfo.h |  9 ++-
 .../Target/AMDGPU/R600TargetTransformInfo.cpp | 15 ++--
 .../Target/AMDGPU/R600TargetTransformInfo.h   |  9 ++-
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp | 13 ++--
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |  9 ++-
 .../Hexagon/HexagonTargetTransformInfo.cpp    | 10 +--
 .../Hexagon/HexagonTargetTransformInfo.h      |  9 ++-
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   | 11 ++-
 .../Target/PowerPC/PPCTargetTransformInfo.cpp | 14 ++--
 .../Target/PowerPC/PPCTargetTransformInfo.h   |  9 ++-
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 15 ++--
 .../Target/RISCV/RISCVTargetTransformInfo.h   | 20 +++--
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 14 ++--
 .../SystemZ/SystemZTargetTransformInfo.h      | 20 +++--
 .../WebAssemblyTargetTransformInfo.cpp        |  4 +-
 .../WebAssemblyTargetTransformInfo.h          |  9 ++-
 .../lib/Target/X86/X86TargetTransformInfo.cpp | 23 +++---
 llvm/lib/Target/X86/X86TargetTransformInfo.h  | 20 +++--
 .../Transforms/Vectorize/LoopVectorize.cpp    | 16 +++-
 llvm/lib/Transforms/Vectorize/VPlan.cpp       |  8 +-
 llvm/lib/Transforms/Vectorize/VPlanHelpers.h  | 14 ++--
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  5 +-
 llvm/lib/Transforms/Vectorize/VPlanUtils.h    |  2 +-
 30 files changed, 337 insertions(+), 231 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index ff91b24ff17e5..ea8122b119c76 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1041,6 +1041,24 @@ class TargetTransformInfo {
   isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
                                                    int RetIdx) const;
 
+  /// Represents a hint about the context in which an insert/extract is used.
+  ///
+  /// On some targets, inserts/extracts can cheaply be folded into loads/stores.
+  ///
+  /// This enum allows the vectorizer to give getVectorInstrCost an idea of how
+  /// inserts/extracts are used
+  ///
+  /// See \c getVectorInstrContextHint to compute a VectorInstrContext from an
+  /// insert/extract Instruction*.
+  enum class VectorInstrContext : uint8_t {
+    None,  ///< The insert/extract is not used with a load/store.
+    Load,  ///< The value being inserted comes from a load (InsertElement only).
+    Store, ///< The extracted value is stored (ExtractElement only).
+  };
+
+  /// Calculates a VectorInstrContext from \p I.
+  static VectorInstrContext getVectorInstrContextHint(const Instruction *I);
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.  The involved values may be passed in VL if
@@ -1048,12 +1066,14 @@ class TargetTransformInfo {
   LLVM_ABI InstructionCost getScalarizationOverhead(
       VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
       TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const;
+      ArrayRef<Value *> VL = {},
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const;
 
   /// Estimate the overhead of scalarizing operands with the given types. The
   /// (potentially vector) types to use for each of argument are passes via Tys.
   LLVM_ABI InstructionCost getOperandsScalarizationOverhead(
-      ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind) const;
+      ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const;
 
   /// If target has efficient vector element load/store instructions, it can
   /// return true here so that insertion/extraction costs are not added to
@@ -1570,11 +1590,11 @@ class TargetTransformInfo {
   /// This is used when the instruction is not available; a typical use
   /// case is to provision the cost of vectorization/scalarization in
   /// vectorizer passes.
-  LLVM_ABI InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                              TTI::TargetCostKind CostKind,
-                                              unsigned Index = -1,
-                                              const Value *Op0 = nullptr,
-                                              const Value *Op1 = nullptr) const;
+  LLVM_ABI InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+      unsigned Index = -1, const Value *Op0 = nullptr,
+      const Value *Op1 = nullptr,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const;
 
   /// \return The expected cost of vector Insert and Extract.
   /// Use -1 to indicate that there is no information on the index value.
@@ -1588,7 +1608,8 @@ class TargetTransformInfo {
   LLVM_ABI InstructionCost getVectorInstrCost(
       unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
       Value *Scalar,
-      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const;
+      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const;
 
   /// \return The expected cost of vector Insert and Extract.
   /// This is used when instruction is available, and implementation
@@ -1596,9 +1617,10 @@ class TargetTransformInfo {
   ///
   /// A typical suitable use case is cost estimation when vector instruction
   /// exists (e.g., from basic blocks during transformation).
-  LLVM_ABI InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
-                                              TTI::TargetCostKind CostKind,
-                                              unsigned Index = -1) const;
+  LLVM_ABI InstructionCost getVectorInstrCost(
+      const Instruction &I, Type *Val, TTI::TargetCostKind CostKind,
+      unsigned Index = -1,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const;
 
   /// \return The expected cost of inserting or extracting a lane that is \p
   /// Index elements from the end of a vector, i.e. the mathematical expression
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 07b3755924fd1..2a93b29930ad4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -500,13 +500,16 @@ class TargetTransformInfoImplBase {
   virtual InstructionCost getScalarizationOverhead(
       VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
       TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const {
+      ArrayRef<Value *> VL = {},
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const {
+    // Default implementation returns 0.
+    // BasicTTIImpl provides the actual implementation.
     return 0;
   }
 
-  virtual InstructionCost
-  getOperandsScalarizationOverhead(ArrayRef<Type *> Tys,
-                                   TTI::TargetCostKind CostKind) const {
+  virtual InstructionCost getOperandsScalarizationOverhead(
+      ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const {
     return 0;
   }
 
@@ -830,10 +833,10 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                             TTI::TargetCostKind CostKind,
-                                             unsigned Index, const Value *Op0,
-                                             const Value *Op1) const {
+  virtual InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      const Value *Op0, const Value *Op1,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const {
     return 1;
   }
 
@@ -844,13 +847,15 @@ class TargetTransformInfoImplBase {
   virtual InstructionCost getVectorInstrCost(
       unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
       Value *Scalar,
-      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
+      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const {
     return 1;
   }
 
-  virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
-                                             TTI::TargetCostKind CostKind,
-                                             unsigned Index) const {
+  virtual InstructionCost getVectorInstrCost(
+      const Instruction &I, Type *Val, TTI::TargetCostKind CostKind,
+      unsigned Index,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const {
     return 1;
   }
 
@@ -1574,7 +1579,8 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
       if (auto *CI = dyn_cast<ConstantInt>(Operands[2]))
         if (CI->getValue().getActiveBits() <= 32)
           Idx = CI->getZExtValue();
-      return TargetTTI->getVectorInstrCost(*IE, Ty, CostKind, Idx);
+      return TargetTTI->getVectorInstrCost(*IE, Ty, CostKind, Idx,
+                                           TTI::getVectorInstrContextHint(IE));
     }
     case Instruction::ShuffleVector: {
       auto *Shuffle = dyn_cast<ShuffleVectorInst>(U);
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index ef91c845ce9e7..db455c0fbf87a 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -380,6 +380,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   ~BasicTTIImplBase() override = default;
 
   using TargetTransformInfoImplBase::DL;
+  using TargetTransformInfoImplBase::getScalarizationOverhead;
 
 public:
   /// \name Scalar TTI Implementations
@@ -893,10 +894,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
-  InstructionCost getScalarizationOverhead(
-      VectorType *InTy, const APInt &DemandedElts, bool Insert, bool Extract,
-      TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const override {
+  InstructionCost
+  getScalarizationOverhead(VectorType *InTy, const APInt &DemandedElts,
+                           bool Insert, bool Extract,
+                           TTI::TargetCostKind CostKind,
+                           bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+                           TTI::VectorInstrContext VIC =
+                               TTI::VectorInstrContext::None) const override {
     /// FIXME: a bitfield is not a reasonable abstraction for talking about
     /// which elements are needed from a scalable vector
     if (isa<ScalableVectorType>(InTy))
@@ -914,12 +918,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         continue;
       if (Insert) {
         Value *InsertedVal = VL.empty() ? nullptr : VL[i];
-        Cost += thisT()->getVectorInstrCost(Instruction::InsertElement, Ty,
-                                            CostKind, i, nullptr, InsertedVal);
+        Cost +=
+            thisT()->getVectorInstrCost(Instruction::InsertElement, Ty,
+                                        CostKind, i, nullptr, InsertedVal, VIC);
       }
       if (Extract)
         Cost += thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty,
-                                            CostKind, i, nullptr, nullptr);
+                                            CostKind, i, nullptr, nullptr, VIC);
     }
 
     return Cost;
@@ -947,23 +952,27 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
-  InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
-                                           bool Extract,
-                                           TTI::TargetCostKind CostKind) const {
+  InstructionCost getScalarizationOverhead(
+      VectorType *InTy, bool Insert, bool Extract, TTI::TargetCostKind CostKind,
+      bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const {
     if (isa<ScalableVectorType>(InTy))
       return InstructionCost::getInvalid();
     auto *Ty = cast<FixedVectorType>(InTy);
 
     APInt DemandedElts = APInt::getAllOnes(Ty->getNumElements());
+    // Use CRTP to allow target overrides
     return thisT()->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                             CostKind);
+                                             CostKind, ForPoisonSrc, VL, VIC);
   }
 
   /// Estimate the overhead of scalarizing an instruction's
   /// operands. The (potentially vector) types to use for each of
   /// argument are passes via Tys.
   InstructionCost getOperandsScalarizationOverhead(
-      ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind) const override {
+      ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind,
+      TTI::VectorInstrContext VIC =
+          TTI::VectorInstrContext::None) const override {
     InstructionCost Cost = 0;
     for (Type *Ty : Tys) {
       // Disregard things like metadata arguments.
@@ -973,7 +982,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       if (auto *VecTy = dyn_cast<VectorType>(Ty))
         Cost += getScalarizationOverhead(VecTy, /*Insert*/ false,
-                                         /*Extract*/ true, CostKind);
+                                         /*Extract*/ true, CostKind,
+                                         /*ForPoisonSrc=*/true, {}, VIC);
     }
 
     return Cost;
@@ -1428,10 +1438,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return 1;
   }
 
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override {
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override {
     return getRegUsageForType(Val->getScalarType());
   }
 
@@ -1439,26 +1450,32 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   /// vector with 'Scalar' being the value being extracted,'User' being the user
   /// of the extract(nullptr if user is not known before vectorization) and
   /// 'Idx' being the extract lane.
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, Value *Scalar,
-                                     ArrayRef<std::tuple<Value *, User *, int>>
-                                         ScalarUserAndIdx) const override {
-    return thisT()->getVectorInstrCost(Opcode, Val, CostKind, Index, nullptr,
-                                       nullptr);
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Scalar,
+      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+      TTI::VectorInstrContext VIC =
+          TTI::VectorInstrContext::None) const override {
+    return getVectorInstrCost(Opcode, Val, CostKind, Index, nullptr, nullptr,
+                              VIC);
   }
 
-  InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index) const override {
+  InstructionCost
+  getVectorInstrCost(const Instruction &I, Type *Val,
+                     TTI::TargetCostKind CostKind, unsigned Index,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override {
     Value *Op0 = nullptr;
     Value *Op1 = nullptr;
     if (auto *IE = dyn_cast<InsertElementInst>(&I)) {
       Op0 = IE->getOperand(0);
       Op1 = IE->getOperand(1);
     }
+    // If VIC is None, compute it from the instruction
+    if (VIC == TTI::VectorInstrContext::None)
+      VIC = TTI::getVectorInstrContextHint(&I);
     return thisT()->getVectorInstrCost(I.getOpcode(), Val, CostKind, Index, Op0,
-                                       Op1);
+                                       Op1, VIC);
   }
 
   InstructionCost
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b2b77da4914d6..36ce3cfbdbacc 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -646,17 +646,35 @@ bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
   return TTIImpl->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
 }
 
+TargetTransformInfo::VectorInstrContext
+TargetTransformInfo::getVectorInstrContextHint(const Instruction *I) {
+  if (!I)
+    return VectorInstrContext::None;
+
+  // For inserts, check if the value being inserted comes from a load.
+  if (isa<InsertElementInst>(I) && isa<LoadInst>(I->getOperand(1)))
+    return VectorInstrContext::Load;
+
+  // For extracts, check if it has a single use that is a store.
+  if (isa<ExtractElementInst>(I) && I->hasOneUse() &&
+      isa<StoreInst>(*I->user_begin()))
+    return VectorInstrContext::Store;
+
+  return VectorInstrContext::None;
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind, bool ForPoisonSrc,
-    ArrayRef<Value *> VL) const {
+    TTI::TargetCostKind CostKind, bool ForPoisonSrc, ArrayRef<Value *> VL,
+    TTI::VectorInstrContext VIC) const {
   return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                           CostKind, ForPoisonSrc, VL);
+                                           CostKind, ForPoisonSrc, VL, VIC);
 }
 
 InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(
-    ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind) const {
-  return TTIImpl->getOperandsScalarizationOverhead(Tys, CostKind);
+    ArrayRef<Type *> Tys, TTI::TargetCostKind CostKind,
+    TTI::VectorInstrContext VIC) const {
+  return TTIImpl->getOperandsScalarizationOverhead(Tys, CostKind, VIC);
 }
 
 bool TargetTransformInfo::supportsEfficientVectorElementLoadStore() const {
@@ -1124,37 +1142,37 @@ InstructionCost TargetTransformInfo::getCmpSelInstrCost(
 
 InstructionCost TargetTransformInfo::getVectorInstrCost(
     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
-    const Value *Op0, const Value *Op1) const {
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   assert((Opcode == Instruction::InsertElement ||
           Opcode == Instruction::ExtractElement) &&
          "Expecting Opcode to be insertelement/extractelement.");
   InstructionCost Cost =
-      TTIImpl->getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+      TTIImpl->getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1, VIC);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
 
 InstructionCost TargetTransformInfo::getVectorInstrCost(
     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
-    Value *Scalar,
-    ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
+    Value *Scalar, ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+    TTI::VectorInstrContext VIC) const {
   assert((Opcode == Instruction::InsertElement ||
           Opcode == Instruction::ExtractElement) &&
          "Expecting Opcode to be insertelement/extractelement.");
   InstructionCost Cost = TTIImpl->getVectorInstrCost(
-      Opcode, Val, CostKind, Index, Scalar, ScalarUserAndIdx);
+      Opcode, Val, CostKind, Index, Scalar, ScalarUserAndIdx, VIC);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
 
-InstructionCost
-TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
-                                        TTI::TargetCostKind CostKind,
-                                        unsigned Index) const {
+InstructionCost TargetTransformInfo::getVectorInstrCost(
+    const Instruction &I, Type *Val, TTI::TargetCostKind CostKind,
+    unsigned Index, TTI::VectorInstrContext VIC) const {
   // FIXME: Assert that Opcode is either InsertElement or ExtractElement.
   // This is mentioned in the interface description and respected by all
   // callers, but never asserted upon.
-  InstructionCost Cost = TTIImpl->getVectorInstrCost(I, Val, CostKind, Index);
+  InstructionCost Cost =
+      TTIImpl->getVectorInstrCost(I, Val, CostKind, Index, VIC);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index e365ad3ca56d6..45db627a27357 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3996,7 +3996,8 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
     const Instruction *I, Value *Scalar,
-    ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
+    ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+    TTI::VectorInstrContext VIC) const {
   assert(Val->isVectorTy() && "This must be a vector type");
 
   if (Index != -1U) {
@@ -4025,7 +4026,7 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
     // register instruction. I.e., if this is an `insertelement` instruction,
     // and its second operand is a load, then we will generate a LD1, which
     // are expensive instructions on some uArchs.
-    if (I && isa<LoadInst>(I->getOperand(1))) {
+    if (VIC == TTI::VectorInstrContext::Load)
       if (ST->hasFastLD1Single())
         return 0;
       return CostKind == TTI::TCK_CodeSize
@@ -4166,33 +4167,33 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
                                        : ST->getVectorInsertExtractBaseCost();
 }
 
-InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                                   TTI::TargetCostKind CostKind,
-                                                   unsigned Index,
-                                                   const Value *Op0,
-                                                   const Value *Op1) const {
+InstructionCost AArch64TTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   // Treat insert at lane 0 into a poison vector as having zero cost. This
   // ensures vector broadcasts via an insert + shuffle (and will be lowered to a
   // single dup) are treated as cheap.
   if (Opcode == Instruction::InsertElement && Index == 0 && Op0 &&
       isa<PoisonValue>(Op0))
     return 0;
-  return getVectorInstrCostHelper(Opcode, Val, CostKind, Index);
+  return getVectorInstrCostHelper(Opcode, Val, CostKind, Index, nullptr,
+                                  nullptr, {}, VIC);
 }
 
 InstructionCost AArch64TTIImpl::getVectorInstrCost(
     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
-    Value *Scalar,
-    ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
+    Value *Scalar, ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+    TTI::VectorInstrContext VIC) const {
   return getVectorInstrCostHelper(Opcode, Val, CostKind, Index, nullptr, Scalar,
-                                  ScalarUserAndIdx);
+                                  ScalarUserAndIdx, VIC);
 }
 
-InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
-                                                   Type *Val,
-                                                   TTI::TargetCostKind CostKind,
-                                                   unsigned Index) const {
-  return getVectorInstrCostHelper(I.getOpcode(), Val, CostKind, Index, &I);
+InstructionCost
+AArch64TTIImpl::getVectorInstrCost(const Instruction &I, Type *Val,
+                                   TTI::TargetCostKind CostKind, unsigned Index,
+                                   TTI::VectorInstrContext VIC) const {
+  return getVectorInstrCostHelper(I.getOpcode(), Val, CostKind, Index, &I,
+                                  nullptr, {}, VIC);
 }
 
 InstructionCost
@@ -4215,8 +4216,8 @@ AArch64TTIImpl::getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val,
 
 InstructionCost AArch64TTIImpl::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind, bool ForPoisonSrc,
-    ArrayRef<Value *> VL) const {
+    TTI::TargetCostKind CostKind, bool ForPoisonSrc, ArrayRef<Value *> VL,
+    TTI::VectorInstrContext VIC) const {
   if (isa<ScalableVectorType>(Ty))
     return InstructionCost::getInvalid();
   if (Ty->getElementType()->isFloatingPointTy())
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index c9bf44b15144a..24f0848991c82 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -81,7 +81,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getVectorInstrCostHelper(
       unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
       const Instruction *I = nullptr, Value *Scalar = nullptr,
-      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx = {}) const;
+      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx = {},
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None) const;
 
 public:
   explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -214,24 +215,28 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
                                  const Instruction *I = nullptr) const override;
 
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 
   /// \param ScalarUserAndIdx encodes the information about extracts from a
   /// vector with 'Scalar' being the value being extracted,'User' being the user
   /// of the extract(nullptr if user is not known before vectorization) and
   /// 'Idx' being the extract lane.
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, Value *Scalar,
-                                     ArrayRef<std::tuple<Value *, User *, int>>
-                                         ScalarUserAndIdx) const override;
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Scalar,
+      ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx,
+      TTI::VectorInstrContext VIC =
+          TTI::VectorInstrContext::None) const override;
 
-  InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index) const override;
+  InstructionCost
+  getVectorInstrCost(const Instruction &I, Type *Val,
+                     TTI::TargetCostKind CostKind, unsigned Index,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 
   InstructionCost
   getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val,
@@ -500,10 +505,13 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                  VectorType *SubTp, ArrayRef<const Value *> Args = {},
                  const Instruction *CxtI = nullptr) const override;
 
-  InstructionCost getScalarizationOverhead(
-      VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-      TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const override;
+  InstructionCost
+  getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
+                           bool Insert, bool Extract,
+                           TTI::TargetCostKind CostKind,
+                           bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+                           TTI::VectorInstrContext VIC =
+                               TTI::VectorInstrContext::None) const override;
 
   /// Return the cost of the scaling factor used in the addressing
   /// mode represented by AM for this target, for a load/store
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index de02369ce4667..b1cf252450890 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -883,10 +883,9 @@ GCNTTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
   return LT.first * getHalfRateInstrCost(CostKind);
 }
 
-InstructionCost GCNTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
-                                               TTI::TargetCostKind CostKind,
-                                               unsigned Index, const Value *Op0,
-                                               const Value *Op1) const {
+InstructionCost GCNTTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *ValTy, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   switch (Opcode) {
   case Instruction::ExtractElement:
   case Instruction::InsertElement: {
@@ -895,8 +894,8 @@ InstructionCost GCNTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
     if (EltSize < 32) {
       if (EltSize == 16 && Index == 0 && ST->has16BitInsts())
         return 0;
-      return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0,
-                                       Op1);
+      return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1,
+                                       VIC);
     }
 
     // Extracts are just reads of a subregister, so are free. Inserts are
@@ -907,7 +906,8 @@ InstructionCost GCNTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
     return Index == ~0u ? 2 : 0;
   }
   default:
-    return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1);
+    return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1,
+                                     VIC);
   }
 }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index 4dcf381a9af93..3ec157aacd0aa 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -176,10 +176,11 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
                                      ArrayRef<unsigned> Indices = {}) const;
 
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *ValTy,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *ValTy, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 
   bool isReadRegisterSourceOfDivergence(const IntrinsicInst *ReadReg) const;
 
diff --git a/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.cpp
index 3093227279a31..c08edc1bb5512 100644
--- a/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.cpp
@@ -108,19 +108,17 @@ InstructionCost R600TTIImpl::getCFInstrCost(unsigned Opcode,
   }
 }
 
-InstructionCost R600TTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
-                                                TTI::TargetCostKind CostKind,
-                                                unsigned Index,
-                                                const Value *Op0,
-                                                const Value *Op1) const {
+InstructionCost R600TTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *ValTy, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   switch (Opcode) {
   case Instruction::ExtractElement:
   case Instruction::InsertElement: {
     unsigned EltSize =
         DL.getTypeSizeInBits(cast<VectorType>(ValTy)->getElementType());
     if (EltSize < 32) {
-      return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0,
-                                       Op1);
+      return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1,
+                                       VIC);
     }
 
     // Extracts are just reads of a subregister, so are free. Inserts are
@@ -131,7 +129,8 @@ InstructionCost R600TTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
     return Index == ~0u ? 2 : 0;
   }
   default:
-    return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1);
+    return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1,
+                                     VIC);
   }
 }
 
diff --git a/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.h b/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.h
index 3deae69bfc8c9..ade1b1518215c 100644
--- a/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/R600TargetTransformInfo.h
@@ -62,10 +62,11 @@ class R600TTIImpl final : public BasicTTIImplBase<R600TTIImpl> {
   InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
                                  const Instruction *I = nullptr) const override;
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *ValTy,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *ValTy, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index b947c8a10e2d8..94048360c0719 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -953,10 +953,9 @@ InstructionCost ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
       BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
 }
 
-InstructionCost ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
-                                               TTI::TargetCostKind CostKind,
-                                               unsigned Index, const Value *Op0,
-                                               const Value *Op1) const {
+InstructionCost ARMTTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *ValTy, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   // Penalize inserting into an D-subregister. We end up with a three times
   // lower estimated throughput on swift.
   if (ST->hasSlowLoadDSubregister() && Opcode == Instruction::InsertElement &&
@@ -975,7 +974,8 @@ InstructionCost ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
     if (ValTy->isVectorTy() &&
         ValTy->getScalarSizeInBits() <= 32)
       return std::max<InstructionCost>(
-          BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1),
+          BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1,
+                                    VIC),
           2U);
   }
 
@@ -989,7 +989,8 @@ InstructionCost ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
     return LT.first * (ValTy->getScalarType()->isIntegerTy() ? 4 : 1);
   }
 
-  return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1);
+  return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1,
+                                   VIC);
 }
 
 InstructionCost ARMTTIImpl::getCmpSelInstrCost(
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index fafd2d44a818c..94804152d96ec 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -354,10 +354,11 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
       const Instruction *I = nullptr) const override;
 
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 
   InstructionCost
   getAddressComputationCost(Type *Val, ScalarEvolution *SE, const SCEV *Ptr,
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 59c6201e07081..25ede7d262544 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -306,11 +306,9 @@ InstructionCost HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy,
   return 1;
 }
 
-InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                                   TTI::TargetCostKind CostKind,
-                                                   unsigned Index,
-                                                   const Value *Op0,
-                                                   const Value *Op1) const {
+InstructionCost HexagonTTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   Type *ElemTy = Val->isVectorTy() ? cast<VectorType>(Val)->getElementType()
                                    : Val;
   if (Opcode == Instruction::InsertElement) {
@@ -320,7 +318,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
       return Cost;
     // If it's not a 32-bit value, there will need to be an extract.
     return Cost + getVectorInstrCost(Instruction::ExtractElement, Val, CostKind,
-                                     Index, Op0, Op1);
+                                     Index, Op0, Op1, VIC);
   }
 
   if (Opcode == Instruction::ExtractElement)
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index edf88cf476f6d..0bd07a97ff3d5 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -145,10 +145,11 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
                    TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
                    const Instruction *I = nullptr) const override;
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 
   InstructionCost
   getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index ae12a6ea3baa3..40eb161bc8666 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -119,10 +119,13 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
       ArrayRef<const Value *> Args = {},
       const Instruction *CxtI = nullptr) const override;
 
-  InstructionCost getScalarizationOverhead(
-      VectorType *InTy, const APInt &DemandedElts, bool Insert, bool Extract,
-      TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const override {
+  InstructionCost
+  getScalarizationOverhead(VectorType *InTy, const APInt &DemandedElts,
+                           bool Insert, bool Extract,
+                           TTI::TargetCostKind CostKind,
+                           bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+                           TTI::VectorInstrContext VIC =
+                               TTI::VectorInstrContext::None) const override {
     if (!InTy->getElementCount().isFixed())
       return InstructionCost::getInvalid();
 
diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index fbed34277dbab..3125d7dc4fea9 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -681,10 +681,9 @@ InstructionCost PPCTTIImpl::getCmpSelInstrCost(
   return Cost * CostFactor;
 }
 
-InstructionCost PPCTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                               TTI::TargetCostKind CostKind,
-                                               unsigned Index, const Value *Op0,
-                                               const Value *Op1) const {
+InstructionCost PPCTTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   assert(Val->isVectorTy() && "This must be a vector type");
 
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
@@ -695,7 +694,7 @@ InstructionCost PPCTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
     return InstructionCost::getMax();
 
   InstructionCost Cost =
-      BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+      BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1, VIC);
   Cost *= CostFactor;
 
   if (ST->hasVSX() && Val->getScalarType()->isDoubleTy()) {
@@ -858,8 +857,9 @@ InstructionCost PPCTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
   if (Src->isVectorTy() && Opcode == Instruction::Store)
     for (int I = 0, E = cast<FixedVectorType>(Src)->getNumElements(); I < E;
          ++I)
-      Cost += getVectorInstrCost(Instruction::ExtractElement, Src, CostKind, I,
-                                 nullptr, nullptr);
+      Cost +=
+          getVectorInstrCost(Instruction::ExtractElement, Src, CostKind, I,
+                             nullptr, nullptr, TTI::VectorInstrContext::None);
 
   return Cost;
 }
diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
index f80ebdbce7f64..283b149ed71be 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h
@@ -128,10 +128,11 @@ class PPCTTIImpl final : public BasicTTIImplBase<PPCTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       const Instruction *I = nullptr) const override;
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
   InstructionCost getMemoryOpCost(
       unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace,
       TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index e812d092c3ea0..379b6af3572bc 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -984,8 +984,8 @@ static unsigned isM1OrSmaller(MVT VT) {
 
 InstructionCost RISCVTTIImpl::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind, bool ForPoisonSrc,
-    ArrayRef<Value *> VL) const {
+    TTI::TargetCostKind CostKind, bool ForPoisonSrc, ArrayRef<Value *> VL,
+    TTI::VectorInstrContext VIC) const {
   if (isa<ScalableVectorType>(Ty))
     return InstructionCost::getInvalid();
 
@@ -2413,11 +2413,9 @@ InstructionCost RISCVTTIImpl::getCFInstrCost(unsigned Opcode,
   return 0;
 }
 
-InstructionCost RISCVTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                                 TTI::TargetCostKind CostKind,
-                                                 unsigned Index,
-                                                 const Value *Op0,
-                                                 const Value *Op1) const {
+InstructionCost RISCVTTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   assert(Val->isVectorTy() && "This must be a vector type");
 
   // TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
@@ -2429,7 +2427,8 @@ InstructionCost RISCVTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 
   if (Opcode != Instruction::ExtractElement &&
       Opcode != Instruction::InsertElement)
-    return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+    return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1,
+                                     VIC);
 
   // Legalize the type.
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 6e38951520039..3434898d6d3e7 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -179,10 +179,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
                  VectorType *SubTp, ArrayRef<const Value *> Args = {},
                  const Instruction *CxtI = nullptr) const override;
 
-  InstructionCost getScalarizationOverhead(
-      VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-      TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const override;
+  InstructionCost
+  getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
+                           bool Insert, bool Extract,
+                           TTI::TargetCostKind CostKind,
+                           bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+                           TTI::VectorInstrContext VIC =
+                               TTI::VectorInstrContext::None) const override;
 
   InstructionCost
   getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
@@ -246,10 +249,11 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
                                  const Instruction *I = nullptr) const override;
 
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
 
   InstructionCost
   getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val,
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 2611c291abaa6..4322773f4afd6 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -493,8 +493,8 @@ static bool isFreeEltLoad(const Value *Op) {
 
 InstructionCost SystemZTTIImpl::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind, bool ForPoisonSrc,
-    ArrayRef<Value *> VL) const {
+    TTI::TargetCostKind CostKind, bool ForPoisonSrc, ArrayRef<Value *> VL,
+    TTI::VectorInstrContext VIC) const {
   unsigned NumElts = cast<FixedVectorType>(Ty)->getNumElements();
   InstructionCost Cost = 0;
 
@@ -1181,11 +1181,9 @@ InstructionCost SystemZTTIImpl::getCmpSelInstrCost(
                                    Op1Info, Op2Info);
 }
 
-InstructionCost SystemZTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                                   TTI::TargetCostKind CostKind,
-                                                   unsigned Index,
-                                                   const Value *Op0,
-                                                   const Value *Op1) const {
+InstructionCost SystemZTTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   if (Opcode == Instruction::InsertElement) {
     // Vector Element Load.
     if (Op1 != nullptr && isFreeEltLoad(Op1))
@@ -1208,7 +1206,7 @@ InstructionCost SystemZTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
     return Cost;
   }
 
-  return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+  return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1, VIC);
 }
 
 // Check if a load may be folded as a memory operand in its user.
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index fc681dec1859a..f4ba29c987f09 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -85,10 +85,13 @@ class SystemZTTIImpl final : public BasicTTIImplBase<SystemZTTIImpl> {
   bool hasDivRemOp(Type *DataType, bool IsSigned) const override;
   bool prefersVectorizedAddressing() const override { return false; }
   bool LSRWithInstrQueries() const override { return true; }
-  InstructionCost getScalarizationOverhead(
-      VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-      TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const override;
+  InstructionCost
+  getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
+                           bool Insert, bool Extract,
+                           TTI::TargetCostKind CostKind,
+                           bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+                           TTI::VectorInstrContext VIC =
+                               TTI::VectorInstrContext::None) const override;
   bool supportsEfficientVectorElementLoadStore() const override { return true; }
   bool enableInterleavedAccessVectorization() const override { return true; }
 
@@ -118,10 +121,11 @@ class SystemZTTIImpl final : public BasicTTIImplBase<SystemZTTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       const Instruction *I = nullptr) const override;
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
   bool isFoldableLoad(const LoadInst *Ld,
                       const Instruction *&FoldedValue) const;
   InstructionCost getMemoryOpCost(
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 434827c689a8c..e136a3e6dac75 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -383,9 +383,9 @@ InstructionCost WebAssemblyTTIImpl::getInterleavedMemoryOpCost(
 
 InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
-    const Value *Op0, const Value *Op1) const {
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   InstructionCost Cost = BasicTTIImplBase::getVectorInstrCost(
-      Opcode, Val, CostKind, Index, Op0, Op1);
+      Opcode, Val, CostKind, Index, Op0, Op1, VIC);
 
   // SIMD128's insert/extract currently only take constant indices.
   if (Index == -1u)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 3e9e8972395ab..5308e8ac20b04 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -86,10 +86,11 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
       bool UseMaskForCond, bool UseMaskForGaps) const override;
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
   InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 608727b745925..f4d630976d1eb 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -4803,10 +4803,9 @@ X86TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
 }
 
-InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                               TTI::TargetCostKind CostKind,
-                                               unsigned Index, const Value *Op0,
-                                               const Value *Op1) const {
+InstructionCost X86TTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    const Value *Op0, const Value *Op1, TTI::VectorInstrContext VIC) const {
   static const CostTblEntry SLMCostTbl[] = {
      { ISD::EXTRACT_VECTOR_ELT,       MVT::i8,      4 },
      { ISD::EXTRACT_VECTOR_ELT,       MVT::i16,     4 },
@@ -4948,14 +4947,15 @@ InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
     return ShuffleCost + IntOrFpCost + RegisterFileMoveCost;
   }
 
-  return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1) +
+  return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1,
+                                   VIC) +
          RegisterFileMoveCost;
 }
 
 InstructionCost X86TTIImpl::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind, bool ForPoisonSrc,
-    ArrayRef<Value *> VL) const {
+    TTI::TargetCostKind CostKind, bool ForPoisonSrc, ArrayRef<Value *> VL,
+    TTI::VectorInstrContext VIC) const {
   assert(DemandedElts.getBitWidth() ==
              cast<FixedVectorType>(Ty)->getNumElements() &&
          "Vector size mismatch");
@@ -4987,7 +4987,8 @@ InstructionCost X86TTIImpl::getScalarizationOverhead(
         continue;
       Cost += getVectorInstrCost(Instruction::InsertElement, Ty, CostKind, I,
                                  Constant::getNullValue(Ty),
-                                 VL.empty() ? nullptr : VL[I]);
+                                 VL.empty() ? nullptr : VL[I],
+                                 TTI::VectorInstrContext::None);
     }
     return Cost;
   }
@@ -5792,7 +5793,8 @@ X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
 
   // Add the final extract element to the cost.
   return ReductionCost + getVectorInstrCost(Instruction::ExtractElement, Ty,
-                                            CostKind, 0, nullptr, nullptr);
+                                            CostKind, 0, nullptr, nullptr,
+                                            TTI::VectorInstrContext::None);
 }
 
 InstructionCost X86TTIImpl::getMinMaxCost(Intrinsic::ID IID, Type *Ty,
@@ -5971,7 +5973,8 @@ X86TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *ValTy,
 
   // Add the final extract element to the cost.
   return MinMaxCost + getVectorInstrCost(Instruction::ExtractElement, Ty,
-                                         CostKind, 0, nullptr, nullptr);
+                                         CostKind, 0, nullptr, nullptr,
+                                         TTI::VectorInstrContext::None);
 }
 
 /// Calculate the cost of materializing a 64-bit value. This helper
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 4f672793a0fcf..b3dde1555d0a0 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -165,14 +165,18 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
       const Instruction *I = nullptr) const override;
   using BaseT::getVectorInstrCost;
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, const Value *Op0,
-                                     const Value *Op1) const override;
-  InstructionCost getScalarizationOverhead(
-      VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-      TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
-      ArrayRef<Value *> VL = {}) const override;
+  InstructionCost
+  getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
+                     unsigned Index, const Value *Op0, const Value *Op1,
+                     TTI::VectorInstrContext VIC =
+                         TTI::VectorInstrContext::None) const override;
+  InstructionCost
+  getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
+                           bool Insert, bool Extract,
+                           TTI::TargetCostKind CostKind,
+                           bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
+                           TTI::VectorInstrContext VIC =
+                               TTI::VectorInstrContext::None) const override;
   InstructionCost
   getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
                             const APInt &DemandedDstElts,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index ce0aab1e2bde8..eb02db5f9feeb 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5637,11 +5637,17 @@ LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I,
   if (!RetTy->isVoidTy() &&
       (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore())) {
 
+    TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None;
+    if (isa<LoadInst>(I))
+      VIC = TTI::VectorInstrContext::Load;
+    else if (isa<StoreInst>(I))
+      VIC = TTI::VectorInstrContext::Store;
+
     for (Type *VectorTy : getContainedTypes(RetTy)) {
       Cost += TTI.getScalarizationOverhead(
           cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
-          /*Insert=*/true,
-          /*Extract=*/false, CostKind);
+          /*Insert=*/true, /*Extract=*/false, CostKind,
+          /*ForPoisonSrc=*/true, {}, VIC);
     }
   }
 
@@ -5662,7 +5668,11 @@ LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I,
   SmallVector<Type *> Tys;
   for (auto *V : filterExtractingOperands(Ops, VF))
     Tys.push_back(maybeVectorizeType(V->getType(), VF));
-  return Cost + TTI.getOperandsScalarizationOverhead(Tys, CostKind);
+
+  TTI::VectorInstrContext OperandVIC = isa<StoreInst>(I)
+                                           ? TTI::VectorInstrContext::Store
+                                           : TTI::VectorInstrContext::None;
+  return Cost + TTI.getOperandsScalarizationOverhead(Tys, CostKind, OperandVIC);
 }
 
 void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index a6a46e36b397d..68a5e183ff3f3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1743,7 +1743,7 @@ VPCostContext::getOperandInfo(VPValue *V) const {
 
 InstructionCost VPCostContext::getScalarizationOverhead(
     Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF,
-    bool AlwaysIncludeReplicatingR) {
+    TTI::VectorInstrContext VIC, bool AlwaysIncludeReplicatingR) {
   if (VF.isScalar())
     return 0;
 
@@ -1757,8 +1757,8 @@ InstructionCost VPCostContext::getScalarizationOverhead(
          to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
       ScalarizationCost += TTI.getScalarizationOverhead(
           cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
-          /*Insert=*/true,
-          /*Extract=*/false, CostKind);
+          /*Insert=*/true, /*Extract=*/false, CostKind,
+          /*ForPoisonSrc=*/true, {}, VIC);
     }
   }
   // Compute the cost of scalarizing the operands, skipping ones that do not
@@ -1776,5 +1776,5 @@ InstructionCost VPCostContext::getScalarizationOverhead(
     Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF));
   }
   return ScalarizationCost +
-         TTI.getOperandsScalarizationOverhead(Tys, CostKind);
+         TTI.getOperandsScalarizationOverhead(Tys, CostKind, VIC);
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
index 973cbaa4944d6..79ef2c065f8b1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h
@@ -366,12 +366,14 @@ struct VPCostContext {
 
   /// Estimate the overhead of scalarizing a recipe with result type \p ResultTy
   /// and \p Operands with \p VF. This is a convenience wrapper for the
-  /// type-based getScalarizationOverhead API. If \p AlwaysIncludeReplicatingR
-  /// is true, always compute the cost of scalarizing replicating operands.
-  InstructionCost
-  getScalarizationOverhead(Type *ResultTy, ArrayRef<const VPValue *> Operands,
-                           ElementCount VF,
-                           bool AlwaysIncludeReplicatingR = false);
+  /// type-based getScalarizationOverhead API. \p VIC provides context about
+  /// whether the scalarization is for a load/store operation. If \p
+  /// AlwaysIncludeReplicatingR is true, always compute the cost of scalarizing
+  /// replicating operands.
+  InstructionCost getScalarizationOverhead(
+      Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF,
+      TTI::VectorInstrContext VIC = TTI::VectorInstrContext::None,
+      bool AlwaysIncludeReplicatingR = false);
 };
 
 /// This class can be used to assign names to VPValues. For VPValues without
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 701235de8b377..8656bf391ac71 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3468,8 +3468,11 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
         ResultTy = Ctx.Types.inferScalarType(this);
     }
 
+    TTI::VectorInstrContext VIC =
+        IsLoad ? TTI::VectorInstrContext::Load : TTI::VectorInstrContext::Store;
     return (ScalarCost * VF.getFixedValue()) +
-           Ctx.getScalarizationOverhead(ResultTy, OpsToScalarize, VF, true);
+           Ctx.getScalarizationOverhead(ResultTy, OpsToScalarize, VF, VIC,
+                                        true);
   }
   case Instruction::SExt:
   case Instruction::ZExt:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h
index 4e7ed1f5a4ab7..e3c2a062a8b97 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h
@@ -103,6 +103,7 @@ inline VPIRFlags getFlagsFromIndDesc(const InductionDescriptor &ID) {
          "Expected int induction");
   return VPIRFlags::WrapFlagsTy(false, false);
 }
+
 } // namespace vputils
 
 //===----------------------------------------------------------------------===//
@@ -254,7 +255,6 @@ class VPBlockUtils {
   /// Returns true if \p VPB is a loop latch, using isHeader().
   static bool isLatch(const VPBlockBase *VPB, const VPDominatorTree &VPDT);
 };
-
 } // namespace llvm
 
 #endif



More information about the llvm-commits mailing list