[llvm] [TTI] Use MemIntrinsicCostAttributes for getGatherScatterOpCost (PR #168650)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 18 19:07:35 PST 2025


https://github.com/arcbbb created https://github.com/llvm/llvm-project/pull/168650

- Following from #168029. This is a step toward a unified interface for masked/gather-scatter/strided/expand-compress cost modeling.
- Replace the ad-hoc parameter list with a single attributes object.

API change:
```
- InstructionCost getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                        Alignment, CostKind, Inst);

+ InstructionCost getGatherScatterOpCost(MemIntrinsicCostAttributes,
+                                       CostKind);
```

Notes:
- NFCI intended: callers populate MemIntrinsicCostAttributes with same information as before.

>From 482785e8e1d89592829d0329f962ce8c53e0117d Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Fri, 14 Nov 2025 06:19:43 -0800
Subject: [PATCH] [TTI] Use MemIntrinsicCostAttributes for
 getGatherScatterOpCost

- Following from #168029. This is a step toward a unified interface for
masked/gather-scatter/strided/expand-compress cost modeling.
- Replace the ad-hoc parameter list with a single attributes object.

API change:
```
- InstructionCost getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                        Alignment, CostKind, Inst);

+ InstructionCost getGatherScatterOpCost(MemIntrinsicCostAttributes,
+                                       CostKind);
```

Notes:
- NFCI intended: callers populate MemIntrinsicCostAttributes with
  same information as before.
---
 .../llvm/Analysis/TargetTransformInfo.h       | 26 +++++++++++--
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  6 +--
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      | 39 ++++++++++++-------
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  7 ++--
 .../AArch64/AArch64TargetTransformInfo.cpp    | 19 ++++++---
 .../AArch64/AArch64TargetTransformInfo.h      |  6 +--
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp | 16 +++++---
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |  6 +--
 .../Hexagon/HexagonTargetTransformInfo.cpp    |  9 ++---
 .../Hexagon/HexagonTargetTransformInfo.h      |  8 ++--
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 19 +++++----
 .../Target/RISCV/RISCVTargetTransformInfo.h   |  8 ++--
 .../lib/Target/X86/X86TargetTransformInfo.cpp | 16 +++++---
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |  8 ++--
 .../Transforms/Vectorize/LoopVectorize.cpp    |  9 +++--
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 23 ++++++-----
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp |  8 +++-
 17 files changed, 141 insertions(+), 92 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index a65e4667ab76c..7a4ff1035ac04 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -125,12 +125,23 @@ struct HardwareLoopInfo {
 
 /// Information for memory intrinsic cost model.
 class MemIntrinsicCostAttributes {
+  /// Optional context instruction, if one exists, e.g. the
+  /// load/store to transform to the intrinsic.
+  const Instruction *I = nullptr;
+
+  /// Address in memory.
+  const Value *Ptr = nullptr;
+
   /// Vector type of the data to be loaded or stored.
   Type *DataTy = nullptr;
 
   /// ID of the memory intrinsic.
   Intrinsic::ID IID;
 
+  /// True when the memory access is predicated with a mask
+  /// that is not a compile-time constant.
+  bool VariableMask = true;
+
   /// Address space of the pointer.
   unsigned AddressSpace = 0;
 
@@ -143,8 +154,18 @@ class MemIntrinsicCostAttributes {
       : DataTy(DataTy), IID(Id), AddressSpace(AddressSpace),
         Alignment(Alignment) {}
 
+  LLVM_ABI MemIntrinsicCostAttributes(Intrinsic::ID Id, Type *DataTy,
+                                      const Value *Ptr, bool VariableMask,
+                                      Align Alignment,
+                                      const Instruction *I = nullptr)
+      : I(I), Ptr(Ptr), DataTy(DataTy), IID(Id), VariableMask(VariableMask),
+        Alignment(Alignment) {}
+
   Intrinsic::ID getID() const { return IID; }
+  const Instruction *getInst() const { return I; }
+  const Value *getPointer() const { return Ptr; }
   Type *getDataType() const { return DataTy; }
+  bool getVariableMask() const { return VariableMask; }
   unsigned getAddressSpace() const { return AddressSpace; }
   Align getAlignment() const { return Alignment; }
 };
@@ -1595,9 +1616,8 @@ class TargetTransformInfo {
   /// \p I - the optional original context instruction, if one exists, e.g. the
   ///        load/store to transform or the call to the gather/scatter intrinsic
   LLVM_ABI InstructionCost getGatherScatterOpCost(
-      unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-      Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
-      const Instruction *I = nullptr) const;
+      const MemIntrinsicCostAttributes &MICA,
+      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// \return The cost of Expand Load or Compress Store operation
   /// \p Opcode - is a type of memory access Load or Store
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d8e35748f53e5..2232ec065a60f 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -848,10 +848,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const {
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const {
     return 1;
   }
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index cb389ae74ef46..2f78a4d919e30 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1571,10 +1571,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const override {
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override {
+    unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
+                       MICA.getID() == Intrinsic::vp_gather)
+                          ? Instruction::Load
+                          : Instruction::Store;
+    Type *DataTy = MICA.getDataType();
+    bool VariableMask = MICA.getVariableMask();
+    Align Alignment = MICA.getAlignment();
     return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask,
                                        true, CostKind);
   }
@@ -1598,8 +1603,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // For a target without strided memory operations (or for an illegal
     // operation type on one which does), assume we lower to a gather/scatter
     // operation.  (Which may in turn be scalarized.)
-    return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                           Alignment, CostKind, I);
+    unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather
+                                               : Intrinsic::masked_scatter;
+    return thisT()->getGatherScatterOpCost(
+        {IID, DataTy, Ptr, VariableMask, Alignment, I}, CostKind);
   }
 
   InstructionCost getInterleavedMemoryOpCost(
@@ -1826,8 +1833,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
           Alignment = VPI->getPointerAlignment().valueOrOne();
         bool VarMask = isa<Constant>(ICA.getArgs()[2]);
         return thisT()->getGatherScatterOpCost(
-            Instruction::Store, ICA.getArgTypes()[0], ICA.getArgs()[1], VarMask,
-            Alignment, CostKind, nullptr);
+            {ICA.getID(), ICA.getArgTypes()[0], ICA.getArgs()[1], VarMask,
+             Alignment, nullptr},
+            CostKind);
       }
       if (ICA.getID() == Intrinsic::vp_gather) {
         if (ICA.isTypeBasedOnly()) {
@@ -1842,8 +1850,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
           Alignment = VPI->getPointerAlignment().valueOrOne();
         bool VarMask = isa<Constant>(ICA.getArgs()[1]);
         return thisT()->getGatherScatterOpCost(
-            Instruction::Load, ICA.getReturnType(), ICA.getArgs()[0], VarMask,
-            Alignment, CostKind, nullptr);
+            {ICA.getID(), ICA.getReturnType(), ICA.getArgs()[0], VarMask,
+             Alignment, nullptr},
+            CostKind);
       }
 
       if (ICA.getID() == Intrinsic::vp_select ||
@@ -1948,16 +1957,16 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       const Value *Mask = Args[2];
       bool VarMask = !isa<Constant>(Mask);
       Align Alignment = I->getParamAlign(1).valueOrOne();
-      return thisT()->getGatherScatterOpCost(Instruction::Store,
-                                             ICA.getArgTypes()[0], Args[1],
-                                             VarMask, Alignment, CostKind, I);
+      return thisT()->getGatherScatterOpCost(
+          {IID, ICA.getArgTypes()[0], Args[1], VarMask, Alignment, I},
+          CostKind);
     }
     case Intrinsic::masked_gather: {
       const Value *Mask = Args[1];
       bool VarMask = !isa<Constant>(Mask);
       Align Alignment = I->getParamAlign(0).valueOrOne();
-      return thisT()->getGatherScatterOpCost(Instruction::Load, RetTy, Args[0],
-                                             VarMask, Alignment, CostKind, I);
+      return thisT()->getGatherScatterOpCost(
+          {IID, RetTy, Args[0], VarMask, Alignment, I}, CostKind);
     }
     case Intrinsic::masked_compressstore: {
       const Value *Data = Args[0];
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 45369f0ffe137..3c52ebf8c0fd8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1191,10 +1191,9 @@ InstructionCost TargetTransformInfo::getMaskedMemoryOpCost(
 }
 
 InstructionCost TargetTransformInfo::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
-  InstructionCost Cost = TTIImpl->getGatherScatterOpCost(
-      Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
+    const MemIntrinsicCostAttributes &MICA,
+    TTI::TargetCostKind CostKind) const {
+  InstructionCost Cost = TTIImpl->getGatherScatterOpCost(MICA, CostKind);
   assert((!Cost.isValid() || Cost >= 0) &&
          "TTI should not produce negative costs!");
   return Cost;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 0bae00bafee3c..17a1ba356a6cd 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4777,12 +4777,21 @@ static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
   }
 }
 
-InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                       TTI::TargetCostKind CostKind) const {
+
+  unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
+                     MICA.getID() == Intrinsic::vp_gather)
+                        ? Instruction::Load
+                        : Instruction::Store;
+
+  Type *DataTy = MICA.getDataType();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
+
   if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
   auto *VT = cast<VectorType>(DataTy);
   auto LT = getTypeLegalizationCost(DataTy);
   if (!LT.first.isValid())
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 6cc4987428567..26163fbd52331 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -192,10 +192,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const override;
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
                           Type *Src) const;
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index d12b802fe234f..344f36700e1c9 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1693,13 +1693,19 @@ InstructionCost ARMTTIImpl::getInterleavedMemoryOpCost(
                                            UseMaskForCond, UseMaskForGaps);
 }
 
-InstructionCost ARMTTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+ARMTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                   TTI::TargetCostKind CostKind) const {
+
+  Type *DataTy = MICA.getDataType();
+  const Value *Ptr = MICA.getPointer();
+  bool VariableMask = MICA.getVariableMask();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
+
   using namespace PatternMatch;
   if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters)
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!");
   auto *VTy = cast<FixedVectorType>(DataTy);
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 919a6fc9fd0b0..2cd0ee242c328 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -284,10 +284,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
 
   InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const override;
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
   getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 8f3f0cc8abb01..e7cac3d3e95cd 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -238,11 +238,10 @@ HexagonTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
   return 1;
 }
 
-InstructionCost HexagonTTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
-  return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                       Alignment, CostKind, I);
+InstructionCost
+HexagonTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                       TTI::TargetCostKind CostKind) const {
+  return BaseT::getGatherScatterOpCost(MICA, CostKind);
 }
 
 InstructionCost HexagonTTIImpl::getInterleavedMemoryOpCost(
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e95b5a10b76a7..eea8333522a7e 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -127,11 +127,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
                  ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
                  VectorType *SubTp, ArrayRef<const Value *> Args = {},
                  const Instruction *CxtI = nullptr) const override;
-  InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 1a1a93a9cb178..f05a534160d8f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1120,19 +1120,24 @@ InstructionCost RISCVTTIImpl::getInterleavedMemoryOpCost(
   return MemCost + ShuffleCost;
 }
 
-InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+RISCVTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                     TTI::TargetCostKind CostKind) const {
+
+  bool IsLoad = MICA.getID() == Intrinsic::masked_gather ||
+                MICA.getID() == Intrinsic::vp_gather;
+  unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store;
+  Type *DataTy = MICA.getDataType();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
   if (CostKind != TTI::TCK_RecipThroughput)
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   if ((Opcode == Instruction::Load &&
        !isLegalMaskedGather(DataTy, Align(Alignment))) ||
       (Opcode == Instruction::Store &&
        !isLegalMaskedScatter(DataTy, Align(Alignment))))
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   // Cost is proportional to the number of memory operations implied.  For
   // scalable vectors, we use an estimate on that number since we don't
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 39c1173e2986c..24342111585aa 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -190,11 +190,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
 
-  InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
   getExpandCompressMemoryOpCost(unsigned Opcode, Type *Src, bool VariableMask,
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 4b77bf925b2ba..8b075996eafbe 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6258,10 +6258,15 @@ InstructionCost X86TTIImpl::getGSVectorCost(unsigned Opcode,
 }
 
 /// Calculate the cost of Gather / Scatter operation
-InstructionCost X86TTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *SrcVTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind,
-    const Instruction *I = nullptr) const {
+InstructionCost
+X86TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                   TTI::TargetCostKind CostKind) const {
+  bool IsLoad = MICA.getID() == Intrinsic::masked_gather ||
+                MICA.getID() == Intrinsic::vp_gather;
+  unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store;
+  Type *SrcVTy = MICA.getDataType();
+  const Value *Ptr = MICA.getPointer();
+  Align Alignment = MICA.getAlignment();
   if ((Opcode == Instruction::Load &&
        (!isLegalMaskedGather(SrcVTy, Align(Alignment)) ||
         forceScalarizeMaskedGather(cast<VectorType>(SrcVTy),
@@ -6270,8 +6275,7 @@ InstructionCost X86TTIImpl::getGatherScatterOpCost(
        (!isLegalMaskedScatter(SrcVTy, Align(Alignment)) ||
         forceScalarizeMaskedScatter(cast<VectorType>(SrcVTy),
                                     Align(Alignment)))))
-    return BaseT::getGatherScatterOpCost(Opcode, SrcVTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   assert(SrcVTy->isVectorTy() && "Unexpected data type for Gather/Scatter");
   PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index df1393ce16ca1..82b6c1b45aacf 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -185,11 +185,9 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   InstructionCost
   getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
                         TTI::TargetCostKind CostKind) const override;
-  InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
   InstructionCost
   getPointersChainCost(ArrayRef<const Value *> Ptrs, const Value *Base,
                        const TTI::PointersChainInfo &Info, Type *AccessTy,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index aa52f9e2a53ca..bb00ec2fca306 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5313,10 +5313,13 @@ LoopVectorizationCostModel::getGatherScatterCost(Instruction *I,
   if (!Legal->isUniform(Ptr, VF))
     PtrTy = toVectorTy(PtrTy, VF);
 
+  unsigned IID = I->getOpcode() == Instruction::Load
+                     ? Intrinsic::masked_gather
+                     : Intrinsic::masked_scatter;
   return TTI.getAddressComputationCost(PtrTy, nullptr, nullptr, CostKind) +
-         TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr,
-                                    Legal->isMaskRequired(I), Alignment,
-                                    CostKind, I);
+         TTI.getGatherScatterOpCost(
+             {IID, VectorTy, Ptr, Legal->isMaskRequired(I), Alignment, I},
+             CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e33ff724ccdd5..d23dda22d18e0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7189,9 +7189,10 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
         ScalarGEPCost;
     // The cost of masked gather.
     InstructionCost MaskedGatherCost =
-        TTI.getGatherScatterOpCost(
-            Instruction::Load, VecTy, cast<LoadInst>(VL0)->getPointerOperand(),
-            /*VariableMask=*/false, CommonAlignment, CostKind) +
+        TTI.getGatherScatterOpCost({Intrinsic::masked_gather, VecTy,
+                                    cast<LoadInst>(VL0)->getPointerOperand(),
+                                    /*VariableMask=*/false, CommonAlignment},
+                                   CostKind) +
         (ProfitableGatherPointers ? 0 : VectorGEPCost);
     InstructionCost GatherCost =
         getScalarizationOverhead(TTI, ScalarTy, VecTy, DemandedElts,
@@ -7314,11 +7315,12 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
                                         {}, CostKind);
           break;
         case LoadsState::ScatterVectorize:
-          VecLdCost += TTI.getGatherScatterOpCost(Instruction::Load, SubVecTy,
-                                                  LI0->getPointerOperand(),
-                                                  /*VariableMask=*/false,
-                                                  CommonAlignment, CostKind) +
-                       VectorGEPCost;
+          VecLdCost +=
+              TTI.getGatherScatterOpCost(
+                  {Intrinsic::masked_gather, SubVecTy, LI0->getPointerOperand(),
+                   /*VariableMask=*/false, CommonAlignment},
+                  CostKind) +
+              VectorGEPCost;
           break;
         case LoadsState::Gather:
           // Gathers are already calculated - ignore.
@@ -15124,8 +15126,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         Align CommonAlignment =
             computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
         VecLdCost = TTI->getGatherScatterOpCost(
-            Instruction::Load, VecTy, LI0->getPointerOperand(),
-            /*VariableMask=*/false, CommonAlignment, CostKind);
+            {Intrinsic::masked_gather, VecTy, LI0->getPointerOperand(),
+             /*VariableMask=*/false, CommonAlignment},
+            CostKind);
         break;
       }
       case TreeEntry::CombinedVectorize:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index e89e91b959926..de425b2e7e435 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3584,10 +3584,14 @@ InstructionCost VPWidenMemoryRecipe::computeCost(ElementCount VF,
     if (!vputils::isSingleScalar(getAddr()))
       PtrTy = toVectorTy(PtrTy, VF);
 
+    unsigned IID = isa<VPWidenLoadRecipe>(this)      ? Intrinsic::masked_gather
+                   : isa<VPWidenStoreRecipe>(this)   ? Intrinsic::masked_scatter
+                   : isa<VPWidenLoadEVLRecipe>(this) ? Intrinsic::vp_gather
+                                                     : Intrinsic::vp_scatter;
     return Ctx.TTI.getAddressComputationCost(PtrTy, nullptr, nullptr,
                                              Ctx.CostKind) +
-           Ctx.TTI.getGatherScatterOpCost(Opcode, Ty, Ptr, IsMasked, Alignment,
-                                          Ctx.CostKind, &Ingredient);
+           Ctx.TTI.getGatherScatterOpCost(
+               {IID, Ty, Ptr, IsMasked, Alignment, &Ingredient}, Ctx.CostKind);
   }
 
   InstructionCost Cost = 0;



More information about the llvm-commits mailing list