[llvm] Reapply [TLI] Fix replace-with-veclib crash with invalid arguments (#77112) (PR #77945)

Paschalis Mpeis via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 12 09:13:33 PST 2024


https://github.com/paschalis-mpeis updated https://github.com/llvm/llvm-project/pull/77945

>From 8ed527c2a5dc9dd1b899caa5d1663e504b072069 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <paschalis.mpeis at arm.com>
Date: Fri, 12 Jan 2024 17:19:52 +0200
Subject: [PATCH] Reapply [TLI] Fix replace-with-veclib crash with invalid
 arguments (#77112)

Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI
mapping do not match the original call.
Now, it simply ignores such cases.

Test require assertions as it accesses programmatically the debug log.

NOTE: Originally submitted by commit 9fdc568824b0, which was reverted by
a300b2403784, as it was causing some linking issues:
https://lab.llvm.org/buildbot/#/builders/234/builds/17734
---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp        |  30 ++++-
 llvm/unittests/Analysis/CMakeLists.txt        |   2 +
 .../Analysis/ReplaceWithVecLibTest.cpp        | 113 ++++++++++++++++++
 3 files changed, 144 insertions(+), 1 deletion(-)
 create mode 100644 llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 56025aa5c45fb3..7b0215535a92c8 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -111,7 +111,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   SmallVector<Type *, 8> ScalarArgTypes;
   std::string ScalarName;
   Function *FuncToReplace = nullptr;
-  if (auto *CI = dyn_cast<CallInst>(&I)) {
+  auto *CI = dyn_cast<CallInst>(&I);
+  if (CI) {
     FuncToReplace = CI->getCalledFunction();
     Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
     assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
@@ -168,12 +169,36 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   if (!OptInfo)
     return false;
 
+  // There is no guarantee that the vectorized instructions followed the VFABI
+  // specification when being created, this is why we need to add extra check to
+  // make sure that the operands of the vector function obtained via VFABI match
+  // the operands of the original vector instruction.
+  if (CI) {
+    for (auto VFParam : OptInfo->Shape.Parameters) {
+      if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+        continue;
+
+      // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
+      // a bug in the VFABI parser.
+      assert(VFParam.ParamPos < CI->arg_size() &&
+             "ParamPos has invalid range.");
+      Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
+      if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+        LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
+                          << ". Wrong type at index " << VFParam.ParamPos
+                          << ": " << *OrigTy << "\n");
+        return false;
+      }
+    }
+  }
+
   FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
   if (!VectorFTy)
     return false;
 
   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
                                      VD->getVectorFnName(), FuncToReplace);
+
   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
                     << "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -220,6 +245,9 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
   auto Changed = runImpl(TLI, F);
   if (Changed) {
+    LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
+                      << NumCallsReplaced << "\n");
+
     PreservedAnalyses PA;
     PA.preserveSet<CFGAnalyses>();
     PA.preserve<TargetLibraryAnalysis>();
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 847430bf17697a..1f9b7da5f4b1d1 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -1,6 +1,7 @@
 set(LLVM_LINK_COMPONENTS
   Analysis
   AsmParser
+  CodeGen
   Core
   Passes
   Support
@@ -40,6 +41,7 @@ set(ANALYSIS_TEST_SOURCES
   PluginInlineAdvisorAnalysisTest.cpp
   PluginInlineOrderAnalysisTest.cpp
   ProfileSummaryInfoTest.cpp
+  ReplaceWithVecLibTest.cpp
   ScalarEvolutionTest.cpp
   VectorFunctionABITest.cpp
   SparsePropagation.cpp
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
new file mode 100644
index 00000000000000..a1f0a4a894c8d1
--- /dev/null
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -0,0 +1,113 @@
+//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ReplaceWithVeclib.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+/// NOTE: Assertions must be enabled for these tests to run.
+#ifndef NDEBUG
+
+namespace {
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+  if (!Mod)
+    Err.print("ReplaceWithVecLibTest", errs());
+  return Mod;
+}
+
+/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
+/// allows checking that the pass won't crash when the function to replace (from
+/// the input IR) does not match the replacement function (derived from the
+/// VecDesc mapping).
+///
+class ReplaceWithVecLibTest : public ::testing::Test {
+
+  std::string getLastLine(std::string Out) {
+    // remove any trailing '\n'
+    if (!Out.empty() && *(Out.cend() - 1) == '\n')
+      Out.pop_back();
+
+    size_t LastNL = Out.find_last_of('\n');
+    return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1);
+  }
+
+protected:
+  LLVMContext Ctx;
+
+  /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
+  /// pass. The pass should not crash even when the replacement function
+  /// (derived from the \p VD mapping) does not match the function to be
+  /// replaced (from the input \p IR).
+  ///
+  /// \returns the last line of the standard error to be compared for
+  /// correctness.
+  std::string run(const VecDesc &VD, const char *IR) {
+    // Create TLII and register it with FAM so it's preserved when
+    // ReplaceWithVeclib pass runs.
+    TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
+    TLII.addVectorizableFunctions({VD});
+    FunctionAnalysisManager FAM;
+    FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
+
+    // Register and run the pass on the 'foo' function from the input IR.
+    FunctionPassManager FPM;
+    FPM.addPass(ReplaceWithVeclib());
+    std::unique_ptr<Module> M = parseIR(Ctx, IR);
+    PassBuilder PB;
+    PB.registerFunctionAnalyses(FAM);
+
+    // Enable debugging and capture std error
+    llvm::DebugFlag = true;
+    testing::internal::CaptureStderr();
+    FPM.run(*M->getFunction("foo"), FAM);
+    return getLastLine(testing::internal::GetCapturedStderr());
+  }
+};
+
+} // end anonymous namespace
+
+static const char *IR = R"IR(
+define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
+  %call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
+  ret <vscale x 4 x float> %call
+}
+
+declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
+)IR";
+
+// The VFABI prefix in TLI describes signature which is matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+  VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+                       ElementCount::getScalable(4), /*Masked*/ true,
+                       "_ZGVsMxvu"};
+  EXPECT_EQ(run(CorrectVD, IR),
+            "Instructions replaced with vector libraries: 1");
+}
+
+// The VFABI prefix in TLI describes signature which is not matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+  VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+                         ElementCount::getScalable(4), /*Masked*/ true,
+                         "_ZGVsMxvv"};
+  EXPECT_EQ(run(IncorrectVD, IR),
+            "replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong "
+            "type at index 1: i32");
+}
+#endif
\ No newline at end of file



More information about the llvm-commits mailing list