[llvm] [TTI] Add alignment argument to TTI for compress/expand support (PR #83516)
Kolya Panchenko via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 1 05:48:59 PST 2024
https://github.com/nikolaypanchenko updated https://github.com/llvm/llvm-project/pull/83516
>From c463571d3f417bad336db534e49a3214c75555b3 Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <kolya.panchenko at sifive.com>
Date: Thu, 29 Feb 2024 13:22:01 -0800
Subject: [PATCH 1/2] [TTI] Add alignment argument to TTI for compress/expand
support
Since `llvm.compressstore` and `llvm.expandload` do require memory
access, it's essential for some target to check if alignment is good to
be able to lower them to target-specific instructions
---
.../llvm/Analysis/TargetTransformInfo.h | 16 ++++++------
.../llvm/Analysis/TargetTransformInfoImpl.h | 8 ++++--
llvm/lib/Analysis/TargetTransformInfo.cpp | 10 +++++---
.../Target/RISCV/RISCVTargetTransformInfo.cpp | 25 +++++++++++++++++++
.../Target/RISCV/RISCVTargetTransformInfo.h | 2 ++
.../lib/Target/X86/X86TargetTransformInfo.cpp | 6 ++---
llvm/lib/Target/X86/X86TargetTransformInfo.h | 4 +--
.../Scalar/ScalarizeMaskedMemIntrin.cpp | 8 ++++--
8 files changed, 58 insertions(+), 21 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..4eab357f1b33b6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -777,9 +777,9 @@ class TargetTransformInfo {
bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const;
/// Return true if the target supports masked compress store.
- bool isLegalMaskedCompressStore(Type *DataType) const;
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const;
/// Return true if the target supports masked expand load.
- bool isLegalMaskedExpandLoad(Type *DataType) const;
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const;
/// Return true if the target supports strided load.
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
@@ -1863,8 +1863,8 @@ class TargetTransformInfo::Concept {
Align Alignment) = 0;
virtual bool forceScalarizeMaskedScatter(VectorType *DataType,
Align Alignment) = 0;
- virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
- virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+ virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0;
+ virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0;
virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
unsigned Opcode1,
@@ -2358,11 +2358,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
Align Alignment) override {
return Impl.forceScalarizeMaskedScatter(DataType, Alignment);
}
- bool isLegalMaskedCompressStore(Type *DataType) override {
- return Impl.isLegalMaskedCompressStore(DataType);
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) override {
+ return Impl.isLegalMaskedCompressStore(DataType, Alignment);
}
- bool isLegalMaskedExpandLoad(Type *DataType) override {
- return Impl.isLegalMaskedExpandLoad(DataType);
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) override {
+ return Impl.isLegalMaskedExpandLoad(DataType, Alignment);
}
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
return Impl.isLegalStridedLoadStore(DataType, Alignment);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 13379cc126a40c..95fb13d1c97154 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -295,14 +295,18 @@ class TargetTransformInfoImplBase {
return false;
}
- bool isLegalMaskedCompressStore(Type *DataType) const { return false; }
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const {
+ return false;
+ }
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const {
return false;
}
- bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const {
+ return false;
+ }
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
return false;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..15311be4dba277 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -492,12 +492,14 @@ bool TargetTransformInfo::forceScalarizeMaskedScatter(VectorType *DataType,
return TTIImpl->forceScalarizeMaskedScatter(DataType, Alignment);
}
-bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const {
- return TTIImpl->isLegalMaskedCompressStore(DataType);
+bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalMaskedCompressStore(DataType, Alignment);
}
-bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
- return TTIImpl->isLegalMaskedExpandLoad(DataType);
+bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalMaskedExpandLoad(DataType, Alignment);
}
bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2e4e69fb4f920f..0bd623e1196e1f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
C2.NumIVMuls, C2.NumBaseAdds,
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
}
+
+bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+ auto *VTy = dyn_cast<VectorType>(DataTy);
+ if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions())
+ return false;
+
+ Type *ScalarTy = VTy->getScalarType();
+ if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
+ return true;
+
+ if (!ScalarTy->isIntegerTy())
+ return false;
+
+ switch (ScalarTy->getIntegerBitWidth()) {
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ break;
+ default:
+ return false;
+ }
+
+ return getRegUsageForType(VTy) <= 8;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index af36e9d5d5e886..8daf6845dc8bc9 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -261,6 +261,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
}
+ bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
+
bool isVScaleKnownToBeAPowerOfTwo() const {
return TLI->isVScaleKnownToBeAPowerOfTwo();
}
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 18bf32fe1acaad..9c1e4b2f83ab7f 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5938,7 +5938,7 @@ bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy,
ElementTy == Type::getDoubleTy(ElementTy->getContext());
}
-bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
+bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
if (!isa<VectorType>(DataTy))
return false;
@@ -5962,8 +5962,8 @@ bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
((IntWidth == 8 || IntWidth == 16) && ST->hasVBMI2());
}
-bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy) {
- return isLegalMaskedExpandLoad(DataTy);
+bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+ return isLegalMaskedExpandLoad(DataTy, Alignment);
}
bool X86TTIImpl::supportsGather() const {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 07a3fff4f84b3e..1a5e6bc886aa67 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -269,8 +269,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment);
bool isLegalMaskedGather(Type *DataType, Align Alignment);
bool isLegalMaskedScatter(Type *DataType, Align Alignment);
- bool isLegalMaskedExpandLoad(Type *DataType);
- bool isLegalMaskedCompressStore(Type *DataType);
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment);
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const;
bool hasDivRemOp(Type *DataType, bool IsSigned);
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index c01d03f6447240..d545c0ae49f5a1 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -969,12 +969,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
return true;
}
case Intrinsic::masked_expandload:
- if (TTI.isLegalMaskedExpandLoad(CI->getType()))
+ if (TTI.isLegalMaskedExpandLoad(
+ CI->getType(),
+ CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
return false;
scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_compressstore:
- if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
+ if (TTI.isLegalMaskedCompressStore(
+ CI->getArgOperand(0)->getType(),
+ CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
return false;
scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
return true;
>From 98ca9a958046b50cac4affa40bdd5a7fdf30828d Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <kolya.panchenko at sifive.com>
Date: Fri, 1 Mar 2024 05:48:14 -0800
Subject: [PATCH 2/2] removed isLegalMaskedCompressStore from RISCVTTIImpl
---
.../Target/RISCV/RISCVTargetTransformInfo.cpp | 25 -------------------
.../Target/RISCV/RISCVTargetTransformInfo.h | 2 --
2 files changed, 27 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 0bd623e1196e1f..2e4e69fb4f920f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1609,28 +1609,3 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
C2.NumIVMuls, C2.NumBaseAdds,
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
}
-
-bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
- auto *VTy = dyn_cast<VectorType>(DataTy);
- if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions())
- return false;
-
- Type *ScalarTy = VTy->getScalarType();
- if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
- return true;
-
- if (!ScalarTy->isIntegerTy())
- return false;
-
- switch (ScalarTy->getIntegerBitWidth()) {
- case 8:
- case 16:
- case 32:
- case 64:
- break;
- default:
- return false;
- }
-
- return getRegUsageForType(VTy) <= 8;
-}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 8daf6845dc8bc9..af36e9d5d5e886 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -261,8 +261,6 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
}
- bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
-
bool isVScaleKnownToBeAPowerOfTwo() const {
return TLI->isVScaleKnownToBeAPowerOfTwo();
}
More information about the llvm-commits
mailing list