[llvm] [TTI] Make isLegalMasked{Load, Store} take an address space (PR #134006)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 1 16:30:05 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-hexagon

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

In order to facilitate targets that only support masked loads/stores
on certain address spaces (AMDGPU will support them in an upcoming
patch, but only for address space 7), add an AddressSpace parameter
to isLegalMaskedLoad and isLegalMaskedStore

---
Full diff: https://github.com/llvm/llvm-project/pull/134006.diff


13 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+14-8) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+4-2) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6-6) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+4-2) 
- (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp (+6-3) 
- (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.h (+4-3) 
- (modified) llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp (+4-2) 
- (modified) llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h (+4-2) 
- (modified) llvm/lib/Target/VE/VETargetTransformInfo.h (+4-2) 
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+6-4) 
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+4-2) 
- (modified) llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (+6-2) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+12-8) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 99e21aca97631..4835c66a7a3bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -791,9 +791,11 @@ class TargetTransformInfo {
                                                 ScalarEvolution *SE) const;
 
   /// Return true if the target supports masked store.
-  bool isLegalMaskedStore(Type *DataType, Align Alignment) const;
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddressSpace) const;
   /// Return true if the target supports masked load.
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment) const;
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                         unsigned AddressSpace) const;
 
   /// Return true if the target supports nontemporal store.
   bool isLegalNTStore(Type *DataType, Align Alignment) const;
@@ -2015,8 +2017,10 @@ class TargetTransformInfo::Concept {
                           TargetLibraryInfo *LibInfo) = 0;
   virtual AddressingModeKind
     getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const = 0;
-  virtual bool isLegalMaskedStore(Type *DataType, Align Alignment) = 0;
-  virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment) = 0;
+  virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                                  unsigned AddressSpace) = 0;
+  virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                                 unsigned AddressSpace) = 0;
   virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalNTLoad(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalBroadcastLoad(Type *ElementTy,
@@ -2562,11 +2566,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                ScalarEvolution *SE) const override {
     return Impl.getPreferredAddressingMode(L, SE);
   }
-  bool isLegalMaskedStore(Type *DataType, Align Alignment) override {
-    return Impl.isLegalMaskedStore(DataType, Alignment);
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddressSpace) override {
+    return Impl.isLegalMaskedStore(DataType, Alignment, AddressSpace);
   }
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment) override {
-    return Impl.isLegalMaskedLoad(DataType, Alignment);
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                         unsigned AddressSpace) override {
+    return Impl.isLegalMaskedLoad(DataType, Alignment, AddressSpace);
   }
   bool isLegalNTStore(Type *DataType, Align Alignment) override {
     return Impl.isLegalNTStore(DataType, Alignment);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 745758426c714..261d5eacc91b0 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -276,11 +276,13 @@ class TargetTransformInfoImplBase {
     return TTI::AMK_None;
   }
 
-  bool isLegalMaskedStore(Type *DataType, Align Alignment) const {
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddressSpace) const {
     return false;
   }
 
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment) const {
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                         unsigned AddressSpace) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4df551aca30a7..e3212135e9b19 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -462,14 +462,14 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
   return TTIImpl->getPreferredAddressingMode(L, SE);
 }
 
-bool TargetTransformInfo::isLegalMaskedStore(Type *DataType,
-                                             Align Alignment) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment);
+bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
+                                             unsigned AddressSpace) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
 }
 
-bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType,
-                                            Align Alignment) const {
-  return TTIImpl->isLegalMaskedLoad(DataType, Alignment);
+bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
+                                            unsigned AddressSpace) const {
+  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
 }
 
 bool TargetTransformInfo::isLegalNTStore(Type *DataType,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1b8c759fd90b4..ae0df6b895ec8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -290,11 +290,13 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return isElementTypeLegalForScalableVector(DataType->getScalarType());
   }
 
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                         unsigned /*AddressSpace*/) {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
-  bool isLegalMaskedStore(Type *DataType, Align Alignment) {
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned /*AddressSpace*/) {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 8f0db457a982e..1b134bbe5ff6a 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1122,7 +1122,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) {
   return false;
 }
 
-bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+                                   unsigned /*AddressSpace*/) {
   if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
     return false;
 
@@ -1595,9 +1596,11 @@ ARMTTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
                                   unsigned AddressSpace,
                                   TTI::TargetCostKind CostKind) {
   if (ST->hasMVEIntegerOps()) {
-    if (Opcode == Instruction::Load && isLegalMaskedLoad(Src, Alignment))
+    if (Opcode == Instruction::Load &&
+        isLegalMaskedLoad(Src, Alignment, AddressSpace))
       return ST->getMVEVectorCostFactor(CostKind);
-    if (Opcode == Instruction::Store && isLegalMaskedStore(Src, Alignment))
+    if (Opcode == Instruction::Store &&
+        isLegalMaskedStore(Src, Alignment, AddressSpace))
       return ST->getMVEVectorCostFactor(CostKind);
   }
   if (!isa<FixedVectorType>(Src))
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 103d2ed1c6281..ca5129c997fb0 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -184,10 +184,11 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool isProfitableLSRChainElement(Instruction *I);
 
-  bool isLegalMaskedLoad(Type *DataTy, Align Alignment);
+  bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace);
 
-  bool isLegalMaskedStore(Type *DataTy, Align Alignment) {
-    return isLegalMaskedLoad(DataTy, Alignment);
+  bool isLegalMaskedStore(Type *DataTy, Align Alignment,
+                          unsigned AddressSpace) {
+    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
   bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) {
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index bbb9d065b6243..c3c77b514882b 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -340,13 +340,15 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
   return 1;
 }
 
-bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/) {
+bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
+                                        unsigned /*AddressSpace*/) {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
 }
 
-bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/) {
+bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
+                                       unsigned /*AddressSpace*/) {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 826644d08d1ac..b23369ac054b9 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -157,8 +157,10 @@ class HexagonTTIImpl : public BasicTTIImplBase<HexagonTTIImpl> {
     return 1;
   }
 
-  bool isLegalMaskedStore(Type *DataType, Align Alignment);
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment);
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddressSpace);
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                         unsigned AddressSpace);
 
   /// @}
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 7a73280e76d95..f0fa01ef22912 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -133,10 +133,12 @@ class VETTIImpl : public BasicTTIImplBase<VETTIImpl> {
   }
 
   // Load & Store {
-  bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) {
+  bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment,
+                         unsigned /*AddressSpace*/) {
     return isVectorLaneType(*getLaneType(DataType));
   }
-  bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) {
+  bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment,
+                          unsigned /*AddressSpace*/) {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, MaybeAlign Alignment) {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 8bee87a22db16..7d168d33bb3e9 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5368,8 +5368,8 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
   unsigned NumElem = SrcVTy->getNumElements();
   auto *MaskTy =
       FixedVectorType::get(Type::getInt8Ty(SrcVTy->getContext()), NumElem);
-  if ((IsLoad && !isLegalMaskedLoad(SrcVTy, Alignment)) ||
-      (IsStore && !isLegalMaskedStore(SrcVTy, Alignment))) {
+  if ((IsLoad && !isLegalMaskedLoad(SrcVTy, Alignment, AddressSpace)) ||
+      (IsStore && !isLegalMaskedStore(SrcVTy, Alignment, AddressSpace))) {
     // Scalarization
     APInt DemandedElts = APInt::getAllOnes(NumElem);
     InstructionCost MaskSplitCost = getScalarizationOverhead(
@@ -6253,7 +6253,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
          ((IntWidth == 8 || IntWidth == 16) && ST->hasBWI());
 }
 
-bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+                                   unsigned AddressSpace) {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
@@ -6265,7 +6266,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
   return isLegalMaskedLoadStore(ScalarTy, ST);
 }
 
-bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment) {
+bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                    unsigned AddressSpace) {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 9a427d4388d0b..5b6204d665206 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -262,8 +262,10 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2);
   bool canMacroFuseCmp();
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment);
-  bool isLegalMaskedStore(Type *DataType, Align Alignment);
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
+                         unsigned AddressSpace);
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddressSpace);
   bool isLegalNTLoad(Type *DataType, Align Alignment);
   bool isLegalNTStore(Type *DataType, Align Alignment);
   bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const;
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 63fcc1760ccaf..e24088c294987 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1098,14 +1098,18 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       // Scalarize unsupported vector masked load
       if (TTI.isLegalMaskedLoad(
               CI->getType(),
-              cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
+              cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue(),
+              cast<PointerType>(CI->getArgOperand(0)->getType())
+                  ->getAddressSpace()))
         return false;
       scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_store:
       if (TTI.isLegalMaskedStore(
               CI->getArgOperand(0)->getType(),
-              cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
+              cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
+              cast<PointerType>(CI->getArgOperand(1)->getType())
+                  ->getAddressSpace()))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 55cc801e91452..ca77a4295f4f4 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1255,16 +1255,18 @@ class LoopVectorizationCostModel {
 
   /// Returns true if the target machine supports masked store operation
   /// for the given \p DataType and kind of access to \p Ptr.
-  bool isLegalMaskedStore(Type *DataType, Value *Ptr, Align Alignment) const {
+  bool isLegalMaskedStore(Type *DataType, Value *Ptr, Align Alignment,
+                          unsigned AddressSpace) const {
     return Legal->isConsecutivePtr(DataType, Ptr) &&
-           TTI.isLegalMaskedStore(DataType, Alignment);
+           TTI.isLegalMaskedStore(DataType, Alignment, AddressSpace);
   }
 
   /// Returns true if the target machine supports masked load operation
   /// for the given \p DataType and kind of access to \p Ptr.
-  bool isLegalMaskedLoad(Type *DataType, Value *Ptr, Align Alignment) const {
+  bool isLegalMaskedLoad(Type *DataType, Value *Ptr, Align Alignment,
+                         unsigned AddressSpace) const {
     return Legal->isConsecutivePtr(DataType, Ptr) &&
-           TTI.isLegalMaskedLoad(DataType, Alignment);
+           TTI.isLegalMaskedLoad(DataType, Alignment, AddressSpace);
   }
 
   /// Returns true if the target machine can represent \p V as a masked gather
@@ -3220,13 +3222,14 @@ bool LoopVectorizationCostModel::isScalarWithPredication(
   case Instruction::Store: {
     auto *Ptr = getLoadStorePointerOperand(I);
     auto *Ty = getLoadStoreType(I);
+    unsigned AS = getLoadStoreAddressSpace(I);
     Type *VTy = Ty;
     if (VF.isVector())
       VTy = VectorType::get(Ty, VF);
     const Align Alignment = getLoadStoreAlignment(I);
-    return isa<LoadInst>(I) ? !(isLegalMaskedLoad(Ty, Ptr, Alignment) ||
+    return isa<LoadInst>(I) ? !(isLegalMaskedLoad(Ty, Ptr, Alignment, AS) ||
                                 TTI.isLegalMaskedGather(VTy, Alignment))
-                            : !(isLegalMaskedStore(Ty, Ptr, Alignment) ||
+                            : !(isLegalMaskedStore(Ty, Ptr, Alignment, AS) ||
                                 TTI.isLegalMaskedScatter(VTy, Alignment));
   }
   case Instruction::UDiv:
@@ -3427,8 +3430,9 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(
 
   auto *Ty = getLoadStoreType(I);
   const Align Alignment = getLoadStoreAlignment(I);
-  return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty, Alignment)
-                          : TTI.isLegalMaskedStore(Ty, Alignment);
+  unsigned AS = getLoadStoreAddressSpace(I);
+  return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty, Alignment, AS)
+                          : TTI.isLegalMaskedStore(Ty, Alignment, AS);
 }
 
 bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(

``````````

</details>


https://github.com/llvm/llvm-project/pull/134006


More information about the llvm-commits mailing list