[llvm] [TLI] Fix replace-with-veclib crash with invalid arguments (PR #77112)
Paschalis Mpeis via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 5 07:48:39 PST 2024
https://github.com/paschalis-mpeis created https://github.com/llvm/llvm-project/pull/77112
Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI mapping
do not match the original call. After this patch, it will simply ignores such cases.
# Stacked PR:
- (to be updated)
>From da179eb424fa86810b1d2e527f9263e0306a91d4 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 13 Dec 2023 17:33:58 +0000
Subject: [PATCH 1/8] [TLI] replace-with-veclib works with FRem Instruction.
Updated SLEEF and ArmPL tests with Fixed-Width and Scalable cases for
frem. Those are mapped to fmod/fmodf.
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 122 ++++++++++--------
.../replace-intrinsics-with-veclib-armpl.ll | 42 +++++-
...e-intrinsics-with-veclib-sleef-scalable.ll | 20 ++-
.../replace-intrinsics-with-veclib-sleef.ll | 20 ++-
4 files changed, 149 insertions(+), 55 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 893aa4a91828d3..e3ba9e3c0c3fa3 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -69,52 +69,57 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
return TLIFunc;
}
-/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
-/// the corresponding function from the vector library ( \p TLIVecFunc ).
-static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+/// Replace the Instruction \p I, that may be a vector intrinsic CallInst or
+/// the frem instruction, with a call to the corresponding function from the
+/// vector library ( \p TLIVecFunc ).
+static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Function *TLIVecFunc) {
- IRBuilder<> IRBuilder(&CalltoReplace);
- SmallVector<Value *> Args(CalltoReplace.args());
+ IRBuilder<> IRBuilder(&I);
+ auto *CI = dyn_cast<CallInst>(&I);
+ SmallVector<Value *> Args(CI ? CI->args() : I.operands());
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
- auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
- Info.Shape.VF);
+ auto *MaskTy =
+ VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
Args.insert(Args.begin() + OptMaskpos.value(),
Constant::getAllOnesValue(MaskTy));
}
- // Preserve the operand bundles.
+ // Preserve the operand bundles for CallInsts.
SmallVector<OperandBundleDef, 1> OpBundles;
- CalltoReplace.getOperandBundlesAsDefs(OpBundles);
+ if (CI)
+ CI->getOperandBundlesAsDefs(OpBundles);
+
CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
- CalltoReplace.replaceAllUsesWith(Replacement);
+ I.replaceAllUsesWith(Replacement);
// Preserve fast math flags for FP math.
if (isa<FPMathOperator>(Replacement))
- Replacement->copyFastMathFlags(&CalltoReplace);
+ Replacement->copyFastMathFlags(&I);
}
-/// Returns true when successfully replaced \p CallToReplace with a suitable
-/// function taking vector arguments, based on available mappings in the \p TLI.
-/// Currently only works when \p CallToReplace is a call to vectorized
-/// intrinsic.
+/// Returns true when successfully replaced \p I with a suitable function taking
+/// vector arguments, based on available mappings in the \p TLI. Currently only
+/// works when \p I is a call to vectorized intrinsic or the FRem Instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
- CallInst &CallToReplace) {
- if (!CallToReplace.getCalledFunction())
- return false;
-
- auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
- // Replacement is only performed for intrinsic functions.
- if (IntrinsicID == Intrinsic::not_intrinsic)
- return false;
-
+ Instruction &I) {
+ CallInst *CI = dyn_cast<CallInst>(&I);
+ Intrinsic::ID IID = Intrinsic::not_intrinsic;
+ if (CI)
+ IID = CI->getCalledFunction()->getIntrinsicID();
// Compute arguments types of the corresponding scalar call. Additionally
// checks if in the vector call, all vector operands have the same EC.
ElementCount VF = ElementCount::getFixed(0);
- SmallVector<Type *> ScalarArgTypes;
- for (auto Arg : enumerate(CallToReplace.args())) {
+ SmallVector<Type *, 8> ScalarArgTypes;
+ for (auto Arg : enumerate(CI ? CI->args() : I.operands())) {
auto *ArgTy = Arg.value()->getType();
- if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
+ if (CI && isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
- } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+ } else {
+ auto *VectorArgTy = dyn_cast<VectorType>(ArgTy);
+ // We are expecting only VectorTypes, as:
+ // - with a CallInst, scalar operands are handled earlier
+ // - with the FRem Instruction, both operands must be vectors.
+ if (!VectorArgTy)
+ return false;
ScalarArgTypes.push_back(ArgTy->getScalarType());
// Disallow vector arguments with different VFs. When processing the first
// vector argument, store it's VF, and for the rest ensure that they match
@@ -123,18 +128,22 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
VF = VectorArgTy->getElementCount();
else if (VF != VectorArgTy->getElementCount())
return false;
- } else
- // Exit when it is supposed to be a vector argument but it isn't.
- return false;
+ }
}
- // Try to reconstruct the name for the scalar version of this intrinsic using
- // the intrinsic ID and the argument types converted to scalar above.
- std::string ScalarName =
- (Intrinsic::isOverloaded(IntrinsicID)
- ? Intrinsic::getName(IntrinsicID, ScalarArgTypes,
- CallToReplace.getModule())
- : Intrinsic::getName(IntrinsicID).str());
+ // Try to reconstruct the name for the scalar version of the instruction.
+ std::string ScalarName;
+ if (CI) {
+ // For intrinsics, use scalar argument types
+ ScalarName = Intrinsic::isOverloaded(IID)
+ ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
+ : Intrinsic::getName(IID).str();
+ } else {
+ LibFunc Func;
+ if (!TLI.getLibFunc(I.getOpcode(), I.getType()->getScalarType(), Func))
+ return false;
+ ScalarName = TLI.getName(Func);
+ }
// Try to find the mapping for the scalar version of this intrinsic and the
// exact vector width of the call operands in the TargetLibraryInfo. First,
@@ -150,7 +159,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// Replace the call to the intrinsic with a call to the vector library
// function.
- Type *ScalarRetTy = CallToReplace.getType()->getScalarType();
+ Type *ScalarRetTy = I.getType()->getScalarType();
FunctionType *ScalarFTy =
FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
const std::string MangledName = VD->getVectorFunctionABIVariantString();
@@ -162,27 +171,36 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
if (!VectorFTy)
return false;
- Function *FuncToReplace = CallToReplace.getCalledFunction();
- Function *TLIFunc = getTLIFunction(CallToReplace.getModule(), VectorFTy,
+ Function *FuncToReplace = CI ? CI->getCalledFunction() : nullptr;
+ Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
- replaceWithTLIFunction(CallToReplace, *OptInfo, TLIFunc);
-
- LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
- << FuncToReplace->getName() << "` with call to `"
- << TLIFunc->getName() << "`.\n");
+ replaceWithTLIFunction(I, *OptInfo, TLIFunc);
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
+ << "` with call to `" << TLIFunc->getName() << "`.\n");
++NumCallsReplaced;
return true;
}
+/// Supported Instructions \p I are either FRem or CallInsts to Intrinsics.
+static bool isSupportedInstruction(Instruction *I) {
+ if (auto *CI = dyn_cast<CallInst>(I)) {
+ if (!CI->getCalledFunction())
+ return false;
+ if (CI->getCalledFunction()->getIntrinsicID() == Intrinsic::not_intrinsic)
+ return false;
+ } else if (I->getOpcode() != Instruction::FRem)
+ return false;
+
+ return true;
+}
+
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
bool Changed = false;
- SmallVector<CallInst *> ReplacedCalls;
+ SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
- if (auto *CI = dyn_cast<CallInst>(&I)) {
- if (replaceWithCallToVeclib(TLI, *CI)) {
- ReplacedCalls.push_back(CI);
- Changed = true;
- }
+ if (isSupportedInstruction(&I) && replaceWithCallToVeclib(TLI, I)) {
+ ReplacedCalls.push_back(&I);
+ Changed = true;
}
}
// Erase the calls to the intrinsics that have been replaced
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
index d41870ec6e7915..4480a90a2728d3 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
@@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)
;.
-; CHECK: @llvm.compiler.used = appending global [32 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [36 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vfmodq_f64, ptr @armpl_vfmodq_f32, ptr @armpl_svfmod_f64_x, ptr @armpl_svfmod_f32_x], section "llvm.metadata"
;.
define <2 x double> @llvm_cos_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @llvm_cos_f64
@@ -424,6 +424,46 @@ define <vscale x 4 x float> @llvm_pow_vscale_f32(<vscale x 4 x float> %in, <vsca
ret <vscale x 4 x float> %1
}
+define <2 x double> @frem_f64(<2 x double> %in) {
+; CHECK-LABEL: define <2 x double> @frem_f64
+; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT: ret <2 x double> [[TMP1]]
+;
+ %1= frem <2 x double> %in, %in
+ ret <2 x double> %1
+}
+
+define <4 x float> @frem_f32(<4 x float> %in) {
+; CHECK-LABEL: define <4 x float> @frem_f32
+; CHECK-SAME: (<4 x float> [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @armpl_vfmodq_f32(<4 x float> [[IN]], <4 x float> [[IN]])
+; CHECK-NEXT: ret <4 x float> [[TMP1]]
+;
+ %1= frem <4 x float> %in, %in
+ ret <4 x float> %1
+}
+
+define <vscale x 2 x double> @frem_vscale_f64(<vscale x 2 x double> %in) #0 {
+; CHECK-LABEL: define <vscale x 2 x double> @frem_vscale_f64
+; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 2 x double> @armpl_svfmod_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 2 x double> [[TMP1]]
+;
+ %1= frem <vscale x 2 x double> %in, %in
+ ret <vscale x 2 x double> %1
+}
+
+define <vscale x 4 x float> @frem_vscale_f32(<vscale x 4 x float> %in) #0 {
+; CHECK-LABEL: define <vscale x 4 x float> @frem_vscale_f32
+; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @armpl_svfmod_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
+;
+ %1= frem <vscale x 4 x float> %in, %in
+ ret <vscale x 4 x float> %1
+}
+
attributes #0 = { "target-features"="+sve" }
;.
; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
index c2ff6014bc6944..590dd9effac0ea 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
@@ -4,7 +4,7 @@
target triple = "aarch64-unknown-linux-gnu"
;.
-; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf, ptr @_ZGVsMxvv_fmod, ptr @_ZGVsMxvv_fmodf], section "llvm.metadata"
;.
define <vscale x 2 x double> @llvm_ceil_vscale_f64(<vscale x 2 x double> %in) {
; CHECK-LABEL: @llvm_ceil_vscale_f64(
@@ -384,6 +384,24 @@ define <vscale x 4 x float> @llvm_trunc_vscale_f32(<vscale x 4 x float> %in) {
ret <vscale x 4 x float> %1
}
+define <vscale x 2 x double> @frem_f64(<vscale x 2 x double> %in) {
+; CHECK-LABEL: @frem_f64(
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 2 x double> @_ZGVsMxvv_fmod(<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 2 x double> [[TMP1]]
+;
+ %1= frem <vscale x 2 x double> %in, %in
+ ret <vscale x 2 x double> %1
+}
+
+define <vscale x 4 x float> @frem_f32(<vscale x 4 x float> %in) {
+; CHECK-LABEL: @frem_f32(
+; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @_ZGVsMxvv_fmodf(<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
+;
+ %1= frem <vscale x 4 x float> %in, %in
+ ret <vscale x 4 x float> %1
+}
+
declare <vscale x 2 x double> @llvm.ceil.nxv2f64(<vscale x 2 x double>)
declare <vscale x 4 x float> @llvm.ceil.nxv4f32(<vscale x 4 x float>)
declare <vscale x 2 x double> @llvm.copysign.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
index be247de368056e..865a46009b205f 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
@@ -4,7 +4,7 @@
target triple = "aarch64-unknown-linux-gnu"
;.
-; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVnN2v_cos, ptr @_ZGVnN4v_cosf, ptr @_ZGVnN2v_exp, ptr @_ZGVnN4v_expf, ptr @_ZGVnN2v_exp2, ptr @_ZGVnN4v_exp2f, ptr @_ZGVnN2v_exp10, ptr @_ZGVnN4v_exp10f, ptr @_ZGVnN2v_log, ptr @_ZGVnN4v_logf, ptr @_ZGVnN2v_log10, ptr @_ZGVnN4v_log10f, ptr @_ZGVnN2v_log2, ptr @_ZGVnN4v_log2f, ptr @_ZGVnN2v_sin, ptr @_ZGVnN4v_sinf], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVnN2v_cos, ptr @_ZGVnN4v_cosf, ptr @_ZGVnN2v_exp, ptr @_ZGVnN4v_expf, ptr @_ZGVnN2v_exp2, ptr @_ZGVnN4v_exp2f, ptr @_ZGVnN2v_exp10, ptr @_ZGVnN4v_exp10f, ptr @_ZGVnN2v_log, ptr @_ZGVnN4v_logf, ptr @_ZGVnN2v_log10, ptr @_ZGVnN4v_log10f, ptr @_ZGVnN2v_log2, ptr @_ZGVnN4v_log2f, ptr @_ZGVnN2v_sin, ptr @_ZGVnN4v_sinf, ptr @_ZGVnN2vv_fmod, ptr @_ZGVnN4vv_fmodf], section "llvm.metadata"
;.
define <2 x double> @llvm_ceil_f64(<2 x double> %in) {
; CHECK-LABEL: @llvm_ceil_f64(
@@ -384,6 +384,24 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
ret <4 x float> %1
}
+define <2 x double> @frem_f64(<2 x double> %in) {
+; CHECK-LABEL: @frem_f64(
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
+; CHECK-NEXT: ret <2 x double> [[TMP1]]
+;
+ %1= frem <2 x double> %in, %in
+ ret <2 x double> %1
+}
+
+define <4 x float> @frem_f32(<4 x float> %in) {
+; CHECK-LABEL: @frem_f32(
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @_ZGVnN4vv_fmodf(<4 x float> [[IN:%.*]], <4 x float> [[IN]])
+; CHECK-NEXT: ret <4 x float> [[TMP1]]
+;
+ %1= frem <4 x float> %in, %in
+ ret <4 x float> %1
+}
+
declare <2 x double> @llvm.ceil.v2f64(<2 x double>)
declare <4 x float> @llvm.ceil.v4f32(<4 x float>)
declare <2 x double> @llvm.copysign.v2f64(<2 x double>, <2 x double>)
>From 0daec94f528e731a7e00a3c428faed3ab40ca38c Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 22 Dec 2023 17:21:34 +0000
Subject: [PATCH 2/8] Split replaceWithCallToVeclib to two blocks
One handles CallInst and the other the frem instruction.
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 87 ++++++++++---------
...-armpl.ll => replace-with-veclib-armpl.ll} | 0
... => replace-with-veclib-sleef-scalable.ll} | 0
...-sleef.ll => replace-with-veclib-sleef.ll} | 0
4 files changed, 44 insertions(+), 43 deletions(-)
rename llvm/test/CodeGen/AArch64/{replace-intrinsics-with-veclib-armpl.ll => replace-with-veclib-armpl.ll} (100%)
rename llvm/test/CodeGen/AArch64/{replace-intrinsics-with-veclib-sleef-scalable.ll => replace-with-veclib-sleef-scalable.ll} (100%)
rename llvm/test/CodeGen/AArch64/{replace-intrinsics-with-veclib-sleef.ll => replace-with-veclib-sleef.ll} (100%)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index e3ba9e3c0c3fa3..9aaab2ab1c3503 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -6,9 +6,10 @@
//
//===----------------------------------------------------------------------===//
//
-// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
-// with vector operands) with matching calls to functions from a vector
-// library (e.g., libmvec, SVML) according to TargetLibraryInfo.
+// Replaces instructions to LLVM vector intrinsics (i.e., the frem instruction
+// or calls to LLVM intrinsics with vector operands) with matching calls to
+// functions from a vector library (e.g., libmvec, SVML) according to
+// TargetLibraryInfo.
//
//===----------------------------------------------------------------------===//
@@ -69,9 +70,8 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
return TLIFunc;
}
-/// Replace the Instruction \p I, that may be a vector intrinsic CallInst or
-/// the frem instruction, with a call to the corresponding function from the
-/// vector library ( \p TLIVecFunc ).
+/// Replace the Instruction \p I with a call to the corresponding function from
+/// the vector library ( \p TLIVecFunc ).
static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Function *TLIVecFunc) {
IRBuilder<> IRBuilder(&I);
@@ -98,51 +98,53 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
/// Returns true when successfully replaced \p I with a suitable function taking
/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the FRem Instruction.
+/// works when \p I is a call to vectorized intrinsic or the frem Instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Instruction &I) {
- CallInst *CI = dyn_cast<CallInst>(&I);
- Intrinsic::ID IID = Intrinsic::not_intrinsic;
- if (CI)
- IID = CI->getCalledFunction()->getIntrinsicID();
- // Compute arguments types of the corresponding scalar call. Additionally
- // checks if in the vector call, all vector operands have the same EC.
+ std::string ScalarName;
ElementCount VF = ElementCount::getFixed(0);
+ CallInst *CI = dyn_cast<CallInst>(&I);
SmallVector<Type *, 8> ScalarArgTypes;
- for (auto Arg : enumerate(CI ? CI->args() : I.operands())) {
- auto *ArgTy = Arg.value()->getType();
- if (CI && isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
- ScalarArgTypes.push_back(ArgTy);
- } else {
- auto *VectorArgTy = dyn_cast<VectorType>(ArgTy);
- // We are expecting only VectorTypes, as:
- // - with a CallInst, scalar operands are handled earlier
- // - with the FRem Instruction, both operands must be vectors.
- if (!VectorArgTy)
- return false;
- ScalarArgTypes.push_back(ArgTy->getScalarType());
- // Disallow vector arguments with different VFs. When processing the first
- // vector argument, store it's VF, and for the rest ensure that they match
- // it.
- if (VF.isZero())
- VF = VectorArgTy->getElementCount();
- else if (VF != VectorArgTy->getElementCount())
- return false;
- }
- }
-
- // Try to reconstruct the name for the scalar version of the instruction.
- std::string ScalarName;
if (CI) {
- // For intrinsics, use scalar argument types
+ Intrinsic::ID IID = Intrinsic::not_intrinsic;
+ IID = CI->getCalledFunction()->getIntrinsicID();
+ // Compute arguments types of the corresponding scalar call. Additionally
+ // checks if in the vector call, all vector operands have the same EC.
+ for (auto Arg : enumerate(CI ? CI->args() : I.operands())) {
+ auto *ArgTy = Arg.value()->getType();
+ if (CI && isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+ ScalarArgTypes.push_back(ArgTy);
+ } else {
+ auto *VectorArgTy = dyn_cast<VectorType>(ArgTy);
+ // We are expecting only VectorTypes, as:
+ // - with a CallInst, scalar operands are handled earlier
+ // - with the frem Instruction, both operands must be vectors.
+ if (!VectorArgTy)
+ return false;
+ ScalarArgTypes.push_back(ArgTy->getScalarType());
+ // Disallow vector arguments with different VFs. When processing the
+ // first vector argument, store it's VF, and for the rest ensure that
+ // they match it.
+ if (VF.isZero())
+ VF = VectorArgTy->getElementCount();
+ else if (VF != VectorArgTy->getElementCount())
+ return false;
+ }
+ }
+ // Try to reconstruct the name for the scalar version of the instruction,
+ // using scalar argument types.
ScalarName = Intrinsic::isOverloaded(IID)
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
LibFunc Func;
- if (!TLI.getLibFunc(I.getOpcode(), I.getType()->getScalarType(), Func))
+ auto *ScalarTy = I.getType()->getScalarType();
+ if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
ScalarName = TLI.getName(Func);
+ ScalarArgTypes = {ScalarTy, ScalarTy};
+ if (auto *VTy = dyn_cast<VectorType>(I.getType()))
+ VF = VTy->getElementCount();
}
// Try to find the mapping for the scalar version of this intrinsic and the
@@ -181,12 +183,11 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
return true;
}
-/// Supported Instructions \p I are either FRem or CallInsts to Intrinsics.
+/// Supported Instructions \p I are either frem or CallInsts to Intrinsics.
static bool isSupportedInstruction(Instruction *I) {
if (auto *CI = dyn_cast<CallInst>(I)) {
- if (!CI->getCalledFunction())
- return false;
- if (CI->getCalledFunction()->getIntrinsicID() == Intrinsic::not_intrinsic)
+ if (!CI->getCalledFunction() ||
+ CI->getCalledFunction()->getIntrinsicID() == Intrinsic::not_intrinsic)
return false;
} else if (I->getOpcode() != Instruction::FRem)
return false;
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
similarity index 100%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
similarity index 100%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
similarity index 100%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
>From 9523a0a6c6b91001cbec31de30659fb74601abc7 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 3 Jan 2024 10:03:04 +0000
Subject: [PATCH 3/8] Addressing reviewers.
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 67 ++++++++++++--------------
1 file changed, 32 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 9aaab2ab1c3503..075802b2c3b888 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -6,10 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
-// Replaces instructions to LLVM vector intrinsics (i.e., the frem instruction
-// or calls to LLVM intrinsics with vector operands) with matching calls to
-// functions from a vector library (e.g., libmvec, SVML) according to
-// TargetLibraryInfo.
+// Replaces LLVM IR instructions with vector operands (i.e., the frem
+// instruction or calls to LLVM intrinsics) with matching calls to functions
+// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface
//
//===----------------------------------------------------------------------===//
@@ -70,7 +69,7 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
return TLIFunc;
}
-/// Replace the Instruction \p I with a call to the corresponding function from
+/// Replace the instruction \p I with a call to the corresponding function from
/// the vector library ( \p TLIVecFunc ).
static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Function *TLIVecFunc) {
@@ -84,7 +83,7 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Constant::getAllOnesValue(MaskTy));
}
- // Preserve the operand bundles for CallInsts.
+ // If it is a call instruction, preserve the operand bundles.
SmallVector<OperandBundleDef, 1> OpBundles;
if (CI)
CI->getOperandBundlesAsDefs(OpBundles);
@@ -98,38 +97,35 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
/// Returns true when successfully replaced \p I with a suitable function taking
/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the frem Instruction.
+/// works when \p I is a call to vectorized intrinsic or the frem instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Instruction &I) {
std::string ScalarName;
- ElementCount VF = ElementCount::getFixed(0);
+ ElementCount EC = ElementCount::getFixed(0);
CallInst *CI = dyn_cast<CallInst>(&I);
SmallVector<Type *, 8> ScalarArgTypes;
+ // Compute the argument types of the corresponding scalar call, the scalar
+ // function name, and EC. For CI, it additionally checks if in the vector
+ // call, all vector operands have the same EC.
if (CI) {
Intrinsic::ID IID = Intrinsic::not_intrinsic;
IID = CI->getCalledFunction()->getIntrinsicID();
- // Compute arguments types of the corresponding scalar call. Additionally
- // checks if in the vector call, all vector operands have the same EC.
- for (auto Arg : enumerate(CI ? CI->args() : I.operands())) {
+ for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
- if (CI && isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+ if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
- } else {
- auto *VectorArgTy = dyn_cast<VectorType>(ArgTy);
- // We are expecting only VectorTypes, as:
- // - with a CallInst, scalar operands are handled earlier
- // - with the frem Instruction, both operands must be vectors.
- if (!VectorArgTy)
- return false;
- ScalarArgTypes.push_back(ArgTy->getScalarType());
+ } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+ ScalarArgTypes.push_back(VectorArgTy->getElementType());
// Disallow vector arguments with different VFs. When processing the
// first vector argument, store it's VF, and for the rest ensure that
// they match it.
- if (VF.isZero())
- VF = VectorArgTy->getElementCount();
- else if (VF != VectorArgTy->getElementCount())
+ if (EC.isZero())
+ EC = VectorArgTy->getElementCount();
+ else if (EC != VectorArgTy->getElementCount())
return false;
- }
+ } else
+ // Exit when it is supposed to be a vector argument but it isn't.
+ return false;
}
// Try to reconstruct the name for the scalar version of the instruction,
// using scalar argument types.
@@ -137,6 +133,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
+ assert(I.getType()->isVectorTy() && "Instruction must use vectors");
LibFunc Func;
auto *ScalarTy = I.getType()->getScalarType();
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
@@ -144,19 +141,19 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
ScalarName = TLI.getName(Func);
ScalarArgTypes = {ScalarTy, ScalarTy};
if (auto *VTy = dyn_cast<VectorType>(I.getType()))
- VF = VTy->getElementCount();
+ EC = VTy->getElementCount();
}
// Try to find the mapping for the scalar version of this intrinsic and the
// exact vector width of the call operands in the TargetLibraryInfo. First,
// check with a non-masked variant, and if that fails try with a masked one.
const VecDesc *VD =
- TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ false);
- if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ true)))
+ TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
+ if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
return false;
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
- << "` and vector width " << VF << " to: `"
+ << "` and vector width " << EC << " to: `"
<< VD->getVectorFnName() << "`.\n");
// Replace the call to the intrinsic with a call to the vector library
@@ -183,16 +180,16 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
return true;
}
-/// Supported Instructions \p I are either frem or CallInsts to Intrinsics.
+/// Supported instructions \p I are either frem or CallInsts to intrinsics.
static bool isSupportedInstruction(Instruction *I) {
if (auto *CI = dyn_cast<CallInst>(I)) {
- if (!CI->getCalledFunction() ||
- CI->getCalledFunction()->getIntrinsicID() == Intrinsic::not_intrinsic)
- return false;
- } else if (I->getOpcode() != Instruction::FRem)
- return false;
+ if (CI->getCalledFunction() &&
+ CI->getCalledFunction()->getIntrinsicID() != Intrinsic::not_intrinsic)
+ return true;
+ } else if (I->getOpcode() == Instruction::FRem && I->getType()->isVectorTy())
+ return true;
- return true;
+ return false;
}
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
>From 285279b6b75170afe19baba65d8f5c03dba2f3e8 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 3 Jan 2024 14:40:10 +0000
Subject: [PATCH 4/8] Addressing reviewers (2)
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 29 +++++++++++++-------------
1 file changed, 15 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 075802b2c3b888..9a4f5df52bc73f 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -88,7 +88,7 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
if (CI)
CI->getOperandBundlesAsDefs(OpBundles);
- CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+ auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
I.replaceAllUsesWith(Replacement);
// Preserve fast math flags for FP math.
if (isa<FPMathOperator>(Replacement))
@@ -102,14 +102,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Instruction &I) {
std::string ScalarName;
ElementCount EC = ElementCount::getFixed(0);
- CallInst *CI = dyn_cast<CallInst>(&I);
+ Function *FuncToReplace = nullptr;
SmallVector<Type *, 8> ScalarArgTypes;
// Compute the argument types of the corresponding scalar call, the scalar
- // function name, and EC. For CI, it additionally checks if in the vector
+ // function name, and EC. For calls, it additionally checks if in the vector
// call, all vector operands have the same EC.
- if (CI) {
- Intrinsic::ID IID = Intrinsic::not_intrinsic;
- IID = CI->getCalledFunction()->getIntrinsicID();
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
+ Intrinsic::ID IID = CI->getCalledFunction()->getIntrinsicID();
+ assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
+ FuncToReplace = CI->getCalledFunction();
for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
@@ -170,7 +171,6 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
if (!VectorFTy)
return false;
- Function *FuncToReplace = CI ? CI->getCalledFunction() : nullptr;
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
@@ -182,13 +182,12 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
/// Supported instructions \p I are either frem or CallInsts to intrinsics.
static bool isSupportedInstruction(Instruction *I) {
- if (auto *CI = dyn_cast<CallInst>(I)) {
- if (CI->getCalledFunction() &&
- CI->getCalledFunction()->getIntrinsicID() != Intrinsic::not_intrinsic)
- return true;
- } else if (I->getOpcode() == Instruction::FRem && I->getType()->isVectorTy())
+ if (auto *CI = dyn_cast<CallInst>(I))
+ return CI->getCalledFunction() &&
+ CI->getCalledFunction()->getIntrinsicID() !=
+ Intrinsic::not_intrinsic;
+ if (I->getOpcode() == Instruction::FRem && I->getType()->isVectorTy())
return true;
-
return false;
}
@@ -196,7 +195,9 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
bool Changed = false;
SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
- if (isSupportedInstruction(&I) && replaceWithCallToVeclib(TLI, I)) {
+ if (!isSupportedInstruction(&I))
+ continue;
+ if (replaceWithCallToVeclib(TLI, I)) {
ReplacedCalls.push_back(&I);
Changed = true;
}
>From 4b6ed6768e936c9426d86a29b6bb6019cb616d78 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 09:06:48 +0000
Subject: [PATCH 5/8] Better handling of ElementCount
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 38 +++++++++++++-------------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 9a4f5df52bc73f..beb90ecd85489f 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -8,7 +8,7 @@
//
// Replaces LLVM IR instructions with vector operands (i.e., the frem
// instruction or calls to LLVM intrinsics) with matching calls to functions
-// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface
+// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
//
//===----------------------------------------------------------------------===//
@@ -70,7 +70,7 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
}
/// Replace the instruction \p I with a call to the corresponding function from
-/// the vector library ( \p TLIVecFunc ).
+/// the vector library (\p TLIVecFunc).
static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Function *TLIVecFunc) {
IRBuilder<> IRBuilder(&I);
@@ -100,26 +100,26 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
/// works when \p I is a call to vectorized intrinsic or the frem instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Instruction &I) {
+ auto *VTy = dyn_cast<VectorType>(I.getType());
+ ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+ // Compute the argument types of the corresponding scalar call and the scalar
+ // function name. For calls, it additionally finds the function to replace
+ // and checks that all vector operands match the previously found EC.
+ SmallVector<Type *, 8> ScalarArgTypes;
std::string ScalarName;
- ElementCount EC = ElementCount::getFixed(0);
Function *FuncToReplace = nullptr;
- SmallVector<Type *, 8> ScalarArgTypes;
- // Compute the argument types of the corresponding scalar call, the scalar
- // function name, and EC. For calls, it additionally checks if in the vector
- // call, all vector operands have the same EC.
if (auto *CI = dyn_cast<CallInst>(&I)) {
- Intrinsic::ID IID = CI->getCalledFunction()->getIntrinsicID();
- assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
FuncToReplace = CI->getCalledFunction();
+ Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
+ assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
ScalarArgTypes.push_back(VectorArgTy->getElementType());
- // Disallow vector arguments with different VFs. When processing the
- // first vector argument, store it's VF, and for the rest ensure that
- // they match it.
+ // When return type is void, set EC to the first vector argument, and
+ // disallow vector arguments with different ECs.
if (EC.isZero())
EC = VectorArgTy->getElementCount();
else if (EC != VectorArgTy->getElementCount())
@@ -134,15 +134,13 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
- assert(I.getType()->isVectorTy() && "Instruction must use vectors");
+ assert(VTy && "Return type must be a vector");
+ auto *ScalarTy = VTy->getScalarType();
LibFunc Func;
- auto *ScalarTy = I.getType()->getScalarType();
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
ScalarName = TLI.getName(Func);
ScalarArgTypes = {ScalarTy, ScalarTy};
- if (auto *VTy = dyn_cast<VectorType>(I.getType()))
- EC = VTy->getElementCount();
}
// Try to find the mapping for the scalar version of this intrinsic and the
@@ -180,13 +178,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
return true;
}
-/// Supported instructions \p I are either frem or CallInsts to intrinsics.
+/// Supported instruction \p I must be a vectorized frem or a call to an
+/// intrinsic that returns either void or a vector.
static bool isSupportedInstruction(Instruction *I) {
+ Type *Ty = I->getType();
if (auto *CI = dyn_cast<CallInst>(I))
- return CI->getCalledFunction() &&
+ return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
CI->getCalledFunction()->getIntrinsicID() !=
Intrinsic::not_intrinsic;
- if (I->getOpcode() == Instruction::FRem && I->getType()->isVectorTy())
+ if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
return true;
return false;
}
>From 31f66085b1ca47606a6e5b70b49bdb18c26f934b Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 12:42:54 +0000
Subject: [PATCH 6/8] Addressing reviewers (3)
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index beb90ecd85489f..56025aa5c45fb3 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -100,8 +100,11 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
/// works when \p I is a call to vectorized intrinsic or the frem instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Instruction &I) {
+ // At the moment VFABI assumes the return type is always widened unless it is
+ // a void type.
auto *VTy = dyn_cast<VectorType>(I.getType());
ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
// and checks that all vector operands match the previously found EC.
>From 254da80562c53538b96ae77146f303e776c68bd1 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 15:13:21 +0000
Subject: [PATCH 7/8] Pass replace-with-veclib crashes with invalid arguments
---
llvm/unittests/Analysis/CMakeLists.txt | 1 +
.../Analysis/ReplaceWithVecLibTest.cpp | 86 +++++++++++++++++++
2 files changed, 87 insertions(+)
create mode 100644 llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 847430bf17697a..e7505f2633d92d 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ set(ANALYSIS_TEST_SOURCES
PluginInlineAdvisorAnalysisTest.cpp
PluginInlineOrderAnalysisTest.cpp
ProfileSummaryInfoTest.cpp
+ ReplaceWithVecLibTest.cpp
ScalarEvolutionTest.cpp
VectorFunctionABITest.cpp
SparsePropagation.cpp
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
new file mode 100644
index 00000000000000..8f80c67b2ed414
--- /dev/null
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -0,0 +1,86 @@
+//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ReplaceWithVeclib.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+ if (!Mod)
+ Err.print("ReplaceWithVecLibTest", errs());
+ return Mod;
+}
+
+/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
+/// allows checking that the pass won't crash when the function to replace (from
+/// the input IR) does not match the replacement function (derived from the
+/// VecDesc mapping).
+class ReplaceWithVecLibTest : public ::testing::Test {
+protected:
+ LLVMContext Ctx;
+
+ /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
+ /// pass. The pass should not crash even when the replacement function
+ /// (derived from the \p VD mapping) does not match the function to be
+ /// replaced (from the input \p IR).
+ bool run(const VecDesc &VD, const char *IR) {
+ // Create TLII and register it with FAM so it's preserved when
+ // ReplaceWithVeclib pass runs.
+ TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
+ TLII.addVectorizableFunctions({VD});
+ FunctionAnalysisManager FAM;
+ FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
+
+ // Register and run the pass on the 'foo' function from the input IR.
+ FunctionPassManager FPM;
+ FPM.addPass(ReplaceWithVeclib());
+ std::unique_ptr<Module> M = parseIR(Ctx, IR);
+ PassBuilder PB;
+ PB.registerFunctionAnalyses(FAM);
+ FPM.run(*M->getFunction("foo"), FAM);
+
+ return true;
+ }
+};
+
+} // end anonymous namespace
+
+static const char *IR = R"IR(
+define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
+ %call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
+ ret <vscale x 4 x float> %call
+}
+
+declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
+)IR";
+
+// LLVM intrinsic 'powi' (in IR) has the same signature with the VecDesc.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+ VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+ ElementCount::getScalable(4), true, "_ZGVsMxvu"};
+ EXPECT_TRUE(run(CorrectVD, IR));
+}
+
+// LLVM intrinsic 'powi' (in IR) has different signature with the VecDesc.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+ VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+ ElementCount::getScalable(4), true, "_ZGVsMxvv"};
+ /// TODO: test should avoid and not crash.
+ EXPECT_DEATH(run(IncorrectVD, IR), "");
+}
>From af5c683ec5427abf18f9b303824b91020cb3a414 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 15:25:30 +0000
Subject: [PATCH 8/8] [TLI] Fix replace-with-veclib crashes with invalid
arguments.
replace-with-veclib used to crash when the arguments of the TLI mapping
did not match the arguments of the mapping. Now, it simply ignores such
cases.
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 24 +++++++++++++++++--
.../Analysis/ReplaceWithVecLibTest.cpp | 3 +--
2 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 56025aa5c45fb3..92f2d006fd79c2 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,15 +108,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
// and checks that all vector operands match the previously found EC.
- SmallVector<Type *, 8> ScalarArgTypes;
+ SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
std::string ScalarName;
Function *FuncToReplace = nullptr;
- if (auto *CI = dyn_cast<CallInst>(&I)) {
+ auto *CI = dyn_cast<CallInst>(&I);
+ if (CI) {
FuncToReplace = CI->getCalledFunction();
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
+ OrigArgTypes.push_back(ArgTy);
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -174,6 +176,24 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
+
+ // For calls, bail out when their arguments do not match with the TLI mapping.
+ if (CI) {
+ int IdxNonPred = 0;
+ for (auto [OrigTy, VFParam] :
+ zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+ if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+ continue;
+ ++IdxNonPred;
+ if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE
+ << ": Will not replace: wrong type at index: "
+ << IdxNonPred << ": " << *OrigTy << "\n");
+ return false;
+ }
+ }
+ }
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 8f80c67b2ed414..858f72894861c1 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -81,6 +81,5 @@ TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
ElementCount::getScalable(4), true, "_ZGVsMxvv"};
- /// TODO: test should avoid and not crash.
- EXPECT_DEATH(run(IncorrectVD, IR), "");
+ EXPECT_TRUE(run(IncorrectVD, IR));
}
More information about the llvm-commits
mailing list