[llvm] [TTI] Use MemIntrinsicCostAttributes for getStridedOpCost (PR #170436)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 2 23:03:55 PST 2025


https://github.com/arcbbb created https://github.com/llvm/llvm-project/pull/170436

- Following #168029. This is a step toward a unified interface for masked/gather-scatter/strided/expand-compress cost modeling.
- Replace the ad-hoc parameter list with a single attributes object.

API change:
```
- InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
                                                                       bool VariableMask, Align Alignment,
                                                                       TTI::TargetCostKind CostKind,
                                                                       const Instruction *I = nullptr);
+ InstructionCost getStridedMemoryOpCost(MemIntrinsicCostAttributes,
+                                                                      CostKind);
```

Notes:
- NFCI intended: callers populate MemIntrinsicCostAttributes with same information as before.

>From f3e35ff72e45315132f7d1c72a62e8bfa8a610fb Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Tue, 2 Dec 2025 22:48:44 -0800
Subject: [PATCH] [TTI] Use MemIntrinsicCostAttributes for getStridedOpCost

---
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  6 ++--
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      | 32 +++++++------------
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 18 ++++++++---
 .../Target/RISCV/RISCVTargetTransformInfo.h   |  8 ++---
 4 files changed, 29 insertions(+), 35 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9b2a7f432a544..5f1d855621c93 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -864,10 +864,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual InstructionCost
-  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
-                         bool VariableMask, Align Alignment,
-                         TTI::TargetCostKind CostKind,
-                         const Instruction *I = nullptr) const {
+  getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const {
     return InstructionCost::getInvalid();
   }
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 314830652f0b6..fceff5f93b765 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1599,19 +1599,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                        /*IsGatherScatter*/ true, CostKind);
   }
 
-  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override {
+  InstructionCost
+  getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override {
     // For a target without strided memory operations (or for an illegal
     // operation type on one which does), assume we lower to a gather/scatter
     // operation.  (Which may in turn be scalarized.)
-    unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather
-                                               : Intrinsic::masked_scatter;
+    unsigned IID = MICA.getID() == Intrinsic::experimental_vp_strided_load
+                       ? Intrinsic::masked_gather
+                       : Intrinsic::masked_scatter;
     return thisT()->getGatherScatterOpCost(
-        MemIntrinsicCostAttributes(IID, DataTy, Ptr, VariableMask, Alignment,
-                                   I),
+        MemIntrinsicCostAttributes(IID, MICA.getDataType(), MICA.getPointer(),
+                                   MICA.getVariableMask(), MICA.getAlignment(),
+                                   MICA.getInst()),
         CostKind);
   }
 
@@ -3062,21 +3062,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
                            TTI::TargetCostKind CostKind) const override {
     unsigned Id = MICA.getID();
-    Type *DataTy = MICA.getDataType();
-    const Value *Ptr = MICA.getPointer();
-    const Instruction *I = MICA.getInst();
-    bool VariableMask = MICA.getVariableMask();
-    Align Alignment = MICA.getAlignment();
 
     switch (Id) {
     case Intrinsic::experimental_vp_strided_load:
-    case Intrinsic::experimental_vp_strided_store: {
-      unsigned Opcode = Id == Intrinsic::experimental_vp_strided_load
-                            ? Instruction::Load
-                            : Instruction::Store;
-      return thisT()->getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                             Alignment, CostKind, I);
-    }
+    case Intrinsic::experimental_vp_strided_store:
+      return thisT()->getStridedMemoryOpCost(MICA, CostKind);
     case Intrinsic::masked_scatter:
     case Intrinsic::masked_gather:
     case Intrinsic::vp_scatter:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 1d431959eaea3..c740e263da024 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1212,14 +1212,22 @@ InstructionCost RISCVTTIImpl::getExpandCompressMemoryOpCost(
          LT.first * getRISCVInstructionCost(Opcodes, LT.second, CostKind);
 }
 
-InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
-    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
-    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+InstructionCost
+RISCVTTIImpl::getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                                     TTI::TargetCostKind CostKind) const {
+
+  unsigned Opcode = MICA.getID() == Intrinsic::experimental_vp_strided_load
+                        ? Instruction::Load
+                        : Instruction::Store;
+
+  Type *DataTy = MICA.getDataType();
+  Align Alignment = MICA.getAlignment();
+  const Instruction *I = MICA.getInst();
+
   if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
        !isLegalStridedLoadStore(DataTy, Alignment)) ||
       (Opcode != Instruction::Load && Opcode != Instruction::Store))
-    return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
-                                         Alignment, CostKind, I);
+    return BaseT::getStridedMemoryOpCost(MICA, CostKind);
 
   if (CostKind == TTI::TCK_CodeSize)
     return TTI::TCC_Basic;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index e32b1c553c57a..c1746e6d13166 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -202,11 +202,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
                                 TTI::TargetCostKind CostKind) const override;
 
-  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
-                                         const Value *Ptr, bool VariableMask,
-                                         Align Alignment,
-                                         TTI::TargetCostKind CostKind,
-                                         const Instruction *I) const override;
+  InstructionCost
+  getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                         TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
   getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const override;



More information about the llvm-commits mailing list