[llvm] c2409b4 - [TTI] Remove masked/gather-scatter/strided/expand-compress costing from TTIImpl (#169885)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 3 17:34:34 PST 2025


Author: Shih-Po Hung
Date: 2025-12-04T01:34:29Z
New Revision: c2409b4bcaca6662f86f7c9956f27413a7aecf0e

URL: https://github.com/llvm/llvm-project/commit/c2409b4bcaca6662f86f7c9956f27413a7aecf0e
DIFF: https://github.com/llvm/llvm-project/commit/c2409b4bcaca6662f86f7c9956f27413a7aecf0e.diff

LOG: [TTI] Remove masked/gather-scatter/strided/expand-compress costing from TTIImpl (#169885)

Following #165532, this patch moves scalarization‑cost computation into
BaseT::getMemIntrinsicCost and lets backends override it via their
getMemIntrinsicCost.
It also removes the masked/gather‑scatter/strided/expand‑compress
costing interfaces from TTIImpl.
Targets may keep them locally if needed.

Stacked on #170426 and #170436.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 5f1d855621c93..835eb7701ccfa 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -845,30 +845,6 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  virtual InstructionCost
-  getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                        TTI::TargetCostKind CostKind) const {
-    return 1;
-  }
-
-  virtual InstructionCost
-  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const {
-    return 1;
-  }
-
-  virtual InstructionCost
-  getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                                TTI::TargetCostKind CostKind) const {
-    return 1;
-  }
-
-  virtual InstructionCost
-  getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const {
-    return InstructionCost::getInvalid();
-  }
-
   virtual InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
@@ -928,8 +904,20 @@ class TargetTransformInfoImplBase {
   virtual InstructionCost
   getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
                            TTI::TargetCostKind CostKind) const {
-    return 1;
+    switch (MICA.getID()) {
+    case Intrinsic::masked_scatter:
+    case Intrinsic::masked_gather:
+    case Intrinsic::masked_load:
+    case Intrinsic::masked_store:
+    case Intrinsic::vp_scatter:
+    case Intrinsic::vp_gather:
+    case Intrinsic::masked_compressstore:
+    case Intrinsic::masked_expandload:
+      return 1;
+    }
+    return InstructionCost::getInvalid();
   }
+
   virtual InstructionCost getCallInstrCost(Function *F, Type *RetTy,
                                            ArrayRef<Type *> Tys,
                                            TTI::TargetCostKind CostKind) const {

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index fceff5f93b765..494199835a19c 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1557,64 +1557,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return Cost;
   }
 
-  InstructionCost
-  getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                        TTI::TargetCostKind CostKind) const override {
-    Type *DataTy = MICA.getDataType();
-    Align Alignment = MICA.getAlignment();
-    unsigned Opcode = MICA.getID() == Intrinsic::masked_load
-                          ? Instruction::Load
-                          : Instruction::Store;
-    // TODO: Pass on AddressSpace when we have test coverage.
-    return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, true, false,
-                                       CostKind);
-  }
-
-  InstructionCost
-  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override {
-    unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
-                       MICA.getID() == Intrinsic::vp_gather)
-                          ? Instruction::Load
-                          : Instruction::Store;
-    Type *DataTy = MICA.getDataType();
-    bool VariableMask = MICA.getVariableMask();
-    Align Alignment = MICA.getAlignment();
-    return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask,
-                                       true, CostKind);
-  }
-
-  InstructionCost
-  getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                                TTI::TargetCostKind CostKind) const override {
-    unsigned Opcode = MICA.getID() == Intrinsic::masked_expandload
-                          ? Instruction::Load
-                          : Instruction::Store;
-    Type *DataTy = MICA.getDataType();
-    bool VariableMask = MICA.getVariableMask();
-    Align Alignment = MICA.getAlignment();
-    // Treat expand load/compress store as gather/scatter operation.
-    // TODO: implement more precise cost estimation for these intrinsics.
-    return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, VariableMask,
-                                       /*IsGatherScatter*/ true, CostKind);
-  }
-
-  InstructionCost
-  getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override {
-    // For a target without strided memory operations (or for an illegal
-    // operation type on one which does), assume we lower to a gather/scatter
-    // operation.  (Which may in turn be scalarized.)
-    unsigned IID = MICA.getID() == Intrinsic::experimental_vp_strided_load
-                       ? Intrinsic::masked_gather
-                       : Intrinsic::masked_scatter;
-    return thisT()->getGatherScatterOpCost(
-        MemIntrinsicCostAttributes(IID, MICA.getDataType(), MICA.getPointer(),
-                                   MICA.getVariableMask(), MICA.getAlignment(),
-                                   MICA.getInst()),
-        CostKind);
-  }
-
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
@@ -3062,22 +3004,56 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
                            TTI::TargetCostKind CostKind) const override {
     unsigned Id = MICA.getID();
+    Type *DataTy = MICA.getDataType();
+    bool VariableMask = MICA.getVariableMask();
+    Align Alignment = MICA.getAlignment();
 
     switch (Id) {
     case Intrinsic::experimental_vp_strided_load:
-    case Intrinsic::experimental_vp_strided_store:
-      return thisT()->getStridedMemoryOpCost(MICA, CostKind);
+    case Intrinsic::experimental_vp_strided_store: {
+      unsigned Opcode = Id == Intrinsic::experimental_vp_strided_load
+                            ? Instruction::Load
+                            : Instruction::Store;
+      // For a target without strided memory operations (or for an illegal
+      // operation type on one which does), assume we lower to a gather/scatter
+      // operation.  (Which may in turn be scalarized.)
+      return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment,
+                                         VariableMask, true, CostKind);
+    }
     case Intrinsic::masked_scatter:
     case Intrinsic::masked_gather:
     case Intrinsic::vp_scatter:
-    case Intrinsic::vp_gather:
-      return thisT()->getGatherScatterOpCost(MICA, CostKind);
+    case Intrinsic::vp_gather: {
+      unsigned Opcode = (MICA.getID() == Intrinsic::masked_gather ||
+                         MICA.getID() == Intrinsic::vp_gather)
+                            ? Instruction::Load
+                            : Instruction::Store;
+
+      return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment,
+                                         VariableMask, true, CostKind);
+    }
+    case Intrinsic::vp_load:
+    case Intrinsic::vp_store:
+      return InstructionCost::getInvalid();
     case Intrinsic::masked_load:
-    case Intrinsic::masked_store:
-      return thisT()->getMaskedMemoryOpCost(MICA, CostKind);
+    case Intrinsic::masked_store: {
+      unsigned Opcode =
+          Id == Intrinsic::masked_load ? Instruction::Load : Instruction::Store;
+      // TODO: Pass on AddressSpace when we have test coverage.
+      return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment, true, false,
+                                         CostKind);
+    }
     case Intrinsic::masked_compressstore:
-    case Intrinsic::masked_expandload:
-      return thisT()->getExpandCompressMemoryOpCost(MICA, CostKind);
+    case Intrinsic::masked_expandload: {
+      unsigned Opcode = MICA.getID() == Intrinsic::masked_expandload
+                            ? Instruction::Load
+                            : Instruction::Store;
+      // Treat expand load/compress store as gather/scatter operation.
+      // TODO: implement more precise cost estimation for these intrinsics.
+      return getCommonMaskedMemoryOpCost(Opcode, DataTy, Alignment,
+                                         VariableMask,
+                                         /*IsGatherScatter*/ true, CostKind);
+    }
     case Intrinsic::vp_load_ff:
       return InstructionCost::getInvalid();
     default:

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 8b08b30388cc2..e8bbce202b407 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4746,13 +4746,27 @@ bool AArch64TTIImpl::prefersVectorizedAddressing() const {
   return ST->hasSVE();
 }
 
+InstructionCost
+AArch64TTIImpl::getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
+                                         TTI::TargetCostKind CostKind) const {
+  switch (MICA.getID()) {
+  case Intrinsic::masked_scatter:
+  case Intrinsic::masked_gather:
+    return getGatherScatterOpCost(MICA, CostKind);
+  case Intrinsic::masked_load:
+  case Intrinsic::masked_store:
+    return getMaskedMemoryOpCost(MICA, CostKind);
+  }
+  return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
+}
+
 InstructionCost
 AArch64TTIImpl::getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
                                       TTI::TargetCostKind CostKind) const {
   Type *Src = MICA.getDataType();
 
   if (useNeonVector(Src))
-    return BaseT::getMaskedMemoryOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
   auto LT = getTypeLegalizationCost(Src);
   if (!LT.first.isValid())
     return InstructionCost::getInvalid();
@@ -4808,7 +4822,7 @@ AArch64TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
   const Instruction *I = MICA.getInst();
 
   if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
-    return BaseT::getGatherScatterOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
   auto *VT = cast<VectorType>(DataTy);
   auto LT = getTypeLegalizationCost(DataTy);
   if (!LT.first.isValid())

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index c4f402b7c3b8e..7ca8482e152d1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -188,12 +188,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                                                   unsigned Opcode2) const;
 
   InstructionCost
-  getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                        TTI::TargetCostKind CostKind) const override;
+  getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
+                           TTI::TargetCostKind CostKind) const override;
 
-  InstructionCost
-  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override;
+  InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                                        TTI::TargetCostKind CostKind) const;
+
+  InstructionCost getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                         TTI::TargetCostKind CostKind) const;
 
   bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
                           Type *Src) const;

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index bdaf9c5e7105c..88a7fb185bf16 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1636,6 +1636,20 @@ InstructionCost ARMTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
                                            CostKind, OpInfo, I);
 }
 
+InstructionCost
+ARMTTIImpl::getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
+                                     TTI::TargetCostKind CostKind) const {
+  switch (MICA.getID()) {
+  case Intrinsic::masked_scatter:
+  case Intrinsic::masked_gather:
+    return getGatherScatterOpCost(MICA, CostKind);
+  case Intrinsic::masked_load:
+  case Intrinsic::masked_store:
+    return getMaskedMemoryOpCost(MICA, CostKind);
+  }
+  return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
+}
+
 InstructionCost
 ARMTTIImpl::getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
                                   TTI::TargetCostKind CostKind) const {
@@ -1652,7 +1666,7 @@ ARMTTIImpl::getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
       return ST->getMVEVectorCostFactor(CostKind);
   }
   if (!isa<FixedVectorType>(Src))
-    return BaseT::getMaskedMemoryOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
   // Scalar cost, which is currently very high due to the efficiency of the
   // generated code.
   return cast<FixedVectorType>(Src)->getNumElements() * 8;
@@ -1711,7 +1725,7 @@ ARMTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
 
   using namespace PatternMatch;
   if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters)
-    return BaseT::getGatherScatterOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 
   assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!");
   auto *VTy = cast<FixedVectorType>(DataTy);

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index e5f9dd5fe26d9..a23256364dd9a 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -279,17 +279,19 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
       const Instruction *I = nullptr) const override;
 
   InstructionCost
-  getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                        TTI::TargetCostKind CostKind) const override;
+  getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
+                           TTI::TargetCostKind CostKind) const override;
+
+  InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                                        TTI::TargetCostKind CostKind) const;
 
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
 
-  InstructionCost
-  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override;
+  InstructionCost getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                         TTI::TargetCostKind CostKind) const;
 
   InstructionCost
   getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 74c2c896a8a88..aedd7f124cef5 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1022,6 +1022,22 @@ RISCVTTIImpl::getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
     return getMemoryOpCost(Instruction::Load, DataTy, Alignment, AS, CostKind,
                            {TTI::OK_AnyValue, TTI::OP_None}, nullptr);
   }
+  case Intrinsic::experimental_vp_strided_load:
+  case Intrinsic::experimental_vp_strided_store:
+    return getStridedMemoryOpCost(MICA, CostKind);
+  case Intrinsic::masked_compressstore:
+  case Intrinsic::masked_expandload:
+    return getExpandCompressMemoryOpCost(MICA, CostKind);
+  case Intrinsic::vp_scatter:
+  case Intrinsic::vp_gather:
+  case Intrinsic::masked_scatter:
+  case Intrinsic::masked_gather:
+    return getGatherScatterOpCost(MICA, CostKind);
+  case Intrinsic::vp_load:
+  case Intrinsic::vp_store:
+  case Intrinsic::masked_load:
+  case Intrinsic::masked_store:
+    return getMaskedMemoryOpCost(MICA, CostKind);
   }
   return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 }
@@ -1037,7 +1053,7 @@ RISCVTTIImpl::getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
 
   if (!isLegalMaskedLoadStore(Src, Alignment) ||
       CostKind != TTI::TCK_RecipThroughput)
-    return BaseT::getMaskedMemoryOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 
   return getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, CostKind);
 }
@@ -1150,13 +1166,13 @@ RISCVTTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
   Align Alignment = MICA.getAlignment();
   const Instruction *I = MICA.getInst();
   if (CostKind != TTI::TCK_RecipThroughput)
-    return BaseT::getGatherScatterOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 
   if ((Opcode == Instruction::Load &&
        !isLegalMaskedGather(DataTy, Align(Alignment))) ||
       (Opcode == Instruction::Store &&
        !isLegalMaskedScatter(DataTy, Align(Alignment))))
-    return BaseT::getGatherScatterOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 
   // Cost is proportional to the number of memory operations implied.  For
   // scalable vectors, we use an estimate on that number since we don't
@@ -1183,7 +1199,7 @@ InstructionCost RISCVTTIImpl::getExpandCompressMemoryOpCost(
                  (Opcode == Instruction::Load &&
                   isLegalMaskedExpandLoad(DataTy, Alignment));
   if (!IsLegal || CostKind != TTI::TCK_RecipThroughput)
-    return BaseT::getExpandCompressMemoryOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
   // Example compressstore sequence:
   // vsetivli        zero, 8, e32, m2, ta, ma (ignored)
   // vcompress.vm    v10, v8, v0
@@ -1225,7 +1241,7 @@ RISCVTTIImpl::getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
   const Instruction *I = MICA.getInst();
 
   if (!isLegalStridedLoadStore(DataTy, Alignment))
-    return BaseT::getStridedMemoryOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 
   if (CostKind == TTI::TCK_CodeSize)
     return TTI::TCC_Basic;

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index c1746e6d13166..e6b75d7c458db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -147,9 +147,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
                            TTI::TargetCostKind CostKind) const override;
 
-  InstructionCost
-  getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                        TTI::TargetCostKind CostKind) const override;
+  InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                                        TTI::TargetCostKind CostKind) const;
 
   InstructionCost
   getPointersChainCost(ArrayRef<const Value *> Ptrs, const Value *Base,
@@ -194,17 +193,15 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
 
-  InstructionCost
-  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override;
+  InstructionCost getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                         TTI::TargetCostKind CostKind) const;
 
   InstructionCost
   getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                                TTI::TargetCostKind CostKind) const override;
+                                TTI::TargetCostKind CostKind) const;
 
-  InstructionCost
-  getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override;
+  InstructionCost getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                                         TTI::TargetCostKind CostKind) const;
 
   InstructionCost
   getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const override;

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index aa75d2c60803e..9fb97918cb71a 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5410,6 +5410,20 @@ InstructionCost X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
   return Cost;
 }
 
+InstructionCost
+X86TTIImpl::getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
+                                     TTI::TargetCostKind CostKind) const {
+  switch (MICA.getID()) {
+  case Intrinsic::masked_scatter:
+  case Intrinsic::masked_gather:
+    return getGatherScatterOpCost(MICA, CostKind);
+  case Intrinsic::masked_load:
+  case Intrinsic::masked_store:
+    return getMaskedMemoryOpCost(MICA, CostKind);
+  }
+  return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
+}
+
 InstructionCost
 X86TTIImpl::getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
                                   TTI::TargetCostKind CostKind) const {
@@ -6275,7 +6289,7 @@ X86TTIImpl::getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
        (!isLegalMaskedScatter(SrcVTy, Align(Alignment)) ||
         forceScalarizeMaskedScatter(cast<VectorType>(SrcVTy),
                                     Align(Alignment)))))
-    return BaseT::getGatherScatterOpCost(MICA, CostKind);
+    return BaseT::getMemIntrinsicInstrCost(MICA, CostKind);
 
   assert(SrcVTy->isVectorTy() && "Unexpected data type for Gather/Scatter");
   PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index d35911965d8b5..4f672793a0fcf 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -183,11 +183,12 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
       TTI::OperandValueInfo OpInfo = {TTI::OK_AnyValue, TTI::OP_None},
       const Instruction *I = nullptr) const override;
   InstructionCost
-  getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
-                        TTI::TargetCostKind CostKind) const override;
-  InstructionCost
-  getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
-                         TTI::TargetCostKind CostKind) const override;
+  getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
+                           TTI::TargetCostKind CostKind) const override;
+  InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
+                                        TTI::TargetCostKind CostKind) const;
+  InstructionCost getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
+                                         TTI::TargetCostKind CostKind) const;
   InstructionCost
   getPointersChainCost(ArrayRef<const Value *> Ptrs, const Value *Base,
                        const TTI::PointersChainInfo &Info, Type *AccessTy,


        


More information about the llvm-commits mailing list