[llvm] [NFC][Scalarizer][TargetTransformInfo] Add `isVectorIntrinsicWithOverloadTypeAtArg` api (PR #114849)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 10:30:23 PST 2024


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/114849

This changes allows target intrinsic to specify and overwrite overloaded types.

This change will let us add scalarization for `asdouble`:  #114847

>From b95d4f3c79f2a286bfb079599150abb04c76f39b Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 4 Nov 2024 17:38:36 +0000
Subject: [PATCH] [NFC][Scalarizer][TargetTransformInfo] Add
 `isVectorIntrinsicWithOverloadTypeAtArg`

This changes allows target intrinsic to specify overloaded types.

This change will let us add scalarization for `asdouble`:
---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   | 14 ++++++++++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  6 ++++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           |  6 ++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  6 ++++++
 .../Target/DirectX/DirectXTargetTransformInfo.cpp  |  8 ++++++++
 .../Target/DirectX/DirectXTargetTransformInfo.h    |  3 +++
 llvm/lib/Transforms/Scalar/Scalarizer.cpp          |  9 ++++++---
 7 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..2be5fc4de2409c 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }



More information about the llvm-commits mailing list