[llvm] [TTI]Add support for strided loads/stores. (PR #80329)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 1 12:46:39 PST 2024
https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/80329
>From 0eea53562c0ffeef2771c953634a716e27ca97b5 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 1 Feb 2024 19:42:04 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
=?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.5
---
.../llvm/Analysis/TargetTransformInfo.h | 34 +++++++++++++++++++
.../llvm/Analysis/TargetTransformInfoImpl.h | 13 +++++++
llvm/lib/Analysis/TargetTransformInfo.cpp | 14 ++++++++
.../Target/RISCV/RISCVTargetTransformInfo.cpp | 24 +++++++++++++
.../Target/RISCV/RISCVTargetTransformInfo.h | 11 ++++++
5 files changed, 96 insertions(+)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 3b615bc700bbb..58577a6b6eb5c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -781,6 +781,9 @@ class TargetTransformInfo {
/// Return true if the target supports masked expand load.
bool isLegalMaskedExpandLoad(Type *DataType) const;
+ /// Return true if the target supports strided load.
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
+
/// Return true if this is an alternating opcode pattern that can be lowered
/// to a single instruction on the target. In X86 this is for the addsub
/// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1412,6 +1415,20 @@ class TargetTransformInfo {
Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
const Instruction *I = nullptr) const;
+ /// \return The cost of strided memory operations.
+ /// \p Opcode - is a type of memory access Load or Store
+ /// \p DataTy - a vector type of the data to be loaded or stored
+ /// \p Ptr - pointer [or vector of pointers] - address[es] in memory
+ /// \p VariableMask - true when the memory access is predicated with a mask
+ /// that is not a compile-time constant
+ /// \p Alignment - alignment of single element
+ /// \p I - the optional original context instruction, if one exists, e.g. the
+ /// load/store to transform or the call to the gather/scatter intrinsic
+ InstructionCost getStridedMemoryOpCost(
+ unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+ Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+ const Instruction *I = nullptr) const;
+
/// \return The cost of the interleaved memory operation.
/// \p Opcode is the memory operation code
/// \p VecTy is the vector type of the interleaved access.
@@ -1848,6 +1865,7 @@ class TargetTransformInfo::Concept {
Align Alignment) = 0;
virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+ virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
unsigned Opcode1,
const SmallBitVector &OpcodeMask) const = 0;
@@ -2023,6 +2041,11 @@ class TargetTransformInfo::Concept {
bool VariableMask, Align Alignment,
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr) = 0;
+ virtual InstructionCost
+ getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+ bool VariableMask, Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr) = 0;
virtual InstructionCost getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
@@ -2341,6 +2364,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
bool isLegalMaskedExpandLoad(Type *DataType) override {
return Impl.isLegalMaskedExpandLoad(DataType);
}
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
+ return Impl.isLegalStridedLoadStore(DataType, Alignment);
+ }
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const override {
return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
@@ -2671,6 +2697,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
Alignment, CostKind, I);
}
+ InstructionCost
+ getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+ bool VariableMask, Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr) override {
+ return Impl.getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment, CostKind, I);
+ }
InstructionCost getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9958b4daa6ed8..1fe126379ac75 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -304,6 +304,10 @@ class TargetTransformInfoImplBase {
bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
+ return false;
+ }
+
bool enableOrderedReductions() const { return false; }
bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
@@ -687,6 +691,15 @@ class TargetTransformInfoImplBase {
return 1;
}
+ InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+ const Value *Ptr, bool VariableMask,
+ Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr) const {
+ return CostKind == TTI::TCK_RecipThroughput ? TTI::TCC_Expensive
+ : TTI::TCC_Basic;
+ }
+
unsigned getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8902dde37cbca..8158c519562ee 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -500,6 +500,11 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
return TTIImpl->isLegalMaskedExpandLoad(DataType);
}
+bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalStridedLoadStore(DataType, Alignment);
+}
+
bool TargetTransformInfo::enableOrderedReductions() const {
return TTIImpl->enableOrderedReductions();
}
@@ -1041,6 +1046,15 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost(
return Cost;
}
+InstructionCost TargetTransformInfo::getStridedMemoryOpCost(
+ unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+ Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+ InstructionCost Cost = TTIImpl->getStridedMemoryOpCost(
+ Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
InstructionCost TargetTransformInfo::getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index fe1cdb2dfa423..3e349458ec0db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -658,6 +658,30 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
return NumLoads * MemOpCost;
}
+InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
+ unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+ Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment, CostKind, I);
+
+ if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
+ !isLegalStridedLoadStore(DataTy, Alignment)) ||
+ (Opcode != Instruction::Load && Opcode != Instruction::Store))
+ return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment, CostKind, I);
+
+ // Cost is proportional to the number of memory operations implied. For
+ // scalable vectors, we use an estimate on that number since we don't
+ // know exactly what VL will be.
+ auto &VTy = *cast<VectorType>(DataTy);
+ InstructionCost MemOpCost =
+ getMemoryOpCost(Opcode, VTy.getElementType(), Alignment, 0, CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None}, I);
+ unsigned NumLoads = getEstimatedVLFor(&VTy);
+ return NumLoads * MemOpCost;
+}
+
// Currently, these represent both throughput and codesize costs
// for the respective intrinsics. The costs in this table are simply
// instruction counts with the following adjustments made:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 0747a778fe9a2..af36e9d5d5e88 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -143,6 +143,12 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
TTI::TargetCostKind CostKind,
const Instruction *I);
+ InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+ const Value *Ptr, bool VariableMask,
+ Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I);
+
InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
@@ -250,6 +256,11 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return ST->is64Bit() && !ST->hasVInstructionsI64();
}
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) {
+ EVT DataTypeVT = TLI->getValueType(DL, DataType);
+ return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
+ }
+
bool isVScaleKnownToBeAPowerOfTwo() const {
return TLI->isVScaleKnownToBeAPowerOfTwo();
}
>From 40069ec8c0fd464592b12d35cf24d0d8a5de3dc0 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 1 Feb 2024 20:46:29 +0000
Subject: [PATCH 2/2] Address comments
Created using spr 1.3.5
---
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 3 +--
llvm/lib/Analysis/TargetTransformInfo.cpp | 3 ++-
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 7 +++----
3 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 1fe126379ac75..3d5db96e86b80 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -696,8 +696,7 @@ class TargetTransformInfoImplBase {
Align Alignment,
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr) const {
- return CostKind == TTI::TCK_RecipThroughput ? TTI::TCC_Expensive
- : TTI::TCC_Basic;
+ return InstructionCost::getInvalid();
}
unsigned getInterleavedMemoryOpCost(
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8158c519562ee..1f11f0d7dd620 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1042,7 +1042,8 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost(
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
InstructionCost Cost = TTIImpl->getGatherScatterOpCost(
Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
- assert(Cost >= 0 && "TTI should not produce negative costs!");
+ assert((!Cost.isValid() || Cost >= 0) &&
+ "TTI should not produce negative costs!");
return Cost;
}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 3e349458ec0db..cb48720cc1902 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -661,16 +661,15 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
- if (CostKind != TTI::TCK_RecipThroughput)
- return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
- Alignment, CostKind, I);
-
if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
!isLegalStridedLoadStore(DataTy, Alignment)) ||
(Opcode != Instruction::Load && Opcode != Instruction::Store))
return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
Alignment, CostKind, I);
+ if (CostKind == TTI::TCK_CodeSize)
+ return TTI::TCC_Basic;
+
// Cost is proportional to the number of memory operations implied. For
// scalable vectors, we use an estimate on that number since we don't
// know exactly what VL will be.
More information about the llvm-commits
mailing list