[llvm] [TLI] ReplaceWithVecLib pass uses CostModel (PR #78688)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 19 01:43:56 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Paschalis Mpeis (paschalis-mpeis)

<details>
<summary>Changes</summary>

Pass `replace-with-veclib` only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.

---
Full diff: https://github.com/llvm/llvm-project/pull/78688.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+68-12) 
- (modified) llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll (+1-1) 
- (modified) llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7b0215535a92c8..c57156c00a74e8 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -9,6 +9,8 @@
 // Replaces LLVM IR instructions with vector operands (i.e., the frem
 // instruction or calls to LLVM intrinsics) with matching calls to functions
 // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
+// This happens only when the cost of calling the vector library is not found to
+// be more than the cost of the original instruction.
 //
 //===----------------------------------------------------------------------===//
 
@@ -20,11 +22,15 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/InstructionCost.h"
 #include "llvm/Support/TypeSize.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
@@ -95,15 +101,54 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
     Replacement->copyFastMathFlags(&I);
 }
 
+/// Returns whether the vector library call \p TLIFunc costs more than the
+/// original instruction \p I.
+static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
+                               const TargetTransformInfo &TTI, Instruction &I,
+                               VectorType *VectorTy, CallInst *CI,
+                               Function *TLIFunc) {
+  SmallVector<Type *, 4> OpTypes;
+  for (auto &Op : CI ? CI->args() : I.operands())
+    OpTypes.push_back(Op->getType());
+
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  InstructionCost DefaultCost;
+  if (CI) {
+    FastMathFlags FMF;
+    if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
+      FMF = FPMO->getFastMathFlags();
+
+    SmallVector<const Value *> Args(CI->args());
+    IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
+                                      OpTypes, FMF,
+                                      dyn_cast<IntrinsicInst>(CI));
+    DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
+  } else {
+    assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
+    auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
+    SmallVector<const Value *, 4> OpValues(I.operand_values());
+    DefaultCost = TTI.getArithmeticInstrCost(
+        I.getOpcode(), VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Op2Info, OpValues, &I);
+  }
+
+  InstructionCost VecLibCost =
+      TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
+  return VecLibCost > DefaultCost;
+}
+
 /// Returns true when successfully replaced \p I with a suitable function taking
-/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the frem instruction.
+/// vector arguments, based on available mappings in the \p TLI and costs.
+/// Currently only works when \p I is a call to vectorized intrinsic or the frem
+/// instruction.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
+                                    const TargetTransformInfo &TTI,
                                     Instruction &I) {
   // At the moment VFABI assumes the return type is always widened unless it is
   // a void type.
-  auto *VTy = dyn_cast<VectorType>(I.getType());
-  ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+  auto *VectorTy = dyn_cast<VectorType>(I.getType());
+  ElementCount EC(VectorTy ? VectorTy->getElementCount() : ElementCount::getFixed(0));
 
   // Compute the argument types of the corresponding scalar call and the scalar
   // function name. For calls, it additionally finds the function to replace
@@ -124,9 +169,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
         ScalarArgTypes.push_back(VectorArgTy->getElementType());
         // When return type is void, set EC to the first vector argument, and
         // disallow vector arguments with different ECs.
-        if (EC.isZero())
+        if (EC.isZero()) {
           EC = VectorArgTy->getElementCount();
-        else if (EC != VectorArgTy->getElementCount())
+          VectorTy = VectorArgTy;
+        } else if (EC != VectorArgTy->getElementCount())
           return false;
       } else
         // Exit when it is supposed to be a vector argument but it isn't.
@@ -138,8 +184,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
                      ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
                      : Intrinsic::getName(IID).str();
   } else {
-    assert(VTy && "Return type must be a vector");
-    auto *ScalarTy = VTy->getScalarType();
+    assert(VectorTy && "Return type must be a vector");
+    auto *ScalarTy = VectorTy->getScalarType();
     LibFunc Func;
     if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
       return false;
@@ -199,6 +245,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
                                      VD->getVectorFnName(), FuncToReplace);
 
+  if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
+    return false;
+
   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
                     << "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -219,13 +268,14 @@ static bool isSupportedInstruction(Instruction *I) {
   return false;
 }
 
-static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
+static bool runImpl(const TargetLibraryInfo &TLI,
+                    const TargetTransformInfo &TTI, Function &F) {
   bool Changed = false;
   SmallVector<Instruction *> ReplacedCalls;
   for (auto &I : instructions(F)) {
     if (!isSupportedInstruction(&I))
       continue;
-    if (replaceWithCallToVeclib(TLI, I)) {
+    if (replaceWithCallToVeclib(TLI, TTI, I)) {
       ReplacedCalls.push_back(&I);
       Changed = true;
     }
@@ -243,7 +293,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
                                          FunctionAnalysisManager &AM) {
   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
-  auto Changed = runImpl(TLI, F);
+  const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+  auto Changed = runImpl(TLI, TTI, F);
   if (Changed) {
     LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
                       << NumCallsReplaced << "\n");
@@ -251,6 +302,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
     PreservedAnalyses PA;
     PA.preserveSet<CFGAnalyses>();
     PA.preserve<TargetLibraryAnalysis>();
+    PA.preserve<TargetIRAnalysis>();
     PA.preserve<ScalarEvolutionAnalysis>();
     PA.preserve<LoopAccessAnalysis>();
     PA.preserve<DemandedBitsAnalysis>();
@@ -268,13 +320,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
   const TargetLibraryInfo &TLI =
       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
-  return runImpl(TLI, F);
+  const TargetTransformInfo &TTI =
+      getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+  return runImpl(TLI, TTI, F);
 }
 
 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesCFG();
   AU.addRequired<TargetLibraryInfoWrapperPass>();
+  AU.addRequired<TargetTransformInfoWrapperPass>();
   AU.addPreserved<TargetLibraryInfoWrapperPass>();
+  AU.addPreserved<TargetTransformInfoWrapperPass>();
   AU.addPreserved<ScalarEvolutionWrapperPass>();
   AU.addPreserved<AAResultsWrapperPass>();
   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index 758df0493cc504..d3e1ae338f2caa 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
 define <2 x double> @frem_f64(<2 x double> %in) {
 ; CHECK-LABEL: define <2 x double> @frem_f64
 ; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT:    [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
 ; CHECK-NEXT:    ret <2 x double> [[TMP1]]
 ;
   %1= frem <2 x double> %in, %in
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index f408df570fdc00..69b16b02adaa2d 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
 
 define <2 x double> @frem_f64(<2 x double> %in) {
 ; CHECK-LABEL: @frem_f64(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
+; CHECK-NEXT:    [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
 ; CHECK-NEXT:    ret <2 x double> [[TMP1]]
 ;
   %1= frem <2 x double> %in, %in

``````````

</details>


https://github.com/llvm/llvm-project/pull/78688


More information about the llvm-commits mailing list