[llvm] 8663b87 - [NFC][VectorUtils][TargetTransformInfo] Add `isVectorIntrinsicWithOverloadTypeAtArg` api (#114849)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 11:04:28 PST 2024


Author: Finn Plummer
Date: 2024-11-21T11:04:25-08:00
New Revision: 8663b8777e8108f74f91a2a33115b3a00d57c043

URL: https://github.com/llvm/llvm-project/commit/8663b8777e8108f74f91a2a33115b3a00d57c043
DIFF: https://github.com/llvm/llvm-project/commit/8663b8777e8108f74f91a2a33115b3a00d57c043.diff

LOG: [NFC][VectorUtils][TargetTransformInfo] Add `isVectorIntrinsicWithOverloadTypeAtArg` api (#114849)

This changes allows target intrinsics to specify and overwrite overloaded types.

- Updates `ReplaceWithVecLib` to not provide TTI as there most probably won't be a use-case
- Updates `SLPVectorizer` to use available TTI
- Updates `VPTransformState` to pass down TTI
- Updates `VPlanRecipe` to use passed-down TTI

This change will let us add scalarization for `asdouble`:  #114847

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/lib/CodeGen/ReplaceWithVeclib.cpp
    llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
    llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
    llvm/lib/Transforms/Scalar/Scalarizer.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/lib/Transforms/Vectorize/VPlan.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index e37bce3118bcb2..985ca1532e0149 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -901,6 +901,12 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  /// Identifies if the vector form of the intrinsic is overloaded on the type
+  /// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
+  /// -1.
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int ScalarOpdIdx) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1993,6 +1999,8 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      int ScalarOpdIdx) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2569,6 +2577,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int ScalarOpdIdx) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 72038c090b7922..38aba183f6a173 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -396,6 +396,11 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int ScalarOpdIdx) const {
+    return ScalarOpdIdx == -1;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,

diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 467d5932cacf91..c1016dd7bdddbd 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -152,7 +152,10 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 
 /// Identifies if the vector form of the intrinsic is overloaded on the type of
 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
-bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
+/// \p TTI is used to consider target specific intrinsics, if no target specific
+/// intrinsics will be considered then it is appropriate to pass in nullptr.
+bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
+                                            const TargetTransformInfo *TTI);
 
 /// Identifies if the vector form of the intrinsic that returns a struct is
 /// overloaded at the struct element index \p RetIdx.

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 3b098c42f2741c..b3583e2819ee4c 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -801,6 +801,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int ScalarOpdIdx) const {
+    return ScalarOpdIdx == -1;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 174e5e87abe538..1fb2b9836de0cc 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -615,6 +615,11 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int ScalarOpdIdx) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 15e325a0fffca5..1789671276ffaf 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -133,10 +133,13 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
-bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
-                                                  int OpdIdx) {
+bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) {
   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
 
+  if (TTI && Intrinsic::isTargetIntrinsic(ID))
+    return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
+
   switch (ID) {
   case Intrinsic::fptosi_sat:
   case Intrinsic::fptoui_sat:

diff  --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7f3c5cf6cb4436..8d457f58e6eede 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -110,7 +110,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
 
   // OloadTys collects types used in scalar intrinsic overload name.
   SmallVector<Type *, 3> OloadTys;
-  if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
+  if (!RetTy->isVoidTy() &&
+      isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr))
     OloadTys.push_back(ScalarRetTy);
 
   // Compute the argument types of the corresponding scalar call and check that
@@ -118,7 +119,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   SmallVector<Type *, 8> ScalarArgTypes;
   for (auto Arg : enumerate(II->args())) {
     auto *ArgTy = Arg.value()->getType();
-    bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index());
+    bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
+                                                            /*TTI=*/nullptr);
     if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
       ScalarArgTypes.push_back(ArgTy);
       if (IsOloadTy)

diff  --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index b0436a39423405..182cdaa4e9a7d7 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                            int ScalarOpdIdx) {
+  switch (ID) {
+  default:
+    return ScalarOpdIdx == -1;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {

diff  --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..a18e4a28625756 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,8 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              int ScalarOpdIdx);
 };
 } // namespace llvm
 

diff  --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 03d069c9fcb36d..3b701e6ca09761 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +767,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
         Tys.push_back(OpI->getType());
     }
   }

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index fda6550a375480..2854c1462014f9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7684,7 +7684,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan);
+  VPTransformState State(&TTI, BestVF, BestUF, LI, DT, ILV.Builder, &ILV,
+                         &BestVPlan);
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
   // before making any changes to the CFG.

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dd87d34d1f01a4..f13d0d80d382a4 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -15655,7 +15655,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       SmallVector<Value *> OpVecs;
       SmallVector<Type *, 2> TysForDecl;
       // Add return type if intrinsic is overloaded on it.
-      if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+      if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
         TysForDecl.push_back(VecTy);
       auto *CEI = cast<CallInst>(VL0);
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
@@ -15670,7 +15670,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
               It->second.first < DL->getTypeSizeInBits(CEI->getType()))
             ScalarArg = Builder.getFalse();
           OpVecs.push_back(ScalarArg);
-          if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+          if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
             TysForDecl.push_back(ScalarArg->getType());
           continue;
         }
@@ -15692,7 +15692,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         }
         LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
         OpVecs.push_back(OpVec);
-        if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+        if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
           TysForDecl.push_back(OpVec->getType());
       }
 

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 8b1a4aeb88f81f..a24a86b4201c31 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -219,10 +219,11 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   return It;
 }
 
-VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
+VPTransformState::VPTransformState(const TargetTransformInfo *TTI,
+                                   ElementCount VF, unsigned UF, LoopInfo *LI,
                                    DominatorTree *DT, IRBuilderBase &Builder,
                                    InnerLoopVectorizer *ILV, VPlan *Plan)
-    : VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
+    : TTI(TTI), VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
       LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()) {}
 
 Value *VPTransformState::get(VPValue *Def, const VPLane &Lane) {

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index abfe97b4ab55b6..9ef85a7f7a7524 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -234,9 +234,11 @@ class VPLane {
 /// VPTransformState holds information passed down when "executing" a VPlan,
 /// needed for generating the output IR.
 struct VPTransformState {
-  VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
-                   DominatorTree *DT, IRBuilderBase &Builder,
+  VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF,
+                   LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
                    InnerLoopVectorizer *ILV, VPlan *Plan);
+  /// Target Transform Info.
+  const TargetTransformInfo *TTI;
 
   /// The chosen Vectorization Factor of the loop being vectorized.
   ElementCount VF;

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ef2ca9af7268d1..71aca3be9e5dcb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -941,7 +941,7 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
 
   SmallVector<Type *, 2> TysForDecl;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1))
+  if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1, State.TTI))
     TysForDecl.push_back(VectorType::get(getResultType(), State.VF));
   SmallVector<Value *, 4> Args;
   for (const auto &I : enumerate(operands())) {
@@ -952,7 +952,8 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
       Arg = State.get(I.value(), VPLane(0));
     else
       Arg = State.get(I.value(), onlyFirstLaneUsed(I.value()));
-    if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
+    if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index(),
+                                               State.TTI))
       TysForDecl.push_back(Arg->getType());
     Args.push_back(Arg);
   }


        


More information about the llvm-commits mailing list