[llvm] [TTI] Add alignment argument to TTI for compress/expand support (PR #83516)

Nikolay Panchenko via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 5 15:11:23 PST 2024


https://github.com/npanchen updated https://github.com/llvm/llvm-project/pull/83516

>From c463571d3f417bad336db534e49a3214c75555b3 Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <kolya.panchenko at sifive.com>
Date: Thu, 29 Feb 2024 13:22:01 -0800
Subject: [PATCH 1/2] [TTI] Add alignment argument to TTI for compress/expand
 support

Since `llvm.compressstore` and `llvm.expandload` do require memory
access, it's essential for some target to check if alignment is good to
be able to lower them to target-specific instructions
---
 .../llvm/Analysis/TargetTransformInfo.h       | 16 ++++++------
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  8 ++++--
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 10 +++++---
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 25 +++++++++++++++++++
 .../Target/RISCV/RISCVTargetTransformInfo.h   |  2 ++
 .../lib/Target/X86/X86TargetTransformInfo.cpp |  6 ++---
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |  4 +--
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |  8 ++++--
 8 files changed, 58 insertions(+), 21 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..4eab357f1b33b6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -777,9 +777,9 @@ class TargetTransformInfo {
   bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const;
 
   /// Return true if the target supports masked compress store.
-  bool isLegalMaskedCompressStore(Type *DataType) const;
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const;
   /// Return true if the target supports masked expand load.
-  bool isLegalMaskedExpandLoad(Type *DataType) const;
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const;
 
   /// Return true if the target supports strided load.
   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
@@ -1863,8 +1863,8 @@ class TargetTransformInfo::Concept {
                                           Align Alignment) = 0;
   virtual bool forceScalarizeMaskedScatter(VectorType *DataType,
                                            Align Alignment) = 0;
-  virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
-  virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0;
+  virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
@@ -2358,11 +2358,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                    Align Alignment) override {
     return Impl.forceScalarizeMaskedScatter(DataType, Alignment);
   }
-  bool isLegalMaskedCompressStore(Type *DataType) override {
-    return Impl.isLegalMaskedCompressStore(DataType);
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) override {
+    return Impl.isLegalMaskedCompressStore(DataType, Alignment);
   }
-  bool isLegalMaskedExpandLoad(Type *DataType) override {
-    return Impl.isLegalMaskedExpandLoad(DataType);
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) override {
+    return Impl.isLegalMaskedExpandLoad(DataType, Alignment);
   }
   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
     return Impl.isLegalStridedLoadStore(DataType, Alignment);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 13379cc126a40c..95fb13d1c97154 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -295,14 +295,18 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
-  bool isLegalMaskedCompressStore(Type *DataType) const { return false; }
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const {
+    return false;
+  }
 
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const {
     return false;
   }
 
-  bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const {
+    return false;
+  }
 
   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
     return false;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..15311be4dba277 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -492,12 +492,14 @@ bool TargetTransformInfo::forceScalarizeMaskedScatter(VectorType *DataType,
   return TTIImpl->forceScalarizeMaskedScatter(DataType, Alignment);
 }
 
-bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const {
-  return TTIImpl->isLegalMaskedCompressStore(DataType);
+bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType,
+                                                     Align Alignment) const {
+  return TTIImpl->isLegalMaskedCompressStore(DataType, Alignment);
 }
 
-bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
-  return TTIImpl->isLegalMaskedExpandLoad(DataType);
+bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType,
+                                                  Align Alignment) const {
+  return TTIImpl->isLegalMaskedExpandLoad(DataType, Alignment);
 }
 
 bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2e4e69fb4f920f..0bd623e1196e1f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                   C2.NumIVMuls, C2.NumBaseAdds,
                   C2.ScaleCost, C2.ImmCost, C2.SetupCost);
 }
+
+bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+  auto *VTy = dyn_cast<VectorType>(DataTy);
+  if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions())
+    return false;
+
+  Type *ScalarTy = VTy->getScalarType();
+  if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
+    return true;
+
+  if (!ScalarTy->isIntegerTy())
+    return false;
+
+  switch (ScalarTy->getIntegerBitWidth()) {
+  case 8:
+  case 16:
+  case 32:
+  case 64:
+    break;
+  default:
+    return false;
+  }
+
+  return getRegUsageForType(VTy) <= 8;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index af36e9d5d5e886..8daf6845dc8bc9 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -261,6 +261,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
   }
 
+  bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
+
   bool isVScaleKnownToBeAPowerOfTwo() const {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 18bf32fe1acaad..9c1e4b2f83ab7f 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5938,7 +5938,7 @@ bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy,
          ElementTy == Type::getDoubleTy(ElementTy->getContext());
 }
 
-bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
+bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
   if (!isa<VectorType>(DataTy))
     return false;
 
@@ -5962,8 +5962,8 @@ bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
          ((IntWidth == 8 || IntWidth == 16) && ST->hasVBMI2());
 }
 
-bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy) {
-  return isLegalMaskedExpandLoad(DataTy);
+bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+  return isLegalMaskedExpandLoad(DataTy, Alignment);
 }
 
 bool X86TTIImpl::supportsGather() const {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 07a3fff4f84b3e..1a5e6bc886aa67 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -269,8 +269,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment);
   bool isLegalMaskedGather(Type *DataType, Align Alignment);
   bool isLegalMaskedScatter(Type *DataType, Align Alignment);
-  bool isLegalMaskedExpandLoad(Type *DataType);
-  bool isLegalMaskedCompressStore(Type *DataType);
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment);
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const;
   bool hasDivRemOp(Type *DataType, bool IsSigned);
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index c01d03f6447240..d545c0ae49f5a1 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -969,12 +969,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       return true;
     }
     case Intrinsic::masked_expandload:
-      if (TTI.isLegalMaskedExpandLoad(CI->getType()))
+      if (TTI.isLegalMaskedExpandLoad(
+              CI->getType(),
+              CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
         return false;
       scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_compressstore:
-      if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
+      if (TTI.isLegalMaskedCompressStore(
+              CI->getArgOperand(0)->getType(),
+              CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
         return false;
       scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
       return true;

>From 98ca9a958046b50cac4affa40bdd5a7fdf30828d Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <kolya.panchenko at sifive.com>
Date: Fri, 1 Mar 2024 05:48:14 -0800
Subject: [PATCH 2/2] removed isLegalMaskedCompressStore from RISCVTTIImpl

---
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 25 -------------------
 .../Target/RISCV/RISCVTargetTransformInfo.h   |  2 --
 2 files changed, 27 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 0bd623e1196e1f..2e4e69fb4f920f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1609,28 +1609,3 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                   C2.NumIVMuls, C2.NumBaseAdds,
                   C2.ScaleCost, C2.ImmCost, C2.SetupCost);
 }
-
-bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
-  auto *VTy = dyn_cast<VectorType>(DataTy);
-  if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions())
-    return false;
-
-  Type *ScalarTy = VTy->getScalarType();
-  if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
-    return true;
-
-  if (!ScalarTy->isIntegerTy())
-    return false;
-
-  switch (ScalarTy->getIntegerBitWidth()) {
-  case 8:
-  case 16:
-  case 32:
-  case 64:
-    break;
-  default:
-    return false;
-  }
-
-  return getRegUsageForType(VTy) <= 8;
-}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 8daf6845dc8bc9..af36e9d5d5e886 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -261,8 +261,6 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
   }
 
-  bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
-
   bool isVScaleKnownToBeAPowerOfTwo() const {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }



More information about the llvm-commits mailing list