[llvm] [TLI] ReplaceWithVecLib pass uses CostModel (PR #78688)
Paschalis Mpeis via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 19 01:43:25 PST 2024
https://github.com/paschalis-mpeis created https://github.com/llvm/llvm-project/pull/78688
Pass `replace-with-veclib` only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.
>From e476a0e02c024d3787b4b2268faa72c9fe80a60c Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Tue, 16 Jan 2024 13:52:11 +0000
Subject: [PATCH] [TLI] ReplaceWithVecLib pass uses CostModel
Pass replace-with-veclib only replaces to veclib calls when their cost
is not found to be higher than the cost of the original instruction.
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 80 ++++++++++++++++---
.../AArch64/replace-with-veclib-armpl.ll | 2 +-
.../AArch64/replace-with-veclib-sleef.ll | 2 +-
3 files changed, 70 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7b0215535a92c89..c57156c00a74e84 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -9,6 +9,8 @@
// Replaces LLVM IR instructions with vector operands (i.e., the frem
// instruction or calls to LLVM intrinsics) with matching calls to functions
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
+// This happens only when the cost of calling the vector library is not found to
+// be more than the cost of the original instruction.
//
//===----------------------------------------------------------------------===//
@@ -20,11 +22,15 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -95,15 +101,54 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Replacement->copyFastMathFlags(&I);
}
+/// Returns whether the vector library call \p TLIFunc costs more than the
+/// original instruction \p I.
+static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI, Instruction &I,
+ VectorType *VectorTy, CallInst *CI,
+ Function *TLIFunc) {
+ SmallVector<Type *, 4> OpTypes;
+ for (auto &Op : CI ? CI->args() : I.operands())
+ OpTypes.push_back(Op->getType());
+
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost DefaultCost;
+ if (CI) {
+ FastMathFlags FMF;
+ if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
+ FMF = FPMO->getFastMathFlags();
+
+ SmallVector<const Value *> Args(CI->args());
+ IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
+ OpTypes, FMF,
+ dyn_cast<IntrinsicInst>(CI));
+ DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
+ } else {
+ assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
+ auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
+ SmallVector<const Value *, 4> OpValues(I.operand_values());
+ DefaultCost = TTI.getArithmeticInstrCost(
+ I.getOpcode(), VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ Op2Info, OpValues, &I);
+ }
+
+ InstructionCost VecLibCost =
+ TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
+ return VecLibCost > DefaultCost;
+}
+
/// Returns true when successfully replaced \p I with a suitable function taking
-/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the frem instruction.
+/// vector arguments, based on available mappings in the \p TLI and costs.
+/// Currently only works when \p I is a call to vectorized intrinsic or the frem
+/// instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI,
Instruction &I) {
// At the moment VFABI assumes the return type is always widened unless it is
// a void type.
- auto *VTy = dyn_cast<VectorType>(I.getType());
- ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+ auto *VectorTy = dyn_cast<VectorType>(I.getType());
+ ElementCount EC(VectorTy ? VectorTy->getElementCount() : ElementCount::getFixed(0));
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
@@ -124,9 +169,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
ScalarArgTypes.push_back(VectorArgTy->getElementType());
// When return type is void, set EC to the first vector argument, and
// disallow vector arguments with different ECs.
- if (EC.isZero())
+ if (EC.isZero()) {
EC = VectorArgTy->getElementCount();
- else if (EC != VectorArgTy->getElementCount())
+ VectorTy = VectorArgTy;
+ } else if (EC != VectorArgTy->getElementCount())
return false;
} else
// Exit when it is supposed to be a vector argument but it isn't.
@@ -138,8 +184,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
- assert(VTy && "Return type must be a vector");
- auto *ScalarTy = VTy->getScalarType();
+ assert(VectorTy && "Return type must be a vector");
+ auto *ScalarTy = VectorTy->getScalarType();
LibFunc Func;
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
@@ -199,6 +245,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
+ if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
+ return false;
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -219,13 +268,14 @@ static bool isSupportedInstruction(Instruction *I) {
return false;
}
-static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
+static bool runImpl(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI, Function &F) {
bool Changed = false;
SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
if (!isSupportedInstruction(&I))
continue;
- if (replaceWithCallToVeclib(TLI, I)) {
+ if (replaceWithCallToVeclib(TLI, TTI, I)) {
ReplacedCalls.push_back(&I);
Changed = true;
}
@@ -243,7 +293,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
FunctionAnalysisManager &AM) {
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
- auto Changed = runImpl(TLI, F);
+ const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+ auto Changed = runImpl(TLI, TTI, F);
if (Changed) {
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
<< NumCallsReplaced << "\n");
@@ -251,6 +302,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<TargetLibraryAnalysis>();
+ PA.preserve<TargetIRAnalysis>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<LoopAccessAnalysis>();
PA.preserve<DemandedBitsAnalysis>();
@@ -268,13 +320,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
const TargetLibraryInfo &TLI =
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- return runImpl(TLI, F);
+ const TargetTransformInfo &TTI =
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ return runImpl(TLI, TTI, F);
}
void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
+ AU.addPreserved<TargetTransformInfoWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index 758df0493cc5046..d3e1ae338f2caa6 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @frem_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index f408df570fdc00d..69b16b02adaa2d4 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: @frem_f64(
-; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
+; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
More information about the llvm-commits
mailing list