[llvm] [NFC][TargetTransformInfo][VectorUtils] Consolidate `isVectorIntrinsic...` api (PR #117635)

Finn Plummer via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 14:14:31 PST 2024


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/117635

- update `VectorUtils:isVectorIntrinsicWithScalarOpAtArg` to use TTI for
all uses, to allow specifiction of target specific intrinsics
- add TTI to the `isVectorIntrinsicWithStructReturnOverloadAtField` api
- update TTI api to provide `isTargetIntrinsicWith...` functions and
  consistently name them
  
- update all uses of the api and provide the TTI parameter

Resolves #117030

>From c1c1373983fe1bb5ef62a3da8034c831419ad458 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:21:32 +0000
Subject: [PATCH 1/8] [NFC][TargetTransformInfo][VectorUtils] Consolidate
 `isVectorIntrinsic...` api

- update `VectorUtils:isVectorIntrinsicWithScalarOpAtArg` to use TTI for
all uses, to allow specifiction of target specific intrinsics
- add TTI to the `isVectorIntrinsicWithStructReturnOverloadAtField` api
- update TTI api to provide `isTargetIntrinsicWith...` functions and
  consistently name them
---
 .../llvm/Analysis/TargetTransformInfo.h       | 28 ++++++++++++++-----
 .../llvm/Analysis/TargetTransformInfoImpl.h   | 11 ++++++--
 llvm/include/llvm/Analysis/VectorUtils.h      | 14 ++++++----
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      | 11 ++++++--
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 11 ++++++--
 llvm/lib/Analysis/VectorUtils.cpp             | 17 ++++++++---
 6 files changed, 67 insertions(+), 25 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 985ca1532e0149..cb4793e85e462e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -898,14 +898,20 @@ class TargetTransformInfo {
 
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
 
+  /// Identifies if the vector form of the intrinsic has a scalar operand.
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
   /// Identifies if the vector form of the intrinsic is overloaded on the type
   /// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
   /// -1.
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const;
+
+  /// Identifies if the vector form of the intrinsic that returns a struct is
+  /// overloaded at the struct element index \p RetIdx.
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const;
 
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
@@ -1999,8 +2005,11 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
-  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                      int ScalarOpdIdx) = 0;
+  virtual bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      int OpdIdx) = 0;
+  virtual bool
+  isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                   int RetIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2577,9 +2586,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) override {
-    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) override {
+    return Impl.isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) override {
+    return Impl.isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 38aba183f6a173..9499c660822179 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -396,9 +396,14 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   InstructionCost getScalarizationOverhead(VectorType *Ty,
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c1016dd7bdddbd..f7febf3e82b125 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -147,8 +147,10 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
 bool isTriviallyVectorizable(Intrinsic::ID ID);
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
-bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                        unsigned ScalarOpdIdx);
+/// \p TTI is used to consider target specific intrinsics, if no target specific
+/// intrinsics will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
+                                        const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
@@ -158,9 +160,11 @@ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
                                             const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
-/// overloaded at the struct element index \p RetIdx.
-bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                      int RetIdx);
+/// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
+/// consider target specific intrinsics, if no target specific intrinsics
+/// will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
 
 /// Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..1bad9c70223b8d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -801,9 +801,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
-  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                              int ScalarOpdIdx) const {
-    return ScalarOpdIdx == -1;
+  bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int OpdIdx) const {
+    return OpdIdx == -1;
+  }
+
+  bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
+                                                        int RetIdx) const {
+    return RetIdx == 0;
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1fb2b9836de0cc..5cd09207bfc56a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -615,9 +615,14 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
-bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
-    Intrinsic::ID ID, int ScalarOpdIdx) const {
-  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+bool TargetTransformInfo::isTargetIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+}
+
+bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx) const {
+  return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, RetIdx);
 }
 
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 1789671276ffaf..15edaa0a47d037 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -115,7 +115,12 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
 
 /// Identifies if the vector form of the intrinsic has a scalar operand.
 bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
-                                              unsigned ScalarOpdIdx) {
+                                              unsigned ScalarOpdIdx,
+                                              const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
+
   switch (ID) {
   case Intrinsic::abs:
   case Intrinsic::ctlz:
@@ -138,7 +143,7 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
 
   if (TTI && Intrinsic::isTargetIntrinsic(ID))
-    return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+    return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
 
   switch (ID) {
   case Intrinsic::fptosi_sat:
@@ -157,8 +162,12 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
   }
 }
 
-bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
-                                                            int RetIdx) {
+bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
+    Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {
+
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
+
   switch (ID) {
   case Intrinsic::frexp:
     return RetIdx == 0 || RetIdx == 1;

>From cd61b96fc4dbad2ef8b264c2b79cb3af33181d81 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:22:28 +0000
Subject: [PATCH 2/8] update uses in ReplaceWithVeclib

- we don't expect a target intrinsic to be replaced with a general
libvec function. so we can just pass in nullptr for now
---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 8d457f58e6eede..a87c2063b1e353 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -121,7 +121,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
     auto *ArgTy = Arg.value()->getType();
     bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
                                                             /*TTI=*/nullptr);
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
       ScalarArgTypes.push_back(ArgTy);
       if (IsOloadTy)
         OloadTys.push_back(ArgTy);

>From 8d395583ba5ce50f9681628a43b47318d8c52f17 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:23:31 +0000
Subject: [PATCH 3/8] update uses in Scalarizer

- we can pass in the already defined TTI
---
 llvm/lib/Transforms/Scalar/Scalarizer.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 3b701e6ca09761..2ba8389c3a30c3 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -743,7 +743,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       // will only scalarize when the struct elements have the same bitness.
       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
         return false;
-      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
+      if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
         Tys.push_back(CurrVS->SplitTy);
     }
   }
@@ -794,8 +794,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       Tys[0] = VS->RemainderTy;
 
     for (unsigned J = 0; J != NumArgs; ++J) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
-          TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
         ScalarCallOps.push_back(ScalarOperands[J]);
       } else {
         ScalarCallOps.push_back(Scattered[J][I]);

>From 332ba9521215bfb4d4376d38816ace9194b1def7 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:25:14 +0000
Subject: [PATCH 4/8] update uses in SLPVectorizer

- pass down the previously defined TTI
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 44 ++++++++++---------
 1 file changed, 23 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f13d0d80d382a4..a54cddd69d35d3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1060,7 +1060,8 @@ static bool allSameType(ArrayRef<Value *> VL) {
 /// \returns True if in-tree use also needs extract. This refers to
 /// possible scalar operand in vectorized instruction.
 static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
-                                        TargetLibraryInfo *TLI) {
+                                        TargetLibraryInfo *TLI,
+                                        const TargetTransformInfo *TTI) {
   if (!UserInst)
     return false;
   unsigned Opcode = UserInst->getOpcode();
@@ -1077,7 +1078,7 @@ static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
     CallInst *CI = cast<CallInst>(UserInst);
     Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
     return any_of(enumerate(CI->args()), [&](auto &&Arg) {
-      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index()) &&
+      return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index(), TTI) &&
              Arg.value().get() == Scalar;
     });
   }
@@ -6417,7 +6418,7 @@ void BoUpSLP::buildExternalUses(
           // be used.
           if (UseEntry->State == TreeEntry::ScatterVectorize ||
               !doesInTreeUserNeedToExtract(
-                  Scalar, getRootEntryInstruction(*UseEntry), TLI)) {
+                  Scalar, getRootEntryInstruction(*UseEntry), TLI, TTI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(!UseEntry->isGather() && "Bad state");
@@ -7724,7 +7725,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     unsigned NumArgs = CI->arg_size();
     SmallVector<Value *, 4> ScalarArgs(NumArgs, nullptr);
     for (unsigned J = 0; J != NumArgs; ++J)
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI))
         ScalarArgs[J] = CI->getArgOperand(J);
     for (Value *V : VL) {
       CallInst *CI2 = dyn_cast<CallInst>(V);
@@ -7740,7 +7741,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
       // Some intrinsics have scalar arguments and should be same in order for
       // them to be vectorized.
       for (unsigned J = 0; J != NumArgs; ++J) {
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
           Value *A1J = CI2->getArgOperand(J);
           if (ScalarArgs[J] != A1J) {
             LLVM_DEBUG(dbgs()
@@ -8613,7 +8614,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         SmallVector<ValueList> Operands;
         for (unsigned I : seq<unsigned>(2, CI->arg_size())) {
           Operands.emplace_back();
-          if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
             continue;
           for (Value *V : VL) {
             auto *CI2 = cast<CallInst>(V);
@@ -8634,7 +8635,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
         // For scalar operands no need to create an entry since no need to
         // vectorize it.
-        if (isVectorIntrinsicWithScalarOpAtArg(ID, I))
+        if (isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI))
           continue;
         ValueList Operands;
         // Prepare the operand vector.
@@ -10834,14 +10835,14 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
 
 /// Builds the arguments types vector for the given call instruction with the
 /// given \p ID for the specified vector factor.
-static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
-                                                  const Intrinsic::ID ID,
-                                                  const unsigned VF,
-                                                  unsigned MinBW) {
+static SmallVector<Type *>
+buildIntrinsicArgTypes(const CallInst *CI, const Intrinsic::ID ID,
+                       const unsigned VF, unsigned MinBW,
+                       const TargetTransformInfo *TTI) {
   SmallVector<Type *> ArgTys;
   for (auto [Idx, Arg] : enumerate(CI->args())) {
     if (ID != Intrinsic::not_intrinsic) {
-      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx, TTI)) {
         ArgTys.push_back(Arg->getType());
         continue;
       }
@@ -11520,9 +11521,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       auto *CI = cast<CallInst>(VL0);
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
     };
@@ -15644,9 +15645,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
-      SmallVector<Type *> ArgTys =
-          buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
-                                 It != MinBWs.end() ? It->second.first : 0);
+      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(
+          CI, ID, VecTy->getNumElements(),
+          It != MinBWs.end() ? It->second.first : 0, TTI);
       auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
       bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
                           VecCallCosts.first <= VecCallCosts.second;
@@ -15662,7 +15663,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be
         // vectorized.
-        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
+        if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I, TTI)) {
           ScalarArg = CEI->getArgOperand(I);
           // if decided to reduce bitwidth of abs intrinsic, it second argument
           // must be set false (do not return poison, if value issigned min).
@@ -16195,7 +16196,7 @@ BoUpSLP::vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues,
                                E->State == TreeEntry::StridedVectorize) &&
                               doesInTreeUserNeedToExtract(
                                   Scalar, getRootEntryInstruction(*UseEntry),
-                                  TLI);
+                                  TLI, TTI);
                      })) &&
              "Scalar with nullptr User must be registered in "
              "ExternallyUsedValues map or remain as scalar in vectorized "
@@ -17679,7 +17680,8 @@ bool BoUpSLP::collectValuesToDemote(
     // Choose the best bitwidth based on cost estimations.
     auto Checker = [&](unsigned BitWidth, unsigned) {
       unsigned MinBW = PowerOf2Ceil(BitWidth);
-      SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(IC, ID, VF, MinBW);
+      SmallVector<Type *> ArgTys =
+          buildIntrinsicArgTypes(IC, ID, VF, MinBW, TTI);
       auto VecCallCosts = getVectorCallCosts(
           IC, getWidenedType(IntegerType::get(IC->getContext(), MinBW), VF),
           TTI, TLI, ArgTys);

>From fda38cd7f2232b1c793cdd050196a2edecba9594 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:25:53 +0000
Subject: [PATCH 5/8] update uses in VPlanRecipes

- use the previously defined TTI
---
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 71aca3be9e5dcb..1174d90e9d727e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -948,7 +948,8 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
     // Some intrinsics have a scalar argument - don't replace it with a
     // vector.
     Value *Arg;
-    if (isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index()))
+    if (isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index(),
+                                           State.TTI))
       Arg = State.get(I.value(), VPLane(0));
     else
       Arg = State.get(I.value(), onlyFirstLaneUsed(I.value()));

>From 3711df0791cd71f2fcf06f9dc77477ed1bb26253 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:39:44 +0000
Subject: [PATCH 6/8] update uses in LoopVectorizationLegality

- use the already defined TTI
---
 llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index f1568781252c06..567d6ca2e5a319 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -927,7 +927,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         auto *SE = PSE.getSE();
         Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI);
         for (unsigned Idx = 0; Idx < CI->arg_size(); ++Idx)
-          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx)) {
+          if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx, TTI)) {
             if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(Idx)),
                                      TheLoop)) {
               reportVectorizationFailure("Found unvectorizable intrinsic",

>From 309b8f2167ff00386f46191051f137c551763914 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 21:58:25 +0000
Subject: [PATCH 7/8] updates uses in VectorCombine

- use defined TTI
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 23 +++++++++++--------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index b9caf8c0df9be1..dd6b10643e51d4 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1832,7 +1832,7 @@ bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
     return false;
 
   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, I) &&
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) &&
         II0->getArgOperand(I) != II1->getArgOperand(I))
       return false;
 
@@ -1847,7 +1847,7 @@ bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
   SmallVector<Type *> NewArgsTy;
   InstructionCost NewCost = 0;
   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, I)) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
       NewArgsTy.push_back(II0->getArgOperand(I)->getType());
     } else {
       auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
@@ -1868,7 +1868,7 @@ bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
 
   SmallVector<Value *> NewArgs;
   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
-    if (isVectorIntrinsicWithScalarOpAtArg(IID, I)) {
+    if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
       NewArgs.push_back(II0->getArgOperand(I));
     } else {
       Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
@@ -1960,7 +1960,8 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
                                   const SmallPtrSet<Use *, 4> &IdentityLeafs,
                                   const SmallPtrSet<Use *, 4> &SplatLeafs,
                                   const SmallPtrSet<Use *, 4> &ConcatLeafs,
-                                  IRBuilder<> &Builder) {
+                                  IRBuilder<> &Builder,
+                                  const TargetTransformInfo *TTI) {
   auto [FrontU, FrontLane] = Item.front();
 
   if (IdentityLeafs.contains(FrontU)) {
@@ -1995,13 +1996,14 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
   unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
   SmallVector<Value *> Ops(NumOps);
   for (unsigned Idx = 0; Idx < NumOps; Idx++) {
-    if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
+    if (II &&
+        isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
       Ops[Idx] = II->getOperand(Idx);
       continue;
     }
-    Ops[Idx] =
-        generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty,
-                            IdentityLeafs, SplatLeafs, ConcatLeafs, Builder);
+    Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
+                                   Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
+                                   Builder, TTI);
   }
 
   SmallVector<Value *, 8> ValueList;
@@ -2173,7 +2175,8 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
                  II && isTriviallyVectorizable(II->getIntrinsicID()) &&
                  !II->hasOperandBundles()) {
         for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
-          if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
+          if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
+                                                 &TTI)) {
             if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
                   Value *FrontV = Item.front().first->get();
                   Use *U = IL.first;
@@ -2204,7 +2207,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
   // removed. Scan through again and generate the new tree of instructions.
   Builder.SetInsertPoint(&I);
   Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
-                                 ConcatLeafs, Builder);
+                                 ConcatLeafs, Builder, &TTI);
   replaceValue(I, *V);
   return true;
 }

>From 03cc71b70a1537e485570ecca35758b1b9918c85 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 25 Nov 2024 22:06:47 +0000
Subject: [PATCH 8/8] update use in ConstantFolding

- pass down the nullptr
---
 llvm/lib/Analysis/ConstantFolding.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 1971c28fc4c4de..caeaec70e3967c 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3438,7 +3438,7 @@ static Constant *ConstantFoldFixedVectorCall(
     // Gather a column of constants.
     for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
       // Some intrinsics use a scalar type for certain arguments.
-      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) {
+      if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J, /*TTI=*/nullptr)) {
         Lane[J] = Operands[J];
         continue;
       }



More information about the llvm-commits mailing list