[llvm] 1c86f4a - [TTI] Use MemIntrinsicCostAttributes for getGatherScatterOpCost (#168650)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 2 19:01:40 PST 2025


Author: Shih-Po Hung
Date: 2025-12-03T03:01:35Z
New Revision: 1c86f4a8f1a254a6286342a5bffb13c99168267b

URL: https://github.com/llvm/llvm-project/commit/1c86f4a8f1a254a6286342a5bffb13c99168267b
DIFF: https://github.com/llvm/llvm-project/commit/1c86f4a8f1a254a6286342a5bffb13c99168267b.diff

LOG: [TTI] Use MemIntrinsicCostAttributes for getGatherScatterOpCost (#168650)

- Following #168029. This is a step toward a unified interface for
masked/gather-scatter/strided/expand-compress cost modeling.
- Replace the ad-hoc parameter list with a single attributes object.

API change:
```
- InstructionCost getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                        Alignment, CostKind, Inst);

+ InstructionCost getGatherScatterOpCost(MemIntrinsicCostAttributes,
+                                       CostKind);
```

Notes:
- NFCI intended: callers populate MemIntrinsicCostAttributes with same
information as before.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 624302bc6d0a3..9b2a7f432a544 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -852,10 +852,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const {
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const {
     return 1;
   }
 

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 4d9cfea5d1bab..314830652f0b6 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1571,10 +1571,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const override {
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override {
+    unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
+                       MICA.getID() == Intrinsic::vp_gather)
+                          ? Instruction::Load
+                          : Instruction::Store;
+    Type *DataTy = MICA.getDataType();
+    bool VariableMask = MICA.getVariableMask();
+    Align Alignment = MICA.getAlignment();
     return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask,
                                        true, CostKind);
   }
@@ -1602,8 +1607,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // For a target without strided memory operations (or for an illegal
     // operation type on one which does), assume we lower to a gather/scatter
     // operation.  (Which may in turn be scalarized.)
-    return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                           Alignment, CostKind, I);
+    unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather
+                                               : Intrinsic::masked_scatter;
+    return thisT()->getGatherScatterOpCost(
+        MemIntrinsicCostAttributes(IID, DataTy, Ptr, VariableMask, Alignment,
+                                   I),
+        CostKind);
   }
 
   InstructionCost getInterleavedMemoryOpCost(
@@ -3071,14 +3080,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     case Intrinsic::masked_scatter:
     case Intrinsic::masked_gather:
     case Intrinsic::vp_scatter:
-    case Intrinsic::vp_gather: {
-      unsigned Opcode =
-          (Id == Intrinsic::masked_gather || Id == Intrinsic::vp_gather)
-              ? Instruction::Load
-              : Instruction::Store;
-      return thisT()->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                             Alignment, CostKind, I);
-    }
+    case Intrinsic::vp_gather:
+      return thisT()->getGatherScatterOpCost(MICA, CostKind);
     case Intrinsic::masked_load:
     case Intrinsic::masked_store:
       return thisT()->getMaskedMemoryOpCost(MICA, CostKind);

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 3a5f1499f9d2d..8b08b30388cc2 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4794,12 +4794,21 @@ static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
   }
 }
 
-InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                       TTI::TargetCostKind CostKind) const {
+
+  unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
+                     MICA.getID() == Intrinsic::vp_gather)
+                        ? Instruction::Load
+                        : Instruction::Store;
+
+  Type *DataTy = MICA.getDataType();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
+
   if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
   auto *VT = cast<VectorType>(DataTy);
   auto LT = getTypeLegalizationCost(DataTy);
   if (!LT.first.isValid())

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe3bb5e7981d2..c4f402b7c3b8e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -192,10 +192,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const override;
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
                           Type *Src) const;

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index cecc9544ed835..bdaf9c5e7105c 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1699,13 +1699,19 @@ InstructionCost ARMTTIImpl::getInterleavedMemoryOpCost(
                                            UseMaskForCond, UseMaskForGaps);
 }
 
-InstructionCost ARMTTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+ARMTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                   TTI::TargetCostKind CostKind) const {
+
+  Type *DataTy = MICA.getDataType();
+  const Value *Ptr = MICA.getPointer();
+  bool VariableMask = MICA.getVariableMask();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
+
   using namespace PatternMatch;
   if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters)
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!");
   auto *VTy = cast<FixedVectorType>(DataTy);

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 30f2151b41239..e5f9dd5fe26d9 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -288,10 +288,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
 
   InstructionCost
-  getGatherScatterOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const override;
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
   getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 3f84cbb6555ed..c7d6ee306be84 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -238,11 +238,10 @@ HexagonTTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
   return 1;
 }
 
-InstructionCost HexagonTTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
-  return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                       Alignment, CostKind, I);
+InstructionCost
+HexagonTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                       TTI::TargetCostKind CostKind) const {
+  return BaseT::getGatherScatterOpCost(MICA, CostKind);
 }
 
 InstructionCost HexagonTTIImpl::getInterleavedMemoryOpCost(

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 67388984bb3e3..3135958ba4b50 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -127,11 +127,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
                  ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
                  VectorType *SubTp, ArrayRef<const Value *> Args = {},
                  const Instruction *CxtI = nullptr) const override;
-  InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index e697b9810e824..1d431959eaea3 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1139,19 +1139,24 @@ InstructionCost RISCVTTIImpl::getInterleavedMemoryOpCost(
   return MemCost + ShuffleCost;
 }
 
-InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+RISCVTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                     TTI::TargetCostKind CostKind) const {
+
+  bool IsLoad = MICA.getID() == Intrinsic::masked_gather ||
+                MICA.getID() == Intrinsic::vp_gather;
+  unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store;
+  Type *DataTy = MICA.getDataType();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
   if (CostKind != TTI::TCK_RecipThroughput)
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   if ((Opcode == Instruction::Load &&
        !isLegalMaskedGather(DataTy, Align(Alignment))) ||
       (Opcode == Instruction::Store &&
        !isLegalMaskedScatter(DataTy, Align(Alignment))))
-    return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   // Cost is proportional to the number of memory operations implied.  For
   // scalable vectors, we use an estimate on that number since we don't

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 2a73fe6255382..e32b1c553c57a 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -194,11 +194,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
 
-  InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
   getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index a9dc96b53d530..aa75d2c60803e 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6258,10 +6258,15 @@ InstructionCost X86TTIImpl::getGSVectorCost(unsigned Opcode,
 }
 
 /// Calculate the cost of Gather / Scatter operation
-InstructionCost X86TTIImpl::getGatherScatterOpCost(
-    unsigned Opcode, Type *SrcVTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind,
-    const Instruction *I = nullptr) const {
+InstructionCost
+X86TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                   TTI::TargetCostKind CostKind) const {
+  bool IsLoad = MICA.getID() == Intrinsic::masked_gather ||
+                MICA.getID() == Intrinsic::vp_gather;
+  unsigned Opcode = IsLoad ? Instruction::Load : Instruction::Store;
+  Type *SrcVTy = MICA.getDataType();
+  const Value *Ptr = MICA.getPointer();
+  Align Alignment = MICA.getAlignment();
   if ((Opcode == Instruction::Load &&
        (!isLegalMaskedGather(SrcVTy, Align(Alignment)) ||
         forceScalarizeMaskedGather(cast<VectorType>(SrcVTy),
@@ -6270,8 +6275,7 @@ InstructionCost X86TTIImpl::getGatherScatterOpCost(
        (!isLegalMaskedScatter(SrcVTy, Align(Alignment)) ||
         forceScalarizeMaskedScatter(cast<VectorType>(SrcVTy),
                                     Align(Alignment)))))
-    return BaseT::getGatherScatterOpCost(Opcode, SrcVTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getGatherScatterOpCost(MICA, CostKind);
 
   assert(SrcVTy->isVectorTy() && "Unexpected data type for Gather/Scatter");
   PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index d6dea9427990b..d35911965d8b5 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -185,11 +185,9 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   InstructionCost
   getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
                         TTI::TargetCostKind CostKind) const override;
-  InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
   InstructionCost
   getPointersChainCost(ArrayRef<const Value *> Ptrs, const Value *Base,
                        const TTI::PointersChainInfo &Info, Type *AccessTy,


        


More information about the llvm-commits mailing list