[llvm] [TTI]Add support for strided loads/stores. (PR #80329)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 1 11:42:41 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Alexey Bataev (alexey-bataev)
<details>
<summary>Changes</summary>
Added basic legality check and cost estimation functions.
---
Full diff: https://github.com/llvm/llvm-project/pull/80329.diff
5 Files Affected:
- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+34)
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+13)
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+14)
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+24)
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+11)
``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 3b615bc700bbb..58577a6b6eb5c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -781,6 +781,9 @@ class TargetTransformInfo {
/// Return true if the target supports masked expand load.
bool isLegalMaskedExpandLoad(Type *DataType) const;
+ /// Return true if the target supports strided load.
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
+
/// Return true if this is an alternating opcode pattern that can be lowered
/// to a single instruction on the target. In X86 this is for the addsub
/// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1412,6 +1415,20 @@ class TargetTransformInfo {
Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
const Instruction *I = nullptr) const;
+ /// \return The cost of strided memory operations.
+ /// \p Opcode - is a type of memory access Load or Store
+ /// \p DataTy - a vector type of the data to be loaded or stored
+ /// \p Ptr - pointer [or vector of pointers] - address[es] in memory
+ /// \p VariableMask - true when the memory access is predicated with a mask
+ /// that is not a compile-time constant
+ /// \p Alignment - alignment of single element
+ /// \p I - the optional original context instruction, if one exists, e.g. the
+ /// load/store to transform or the call to the gather/scatter intrinsic
+ InstructionCost getStridedMemoryOpCost(
+ unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+ Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+ const Instruction *I = nullptr) const;
+
/// \return The cost of the interleaved memory operation.
/// \p Opcode is the memory operation code
/// \p VecTy is the vector type of the interleaved access.
@@ -1848,6 +1865,7 @@ class TargetTransformInfo::Concept {
Align Alignment) = 0;
virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+ virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
unsigned Opcode1,
const SmallBitVector &OpcodeMask) const = 0;
@@ -2023,6 +2041,11 @@ class TargetTransformInfo::Concept {
bool VariableMask, Align Alignment,
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr) = 0;
+ virtual InstructionCost
+ getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+ bool VariableMask, Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr) = 0;
virtual InstructionCost getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
@@ -2341,6 +2364,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
bool isLegalMaskedExpandLoad(Type *DataType) override {
return Impl.isLegalMaskedExpandLoad(DataType);
}
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
+ return Impl.isLegalStridedLoadStore(DataType, Alignment);
+ }
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const override {
return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
@@ -2671,6 +2697,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
Alignment, CostKind, I);
}
+ InstructionCost
+ getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+ bool VariableMask, Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr) override {
+ return Impl.getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment, CostKind, I);
+ }
InstructionCost getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9958b4daa6ed8..1fe126379ac75 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -304,6 +304,10 @@ class TargetTransformInfoImplBase {
bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
+ return false;
+ }
+
bool enableOrderedReductions() const { return false; }
bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
@@ -687,6 +691,15 @@ class TargetTransformInfoImplBase {
return 1;
}
+ InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+ const Value *Ptr, bool VariableMask,
+ Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I = nullptr) const {
+ return CostKind == TTI::TCK_RecipThroughput ? TTI::TCC_Expensive
+ : TTI::TCC_Basic;
+ }
+
unsigned getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8902dde37cbca..8158c519562ee 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -500,6 +500,11 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
return TTIImpl->isLegalMaskedExpandLoad(DataType);
}
+bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalStridedLoadStore(DataType, Alignment);
+}
+
bool TargetTransformInfo::enableOrderedReductions() const {
return TTIImpl->enableOrderedReductions();
}
@@ -1041,6 +1046,15 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost(
return Cost;
}
+InstructionCost TargetTransformInfo::getStridedMemoryOpCost(
+ unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+ Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+ InstructionCost Cost = TTIImpl->getStridedMemoryOpCost(
+ Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
+ assert(Cost >= 0 && "TTI should not produce negative costs!");
+ return Cost;
+}
+
InstructionCost TargetTransformInfo::getInterleavedMemoryOpCost(
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index fe1cdb2dfa423..3e349458ec0db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -658,6 +658,30 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
return NumLoads * MemOpCost;
}
+InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
+ unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+ Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment, CostKind, I);
+
+ if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
+ !isLegalStridedLoadStore(DataTy, Alignment)) ||
+ (Opcode != Instruction::Load && Opcode != Instruction::Store))
+ return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+ Alignment, CostKind, I);
+
+ // Cost is proportional to the number of memory operations implied. For
+ // scalable vectors, we use an estimate on that number since we don't
+ // know exactly what VL will be.
+ auto &VTy = *cast<VectorType>(DataTy);
+ InstructionCost MemOpCost =
+ getMemoryOpCost(Opcode, VTy.getElementType(), Alignment, 0, CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None}, I);
+ unsigned NumLoads = getEstimatedVLFor(&VTy);
+ return NumLoads * MemOpCost;
+}
+
// Currently, these represent both throughput and codesize costs
// for the respective intrinsics. The costs in this table are simply
// instruction counts with the following adjustments made:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 0747a778fe9a2..af36e9d5d5e88 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -143,6 +143,12 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
TTI::TargetCostKind CostKind,
const Instruction *I);
+ InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+ const Value *Ptr, bool VariableMask,
+ Align Alignment,
+ TTI::TargetCostKind CostKind,
+ const Instruction *I);
+
InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
@@ -250,6 +256,11 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return ST->is64Bit() && !ST->hasVInstructionsI64();
}
+ bool isLegalStridedLoadStore(Type *DataType, Align Alignment) {
+ EVT DataTypeVT = TLI->getValueType(DL, DataType);
+ return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
+ }
+
bool isVScaleKnownToBeAPowerOfTwo() const {
return TLI->isVScaleKnownToBeAPowerOfTwo();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/80329
More information about the llvm-commits
mailing list