[llvm] [SLP] Allow targets to add cost for nonstandard conditions (PR #95328)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 16:05:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Jeffrey Byrnes (jrbyrnes)

<details>
<summary>Changes</summary>

There are conditions in which vectorization is profitable, but are not expressible by the current cost model. As an example, the vectorization profit may entirely be based on conditions of the users of the tree entry. 

This gives targets a chance to express things of this nature.

---
Full diff: https://github.com/llvm/llvm-project/pull/95328.diff


5 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+16) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+5) 
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+5) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+5) 
- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+8) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..49117ca8c74c5 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -891,6 +891,11 @@ class TargetTransformInfo {
                                            bool Insert, bool Extract,
                                            TTI::TargetCostKind CostKind) const;
 
+  /// Whether or not there is any target-specific condition that imposes an
+  /// overhead for scalarization
+  bool hasScalarizationOverhead(ArrayRef<Value *> VL,
+                                std::pair<bool, bool> &ScalarizationKind) const;
+
   /// Estimate the overhead of scalarizing an instructions unique
   /// non-constant operands. The (potentially vector) types to use for each of
   /// argument are passes via Tys.
@@ -1921,6 +1926,10 @@ class TargetTransformInfo::Concept {
   getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
                                    ArrayRef<Type *> Tys,
                                    TargetCostKind CostKind) = 0;
+
+  virtual bool
+  hasScalarizationOverhead(ArrayRef<Value *> VL,
+                           std::pair<bool, bool> &ScalarizationKind) = 0;
   virtual bool supportsEfficientVectorElementLoadStore() = 0;
   virtual bool supportsTailCalls() = 0;
   virtual bool supportsTailCallFor(const CallBase *CB) = 0;
@@ -2456,6 +2465,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
                                          CostKind);
   }
+
+  bool
+  hasScalarizationOverhead(ArrayRef<Value *> VL,
+                           std::pair<bool, bool> &ScalarizationKind) override {
+    return Impl.hasScalarizationOverhead(VL, ScalarizationKind);
+  }
+
   InstructionCost
   getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
                                    ArrayRef<Type *> Tys,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..1d3e6752006d9 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -371,6 +371,11 @@ class TargetTransformInfoImplBase {
     return 0;
   }
 
+  bool hasScalarizationOverhead(ArrayRef<Value *> VL,
+                                std::pair<bool, bool> &ScalarizationKind) {
+    return false;
+  }
+
   InstructionCost
   getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
                                    ArrayRef<Type *> Tys,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 9f8d3ded9b3c1..2aa36c724bc03 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -807,6 +807,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                              CostKind);
   }
 
+  bool hasScalarizationOverhead(ArrayRef<Value *> VL,
+                                std::pair<bool, bool> &ScalarizationKind) {
+    return false;
+  }
+
   /// Estimate the overhead of scalarizing an instructions unique
   /// non-constant operands. The (potentially vector) types to use for each of
   /// argument are passes via Tys.
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..81b1e6b181bb0 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -594,6 +594,11 @@ InstructionCost TargetTransformInfo::getScalarizationOverhead(
                                            CostKind);
 }
 
+bool TargetTransformInfo::hasScalarizationOverhead(
+    ArrayRef<Value *> VL, std::pair<bool, bool> &ScalarizeKind) const {
+  return TTIImpl->hasScalarizationOverhead(VLm ScalarizeKind);
+}
+
 InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(
     ArrayRef<const Value *> Args, ArrayRef<Type *> Tys,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ae0819c964bef..f189e9b6ba14b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9084,6 +9084,14 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts);
   }
   InstructionCost CommonCost = 0;
+  std::pair<bool, bool> ScalarizationKind(false, false);
+  if (TTI->hasScalarizationOverhead(VL, ScalarizationKind)) {
+    APInt DemandedElts = APInt::getAllOnes(VL.size());
+    CommonCost -= TTI->getScalarizationOverhead(
+        VecTy, DemandedElts,
+        /*Insert*/ ScalarizationKind.first,
+        /*Extract*/ ScalarizationKind.second, CostKind);
+  }
   SmallVector<int> Mask;
   bool IsReverseOrder = isReverseOrder(E->ReorderIndices);
   if (!E->ReorderIndices.empty() &&

``````````

</details>


https://github.com/llvm/llvm-project/pull/95328


More information about the llvm-commits mailing list