[llvm] 8ad14b6 - [TTI]Add support for strided loads/stores.

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 13:07:42 PST 2024


Author: Alexey Bataev
Date: 2024-02-01T16:07:38-05:00
New Revision: 8ad14b6d90121d2d0687a3a7f6f6c6f2b34c4aa7

URL: https://github.com/llvm/llvm-project/commit/8ad14b6d90121d2d0687a3a7f6f6c6f2b34c4aa7
DIFF: https://github.com/llvm/llvm-project/commit/8ad14b6d90121d2d0687a3a7f6f6c6f2b34c4aa7.diff

LOG: [TTI]Add support for strided loads/stores.

Added basic legality check and cost estimation functions for strided loads and stores.

These interfaces will be built upon in https://github.com/llvm/llvm-project/pull/80310.

Reviewers: preames

Reviewed By: preames

Pull Request: https://github.com/llvm/llvm-project/pull/80329

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 3b615bc700bbb..58577a6b6eb5c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -781,6 +781,9 @@ class TargetTransformInfo {
   /// Return true if the target supports masked expand load.
   bool isLegalMaskedExpandLoad(Type *DataType) const;
 
+  /// Return true if the target supports strided load.
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
+
   /// Return true if this is an alternating opcode pattern that can be lowered
   /// to a single instruction on the target. In X86 this is for the addsub
   /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1412,6 +1415,20 @@ class TargetTransformInfo {
       Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
       const Instruction *I = nullptr) const;
 
+  /// \return The cost of strided memory operations.
+  /// \p Opcode - is a type of memory access Load or Store
+  /// \p DataTy - a vector type of the data to be loaded or stored
+  /// \p Ptr - pointer [or vector of pointers] - address[es] in memory
+  /// \p VariableMask - true when the memory access is predicated with a mask
+  ///                   that is not a compile-time constant
+  /// \p Alignment - alignment of single element
+  /// \p I - the optional original context instruction, if one exists, e.g. the
+  ///        load/store to transform or the call to the gather/scatter intrinsic
+  InstructionCost getStridedMemoryOpCost(
+      unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+      Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+      const Instruction *I = nullptr) const;
+
   /// \return The cost of the interleaved memory operation.
   /// \p Opcode is the memory operation code
   /// \p VecTy is the vector type of the interleaved access.
@@ -1848,6 +1865,7 @@ class TargetTransformInfo::Concept {
                                            Align Alignment) = 0;
   virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
   virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
                                const SmallBitVector &OpcodeMask) const = 0;
@@ -2023,6 +2041,11 @@ class TargetTransformInfo::Concept {
                          bool VariableMask, Align Alignment,
                          TTI::TargetCostKind CostKind,
                          const Instruction *I = nullptr) = 0;
+  virtual InstructionCost
+  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+                         bool VariableMask, Align Alignment,
+                         TTI::TargetCostKind CostKind,
+                         const Instruction *I = nullptr) = 0;
 
   virtual InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
@@ -2341,6 +2364,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isLegalMaskedExpandLoad(Type *DataType) override {
     return Impl.isLegalMaskedExpandLoad(DataType);
   }
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
+    return Impl.isLegalStridedLoadStore(DataType, Alignment);
+  }
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const override {
     return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
@@ -2671,6 +2697,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
                                        Alignment, CostKind, I);
   }
+  InstructionCost
+  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+                         bool VariableMask, Align Alignment,
+                         TTI::TargetCostKind CostKind,
+                         const Instruction *I = nullptr) override {
+    return Impl.getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                       Alignment, CostKind, I);
+  }
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9958b4daa6ed8..3d5db96e86b80 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -304,6 +304,10 @@ class TargetTransformInfoImplBase {
 
   bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
 
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
+    return false;
+  }
+
   bool enableOrderedReductions() const { return false; }
 
   bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
@@ -687,6 +691,14 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+                                         const Value *Ptr, bool VariableMask,
+                                         Align Alignment,
+                                         TTI::TargetCostKind CostKind,
+                                         const Instruction *I = nullptr) const {
+    return InstructionCost::getInvalid();
+  }
+
   unsigned getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8902dde37cbca..1f11f0d7dd620 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -500,6 +500,11 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
   return TTIImpl->isLegalMaskedExpandLoad(DataType);
 }
 
+bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
+                                                  Align Alignment) const {
+  return TTIImpl->isLegalStridedLoadStore(DataType, Alignment);
+}
+
 bool TargetTransformInfo::enableOrderedReductions() const {
   return TTIImpl->enableOrderedReductions();
 }
@@ -1037,6 +1042,16 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost(
     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
   InstructionCost Cost = TTIImpl->getGatherScatterOpCost(
       Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
+  assert((!Cost.isValid() || Cost >= 0) &&
+         "TTI should not produce negative costs!");
+  return Cost;
+}
+
+InstructionCost TargetTransformInfo::getStridedMemoryOpCost(
+    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+  InstructionCost Cost = TTIImpl->getStridedMemoryOpCost(
+      Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index fe1cdb2dfa423..cb48720cc1902 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -658,6 +658,29 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
   return NumLoads * MemOpCost;
 }
 
+InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
+    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
+  if (((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
+       !isLegalStridedLoadStore(DataTy, Alignment)) ||
+      (Opcode != Instruction::Load && Opcode != Instruction::Store))
+    return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                         Alignment, CostKind, I);
+
+  if (CostKind == TTI::TCK_CodeSize)
+    return TTI::TCC_Basic;
+
+  // Cost is proportional to the number of memory operations implied.  For
+  // scalable vectors, we use an estimate on that number since we don't
+  // know exactly what VL will be.
+  auto &VTy = *cast<VectorType>(DataTy);
+  InstructionCost MemOpCost =
+      getMemoryOpCost(Opcode, VTy.getElementType(), Alignment, 0, CostKind,
+                      {TTI::OK_AnyValue, TTI::OP_None}, I);
+  unsigned NumLoads = getEstimatedVLFor(&VTy);
+  return NumLoads * MemOpCost;
+}
+
 // Currently, these represent both throughput and codesize costs
 // for the respective intrinsics.  The costs in this table are simply
 // instruction counts with the following adjustments made:

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 0747a778fe9a2..af36e9d5d5e88 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -143,6 +143,12 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
                                          TTI::TargetCostKind CostKind,
                                          const Instruction *I);
 
+  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+                                         const Value *Ptr, bool VariableMask,
+                                         Align Alignment,
+                                         TTI::TargetCostKind CostKind,
+                                         const Instruction *I);
+
   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                    TTI::CastContextHint CCH,
                                    TTI::TargetCostKind CostKind,
@@ -250,6 +256,11 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return ST->is64Bit() && !ST->hasVInstructionsI64();
   }
 
+  bool isLegalStridedLoadStore(Type *DataType, Align Alignment) {
+    EVT DataTypeVT = TLI->getValueType(DL, DataType);
+    return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
+  }
+
   bool isVScaleKnownToBeAPowerOfTwo() const {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }


        


More information about the llvm-commits mailing list