[llvm] 9fdc568 - [TLI] Fix replace-with-veclib crash with invalid arguments (#77112)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 12 07:19:56 PST 2024
Author: Paschalis Mpeis
Date: 2024-01-12T15:19:52Z
New Revision: 9fdc568824b0992d48704dfa530a12073cc02f5e
URL: https://github.com/llvm/llvm-project/commit/9fdc568824b0992d48704dfa530a12073cc02f5e
DIFF: https://github.com/llvm/llvm-project/commit/9fdc568824b0992d48704dfa530a12073cc02f5e.diff
LOG: [TLI] Fix replace-with-veclib crash with invalid arguments (#77112)
Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI
mapping do not match the original call.
Now, it simply ignores such cases.
Test require assertions as it accesses programmatically the debug log.
Added:
llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
Modified:
llvm/lib/CodeGen/ReplaceWithVeclib.cpp
llvm/unittests/Analysis/CMakeLists.txt
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 56025aa5c45fb3..7b0215535a92c8 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -111,7 +111,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
SmallVector<Type *, 8> ScalarArgTypes;
std::string ScalarName;
Function *FuncToReplace = nullptr;
- if (auto *CI = dyn_cast<CallInst>(&I)) {
+ auto *CI = dyn_cast<CallInst>(&I);
+ if (CI) {
FuncToReplace = CI->getCalledFunction();
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
@@ -168,12 +169,36 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
if (!OptInfo)
return false;
+ // There is no guarantee that the vectorized instructions followed the VFABI
+ // specification when being created, this is why we need to add extra check to
+ // make sure that the operands of the vector function obtained via VFABI match
+ // the operands of the original vector instruction.
+ if (CI) {
+ for (auto VFParam : OptInfo->Shape.Parameters) {
+ if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+ continue;
+
+ // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
+ // a bug in the VFABI parser.
+ assert(VFParam.ParamPos < CI->arg_size() &&
+ "ParamPos has invalid range.");
+ Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
+ if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
+ << ". Wrong type at index " << VFParam.ParamPos
+ << ": " << *OrigTy << "\n");
+ return false;
+ }
+ }
+ }
+
FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
if (!VectorFTy)
return false;
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -220,6 +245,9 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto Changed = runImpl(TLI, F);
if (Changed) {
+ LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
+ << NumCallsReplaced << "\n");
+
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<TargetLibraryAnalysis>();
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 847430bf17697a..e7505f2633d92d 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ set(ANALYSIS_TEST_SOURCES
PluginInlineAdvisorAnalysisTest.cpp
PluginInlineOrderAnalysisTest.cpp
ProfileSummaryInfoTest.cpp
+ ReplaceWithVecLibTest.cpp
ScalarEvolutionTest.cpp
VectorFunctionABITest.cpp
SparsePropagation.cpp
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
new file mode 100644
index 00000000000000..a1f0a4a894c8d1
--- /dev/null
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -0,0 +1,113 @@
+//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ReplaceWithVeclib.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+/// NOTE: Assertions must be enabled for these tests to run.
+#ifndef NDEBUG
+
+namespace {
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+ if (!Mod)
+ Err.print("ReplaceWithVecLibTest", errs());
+ return Mod;
+}
+
+/// Runs ReplaceWithVecLib with
diff erent TLIIs that have custom VecDescs. This
+/// allows checking that the pass won't crash when the function to replace (from
+/// the input IR) does not match the replacement function (derived from the
+/// VecDesc mapping).
+///
+class ReplaceWithVecLibTest : public ::testing::Test {
+
+ std::string getLastLine(std::string Out) {
+ // remove any trailing '\n'
+ if (!Out.empty() && *(Out.cend() - 1) == '\n')
+ Out.pop_back();
+
+ size_t LastNL = Out.find_last_of('\n');
+ return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1);
+ }
+
+protected:
+ LLVMContext Ctx;
+
+ /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
+ /// pass. The pass should not crash even when the replacement function
+ /// (derived from the \p VD mapping) does not match the function to be
+ /// replaced (from the input \p IR).
+ ///
+ /// \returns the last line of the standard error to be compared for
+ /// correctness.
+ std::string run(const VecDesc &VD, const char *IR) {
+ // Create TLII and register it with FAM so it's preserved when
+ // ReplaceWithVeclib pass runs.
+ TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
+ TLII.addVectorizableFunctions({VD});
+ FunctionAnalysisManager FAM;
+ FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
+
+ // Register and run the pass on the 'foo' function from the input IR.
+ FunctionPassManager FPM;
+ FPM.addPass(ReplaceWithVeclib());
+ std::unique_ptr<Module> M = parseIR(Ctx, IR);
+ PassBuilder PB;
+ PB.registerFunctionAnalyses(FAM);
+
+ // Enable debugging and capture std error
+ llvm::DebugFlag = true;
+ testing::internal::CaptureStderr();
+ FPM.run(*M->getFunction("foo"), FAM);
+ return getLastLine(testing::internal::GetCapturedStderr());
+ }
+};
+
+} // end anonymous namespace
+
+static const char *IR = R"IR(
+define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
+ %call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
+ ret <vscale x 4 x float> %call
+}
+
+declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
+)IR";
+
+// The VFABI prefix in TLI describes signature which is matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+ VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+ ElementCount::getScalable(4), /*Masked*/ true,
+ "_ZGVsMxvu"};
+ EXPECT_EQ(run(CorrectVD, IR),
+ "Instructions replaced with vector libraries: 1");
+}
+
+// The VFABI prefix in TLI describes signature which is not matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+ VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+ ElementCount::getScalable(4), /*Masked*/ true,
+ "_ZGVsMxvv"};
+ EXPECT_EQ(run(IncorrectVD, IR),
+ "replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong "
+ "type at index 1: i32");
+}
+#endif
\ No newline at end of file
More information about the llvm-commits
mailing list