[llvm] [NFC][Scalarizer][TargetTransformInfo] Add `isVectorIntrinsicWithOverloadTypeAtArg` api (PR #114849)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 15 14:08:51 PST 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/114849

>From b95d4f3c79f2a286bfb079599150abb04c76f39b Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 4 Nov 2024 17:38:36 +0000
Subject: [PATCH 1/8] [NFC][Scalarizer][TargetTransformInfo] Add
 `isVectorIntrinsicWithOverloadTypeAtArg`

This changes allows target intrinsic to specify overloaded types.

This change will let us add scalarization for `asdouble`:
---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   | 14 ++++++++++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  6 ++++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           |  6 ++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  6 ++++++
 .../Target/DirectX/DirectXTargetTransformInfo.cpp  |  8 ++++++++
 .../Target/DirectX/DirectXTargetTransformInfo.h    |  3 +++
 llvm/lib/Transforms/Scalar/Scalarizer.cpp          |  9 ++++++---
 7 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..2be5fc4de2409c 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }

>From 58925a494fb2bbd1596eb08329ff105cb2b95ea8 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 4 Nov 2024 21:36:08 +0000
Subject: [PATCH 2/8] review comments

- change code pattern to avoid calling TTI if we don't have a target
specific intrinsics
  - change type from unsigned to int
---
 .../include/llvm/Analysis/TargetTransformInfo.h | 12 ++++--------
 .../llvm/Analysis/TargetTransformInfoImpl.h     |  5 ++---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h        |  5 ++---
 llvm/lib/Analysis/TargetTransformInfo.cpp       |  5 ++---
 .../DirectX/DirectXTargetTransformInfo.cpp      |  6 +++---
 .../Target/DirectX/DirectXTargetTransformInfo.h |  3 +--
 llvm/lib/Transforms/Scalar/Scalarizer.cpp       | 17 +++++++++++------
 7 files changed, 25 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 796b4011d71c0c..26fa20be184383 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -897,8 +897,7 @@ class TargetTransformInfo {
                                           unsigned ScalarOpdIdx) const;
 
   bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx,
-                                              bool Default) const;
+                                              int ScalarOpdIdx) const;
 
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
@@ -1974,8 +1973,7 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
   virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                      unsigned ScalarOpdIdx,
-                                                      bool Default) = 0;
+                                                      int ScalarOpdIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2538,10 +2536,8 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
 
   bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx,
-                                              bool Default) override {
-    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
-                                                       Default);
+                                              int ScalarOpdIdx) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 42d1082cf4d9eb..dd76e0b8dc8a21 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -393,9 +393,8 @@ class TargetTransformInfoImplBase {
   }
 
   bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx,
-                                              bool Default) const {
-    return Default;
+                                              int ScalarOpdIdx) const {
+    return ScalarOpdIdx == -1;
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b2841e778947dd..3ec63c095f7c9a 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -799,9 +799,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx,
-                                              bool Default) const {
-    return Default;
+                                              int ScalarOpdIdx) const {
+    return ScalarOpdIdx == -1;
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index bf9733f971fdac..f79348d27e78ff 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -613,9 +613,8 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
 }
 
 bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
-  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
-                                                         Default);
+    Intrinsic::ID ID, int ScalarOpdIdx) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
 }
 
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 2be5fc4de2409c..1c54887d9f56e2 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,11 +25,11 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
-bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                            int ScalarOpdIdx) {
   switch (ID) {
   default:
-    return Default;
+    return ScalarOpdIdx == -1;
   }
 }
 
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index ff82b7404ca58a..a18e4a28625756 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -38,8 +38,7 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
   bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx,
-                                              bool Default);
+                                              int ScalarOpdIdx);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 719dce704872ae..1c61aa2712e8c7 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -280,6 +280,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
   bool visit(Function &F);
 
   bool isTriviallyScalarizable(Intrinsic::ID ID);
+  bool isIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int ScalarOpdIdx);
 
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
@@ -696,6 +697,13 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
          TTI->isTargetIntrinsicTriviallyScalarizable(ID);
 }
 
+bool ScalarizerVisitor::isIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                         int ScalarOpdIdx) {
+  return Intrinsic::isTargetIntrinsic(ID)
+             ? TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx)
+             : isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+}
+
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
 bool ScalarizerVisitor::splitCall(CallInst &CI) {
@@ -727,8 +735,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
-          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
+  if (isIntrinsicWithOverloadTypeAtArg(ID, -1))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -768,15 +775,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
-              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
+      if (isIntrinsicWithOverloadTypeAtArg(ID, I)) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
-              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
+      if (isIntrinsicWithOverloadTypeAtArg(ID, I))
         Tys.push_back(OpI->getType());
     }
   }

>From 628f9836a9779864fbcef55cd9f58c64bad4d4e4 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Wed, 13 Nov 2024 18:47:46 +0000
Subject: [PATCH 3/8] review comments:

- consolidate the usage of isVectorIntrinsicWithOverloadTypeAtArg into
one place of VectorUtils.cppp
---
 llvm/include/llvm/Analysis/VectorUtils.h  |  3 ++-
 llvm/lib/Analysis/VectorUtils.cpp         |  7 +++++--
 llvm/lib/Transforms/Scalar/Scalarizer.cpp | 14 +++-----------
 3 files changed, 10 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 467d5932cacf91..5e5bb8d753caff 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -152,7 +152,8 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
-bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
+bool isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI = nullptr);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
 /// overloaded at the struct element index \p RetIdx.
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 37c443011719b6..4044b5a6efda20 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -131,10 +131,13 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
-bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                  int OpdIdx) {
+bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) {
   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
 
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+
   switch (ID) {
   case Intrinsic::fptosi_sat:
   case Intrinsic::fptoui_sat:
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 1c61aa2712e8c7..64875d0d86cf75 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -280,7 +280,6 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
   bool visit(Function &F);
 
   bool isTriviallyScalarizable(Intrinsic::ID ID);
-  bool isIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int ScalarOpdIdx);
 
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
@@ -697,13 +696,6 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
          TTI->isTargetIntrinsicTriviallyScalarizable(ID);
 }
 
-bool ScalarizerVisitor::isIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                         int ScalarOpdIdx) {
-  return Intrinsic::isTargetIntrinsic(ID)
-             ? TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx)
-             : isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
-}
-
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
 bool ScalarizerVisitor::splitCall(CallInst &CI) {
@@ -735,7 +727,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -775,13 +767,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
         Tys.push_back(OpI->getType());
     }
   }

>From 7ff6fb72610c2c5bf3577ec32cef732e905db28e Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 15 Nov 2024 19:09:02 +0000
Subject: [PATCH 4/8] remove default TTI arg

---
 llvm/include/llvm/Analysis/VectorUtils.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 5e5bb8d753caff..6478803770c57c 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -152,8 +152,8 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
-bool isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI = nullptr);
+bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
+                                            const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
 /// overloaded at the struct element index \p RetIdx.

>From df90c487333463ea3a545e99d9db78efee6658d5 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 15 Nov 2024 19:12:26 +0000
Subject: [PATCH 5/8] update ReplaceWithVeclib uses

- this pass is used to replace a vectorizable intrinsic function with a
generic veclib function
- we make the assumption that a target specific intrinisc is not going
to be replaced with a generic function call and so we don't need to
provide the TTI
---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7f3c5cf6cb4436..8d457f58e6eede 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -110,7 +110,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
 
   // OloadTys collects types used in scalar intrinsic overload name.
   SmallVector<Type *, 3> OloadTys;
-  if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+  if (!RetTy->isVoidTy() &&
+      isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr))
     OloadTys.push_back(ScalarRetTy);
 
   // Compute the argument types of the corresponding scalar call and check that
@@ -118,7 +119,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   SmallVector<Type *, 8> ScalarArgTypes;
   for (auto Arg : enumerate(II->args())) {
     auto *ArgTy = Arg.value()->getType();
-    bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index());
+    bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
+                                                            /*TTI=*/nullptr);
     if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
       ScalarArgTypes.push_back(ArgTy);
       if (IsOloadTy)

>From 3a236c7d60bd33280ddffd1c73a0675d5bc8575d Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 15 Nov 2024 19:13:00 +0000
Subject: [PATCH 6/8] update uses in SLPVectorizer

- use the already defined TTI as the argument
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index a11e3f3815cbf7..810b477c12eec4 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -15147,7 +15147,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       SmallVector<Value *> OpVecs;
       SmallVector<Type *, 2> TysForDecl;
       // Add return type if intrinsic is overloaded on it.
-      if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+      if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
         TysForDecl.push_back(VecTy);
       auto *CEI = cast<CallInst>(VL0);
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
@@ -15162,7 +15162,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
               It->second.first < DL->getTypeSizeInBits(CEI->getType()))
             ScalarArg = Builder.getFalse();
           OpVecs.push_back(ScalarArg);
-          if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+          if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
             TysForDecl.push_back(ScalarArg->getType());
           continue;
         }
@@ -15184,7 +15184,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         }
         LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
         OpVecs.push_back(OpVec);
-        if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+        if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
           TysForDecl.push_back(OpVec->getType());
       }
 

>From 71d75fe1374ce7bbf6c5d0b0c2dc5fad4893dbce Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 15 Nov 2024 19:31:59 +0000
Subject: [PATCH 7/8] update VPTransformState to include TargetTransformInfo

- pass down the TTI from the LoopVectorizationPlanner into the
VPTransformState
- this will allow this information to all VPRecipe::execute functions
when they generate new IR
---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 3 ++-
 llvm/lib/Transforms/Vectorize/VPlan.cpp         | 5 +++--
 llvm/lib/Transforms/Vectorize/VPlan.h           | 6 ++++--
 3 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e8653498d32a12..5c48720575aefe 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7631,7 +7631,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan);
+  VPTransformState State(&TTI, BestVF, BestUF, LI, DT, ILV.Builder, &ILV,
+                         &BestVPlan);
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
   // before making any changes to the CFG.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index c1b97791331bcf..a002a913128064 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -222,10 +222,11 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   return It;
 }
 
-VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
+VPTransformState::VPTransformState(const TargetTransformInfo *TTI,
+                                   ElementCount VF, unsigned UF, LoopInfo *LI,
                                    DominatorTree *DT, IRBuilderBase &Builder,
                                    InnerLoopVectorizer *ILV, VPlan *Plan)
-    : VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
+    : TTI(TTI), VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
       LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()) {}
 
 Value *VPTransformState::get(VPValue *Def, const VPLane &Lane) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 59a084401cc9bf..2e75fba35e6098 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -234,9 +234,11 @@ class VPLane {
 /// VPTransformState holds information passed down when "executing" a VPlan,
 /// needed for generating the output IR.
 struct VPTransformState {
-  VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
-                   DominatorTree *DT, IRBuilderBase &Builder,
+  VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF,
+                   LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
                    InnerLoopVectorizer *ILV, VPlan *Plan);
+  /// Target Transform Info.
+  const TargetTransformInfo *TTI;
 
   /// The chosen Vectorization Factor of the loop being vectorized.
   ElementCount VF;

>From 59ef511356bd92057c32b8e1b1db70f71d972316 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 15 Nov 2024 19:32:27 +0000
Subject: [PATCH 8/8] update VPlanRecipe uses

- use the updated VPTransformState to pass down the TTI
---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 41f13cc2d9a978..2e72205449ccac 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -974,7 +974,7 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
 
   SmallVector<Type *, 2> TysForDecl;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1))
+  if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1, State.TTI))
     TysForDecl.push_back(VectorType::get(getResultType(), State.VF));
   SmallVector<Value *, 4> Args;
   for (const auto &I : enumerate(operands())) {
@@ -985,7 +985,8 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
       Arg = State.get(I.value(), VPLane(0));
     else
       Arg = State.get(I.value(), onlyFirstLaneUsed(I.value()));
-    if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
+    if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index(),
+                                               State.TTI))
       TysForDecl.push_back(Arg->getType());
     Args.push_back(Arg);
   }



More information about the llvm-commits mailing list