[llvm] [NFC][Scalarizer][TargetTransformInfo] Add isTargetIntrinsicWithScalarOpAtArg api (PR #111441)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 7 14:20:36 PDT 2024


https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/111441

This change allows target intrinsics can have scalar args

fixes [111440](https://github.com/llvm/llvm-project/issues/111440)

>From 9ecd706fdea0890e3466304247b44163d312b9ca Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Mon, 7 Oct 2024 17:17:01 -0400
Subject: [PATCH] [NFC][Scalarizer][TargetTransformInfo] Add
 isTargetIntrinsicWithScalarOpAtArg api This change allows target intrinsics
 can have scalar args

---
 llvm/include/llvm/Analysis/TargetTransformInfo.h     | 11 +++++++++++
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h |  5 +++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h             |  5 +++++
 llvm/lib/Analysis/TargetTransformInfo.cpp            |  5 +++++
 .../Target/DirectX/DirectXTargetTransformInfo.cpp    | 12 +++++++++++-
 llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h |  2 ++
 llvm/lib/Transforms/Scalar/Scalarizer.cpp            |  3 ++-
 7 files changed, 41 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 89a85bc8a90864..2befacea4df866 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -884,6 +884,9 @@ class TargetTransformInfo {
 
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
 
+  bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                          unsigned ScalarOpdIdx) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1935,6 +1938,8 @@ class TargetTransformInfo::Concept {
   virtual bool shouldBuildRelLookupTables() = 0;
   virtual bool useColdCCForColdCall(Function &F) = 0;
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
+  virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                                  unsigned ScalarOpdIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2477,6 +2482,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) override {
     return Impl.isTargetIntrinsicTriviallyScalarizable(ID);
   }
+
+  bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                          unsigned ScalarOpdIdx) override {
+    return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 50040dc8f6165b..01a16e7c7b1e59 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -377,6 +377,11 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                          unsigned ScalarOpdIdx) const {
+    return false;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c36a346c1b2e05..57d1fa33c8482c 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -793,6 +793,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                          unsigned ScalarOpdIdx) const {
+    return false;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b5195f764cbd1c..b612a3331e5737 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -592,6 +592,11 @@ bool TargetTransformInfo::isTargetIntrinsicTriviallyScalarizable(
   return TTIImpl->isTargetIntrinsicTriviallyScalarizable(ID);
 }
 
+bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx) const {
+  return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 1a59f04b214042..be714b5c87895a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -13,7 +13,17 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
 
-bool llvm::DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
+using namespace llvm;
+
+bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                                        unsigned ScalarOpdIdx) {
+  switch (ID) {
+  default:
+    return false;
+  }
+}
+
+bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
   case Intrinsic::dx_frac:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 48414549f83495..30b57ed97d6370 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -35,6 +35,8 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
         TLI(ST->getTargetLowering()) {}
   unsigned getMinVectorRegisterBitWidth() const { return 32; }
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
+  bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+                                          unsigned ScalarOpdIdx);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index ee86e2e6c9751e..72728c0f839e5d 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -745,7 +745,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       Tys[0] = VS->RemainderTy;
 
     for (unsigned J = 0; J != NumArgs; ++J) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
+          TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
         ScalarCallOps.push_back(ScalarOperands[J]);
       } else {
         ScalarCallOps.push_back(Scattered[J][I]);



More information about the llvm-commits mailing list