[llvm] [TLI] Fix replace-with-veclib crash with invalid arguments (PR #77112)
Paschalis Mpeis via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 9 05:11:40 PST 2024
https://github.com/paschalis-mpeis updated https://github.com/llvm/llvm-project/pull/77112
>From 85f89de6caaa6f6b86593d8057302b95120ca125 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 15:13:21 +0000
Subject: [PATCH 1/4] Pass replace-with-veclib crashes with invalid arguments
---
llvm/unittests/Analysis/CMakeLists.txt | 1 +
.../Analysis/ReplaceWithVecLibTest.cpp | 86 +++++++++++++++++++
2 files changed, 87 insertions(+)
create mode 100644 llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 847430bf17697a..e7505f2633d92d 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ set(ANALYSIS_TEST_SOURCES
PluginInlineAdvisorAnalysisTest.cpp
PluginInlineOrderAnalysisTest.cpp
ProfileSummaryInfoTest.cpp
+ ReplaceWithVecLibTest.cpp
ScalarEvolutionTest.cpp
VectorFunctionABITest.cpp
SparsePropagation.cpp
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
new file mode 100644
index 00000000000000..8f80c67b2ed414
--- /dev/null
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -0,0 +1,86 @@
+//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ReplaceWithVeclib.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+ if (!Mod)
+ Err.print("ReplaceWithVecLibTest", errs());
+ return Mod;
+}
+
+/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
+/// allows checking that the pass won't crash when the function to replace (from
+/// the input IR) does not match the replacement function (derived from the
+/// VecDesc mapping).
+class ReplaceWithVecLibTest : public ::testing::Test {
+protected:
+ LLVMContext Ctx;
+
+ /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
+ /// pass. The pass should not crash even when the replacement function
+ /// (derived from the \p VD mapping) does not match the function to be
+ /// replaced (from the input \p IR).
+ bool run(const VecDesc &VD, const char *IR) {
+ // Create TLII and register it with FAM so it's preserved when
+ // ReplaceWithVeclib pass runs.
+ TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
+ TLII.addVectorizableFunctions({VD});
+ FunctionAnalysisManager FAM;
+ FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
+
+ // Register and run the pass on the 'foo' function from the input IR.
+ FunctionPassManager FPM;
+ FPM.addPass(ReplaceWithVeclib());
+ std::unique_ptr<Module> M = parseIR(Ctx, IR);
+ PassBuilder PB;
+ PB.registerFunctionAnalyses(FAM);
+ FPM.run(*M->getFunction("foo"), FAM);
+
+ return true;
+ }
+};
+
+} // end anonymous namespace
+
+static const char *IR = R"IR(
+define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
+ %call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
+ ret <vscale x 4 x float> %call
+}
+
+declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
+)IR";
+
+// LLVM intrinsic 'powi' (in IR) has the same signature with the VecDesc.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+ VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+ ElementCount::getScalable(4), true, "_ZGVsMxvu"};
+ EXPECT_TRUE(run(CorrectVD, IR));
+}
+
+// LLVM intrinsic 'powi' (in IR) has different signature with the VecDesc.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+ VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+ ElementCount::getScalable(4), true, "_ZGVsMxvv"};
+ /// TODO: test should avoid and not crash.
+ EXPECT_DEATH(run(IncorrectVD, IR), "");
+}
>From 9dfc955fa13264e72a85b4dc93c6582cafc486d2 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 15:25:30 +0000
Subject: [PATCH 2/4] [TLI] Fix replace-with-veclib crashes with invalid
arguments.
replace-with-veclib used to crash when the arguments of the TLI mapping
did not match the arguments of the mapping. Now, it simply ignores such
cases.
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 24 +++++++++++++++++--
.../Analysis/ReplaceWithVecLibTest.cpp | 3 +--
2 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 56025aa5c45fb3..92f2d006fd79c2 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,15 +108,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
// and checks that all vector operands match the previously found EC.
- SmallVector<Type *, 8> ScalarArgTypes;
+ SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
std::string ScalarName;
Function *FuncToReplace = nullptr;
- if (auto *CI = dyn_cast<CallInst>(&I)) {
+ auto *CI = dyn_cast<CallInst>(&I);
+ if (CI) {
FuncToReplace = CI->getCalledFunction();
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
for (auto Arg : enumerate(CI->args())) {
auto *ArgTy = Arg.value()->getType();
+ OrigArgTypes.push_back(ArgTy);
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -174,6 +176,24 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
+
+ // For calls, bail out when their arguments do not match with the TLI mapping.
+ if (CI) {
+ int IdxNonPred = 0;
+ for (auto [OrigTy, VFParam] :
+ zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+ if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+ continue;
+ ++IdxNonPred;
+ if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE
+ << ": Will not replace: wrong type at index: "
+ << IdxNonPred << ": " << *OrigTy << "\n");
+ return false;
+ }
+ }
+ }
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 8f80c67b2ed414..858f72894861c1 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -81,6 +81,5 @@ TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
ElementCount::getScalable(4), true, "_ZGVsMxvv"};
- /// TODO: test should avoid and not crash.
- EXPECT_DEATH(run(IncorrectVD, IR), "");
+ EXPECT_TRUE(run(IncorrectVD, IR));
}
>From c93f1c097773d8ffb086990a546a517919ad6a54 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Tue, 9 Jan 2024 10:17:16 +0000
Subject: [PATCH 3/4] Using VFParam.ParamPos to access Types
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 27 ++++++++++---------
.../Analysis/ReplaceWithVecLibTest.cpp | 12 ++++++---
2 files changed, 22 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 92f2d006fd79c2..64025bfe031dd4 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -170,30 +170,31 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
if (!OptInfo)
return false;
- FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
- if (!VectorFTy)
- return false;
-
- Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
- VD->getVectorFnName(), FuncToReplace);
-
- // For calls, bail out when their arguments do not match with the TLI mapping.
+ // There is no guarantee that the vectorized instructions followed the VFABI
+ // specification when being created, this is why we need to add extra check to
+ // make sure that the operands of the vector function obtained via VFABI match
+ // the operands of the original vector instruction.
if (CI) {
- int IdxNonPred = 0;
- for (auto [OrigTy, VFParam] :
- zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+ for (auto VFParam : OptInfo->Shape.Parameters) {
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
continue;
- ++IdxNonPred;
+ Type *OrigTy = OrigArgTypes[VFParam.ParamPos];
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
<< ": Will not replace: wrong type at index: "
- << IdxNonPred << ": " << *OrigTy << "\n");
+ << VFParam.ParamPos << ": " << *OrigTy << "\n");
return false;
}
}
}
+ FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
+ if (!VectorFTy)
+ return false;
+
+ Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
+ VD->getVectorFnName(), FuncToReplace);
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 858f72894861c1..400743af97c8cf 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -70,16 +70,20 @@ define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
)IR";
-// LLVM intrinsic 'powi' (in IR) has the same signature with the VecDesc.
+// The VFABI prefix in TLI describes signature which is matching the powi
+// intrinsic declaration.
TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
- ElementCount::getScalable(4), true, "_ZGVsMxvu"};
+ ElementCount::getScalable(4), /*Masked*/ true,
+ "_ZGVsMxvu"};
EXPECT_TRUE(run(CorrectVD, IR));
}
-// LLVM intrinsic 'powi' (in IR) has different signature with the VecDesc.
+// The VFABI prefix in TLI describes signature which is not matching the powi
+// intrinsic declaration.
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
- ElementCount::getScalable(4), true, "_ZGVsMxvv"};
+ ElementCount::getScalable(4), /*Masked*/ true,
+ "_ZGVsMxvv"};
EXPECT_TRUE(run(IncorrectVD, IR));
}
>From 39cb3532849edc0b46150d45ecc9f10965525ecd Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Tue, 9 Jan 2024 12:47:12 +0000
Subject: [PATCH 4/4] Addressing reviewers
---
llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 64025bfe031dd4..7dccff1bbac746 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -178,6 +178,11 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
for (auto VFParam : OptInfo->Shape.Parameters) {
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
continue;
+
+ // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
+ // a bug in the VFABI parser.
+ assert(VFParam.ParamPos < OrigArgTypes.size() &&
+ "ParamPos has invalid range.");
Type *OrigTy = OrigArgTypes[VFParam.ParamPos];
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE
More information about the llvm-commits
mailing list