[llvm] [TLI] Fix replace-with-veclib crash with invalid arguments (PR #77112)

Paschalis Mpeis via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 06:16:36 PST 2024


https://github.com/paschalis-mpeis updated https://github.com/llvm/llvm-project/pull/77112

>From 85f89de6caaa6f6b86593d8057302b95120ca125 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 15:13:21 +0000
Subject: [PATCH 1/8] Pass replace-with-veclib crashes with invalid arguments

---
 llvm/unittests/Analysis/CMakeLists.txt        |  1 +
 .../Analysis/ReplaceWithVecLibTest.cpp        | 86 +++++++++++++++++++
 2 files changed, 87 insertions(+)
 create mode 100644 llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp

diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 847430bf17697a..e7505f2633d92d 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ set(ANALYSIS_TEST_SOURCES
   PluginInlineAdvisorAnalysisTest.cpp
   PluginInlineOrderAnalysisTest.cpp
   ProfileSummaryInfoTest.cpp
+  ReplaceWithVecLibTest.cpp
   ScalarEvolutionTest.cpp
   VectorFunctionABITest.cpp
   SparsePropagation.cpp
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
new file mode 100644
index 00000000000000..8f80c67b2ed414
--- /dev/null
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -0,0 +1,86 @@
+//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ReplaceWithVeclib.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+  if (!Mod)
+    Err.print("ReplaceWithVecLibTest", errs());
+  return Mod;
+}
+
+/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
+/// allows checking that the pass won't crash when the function to replace (from
+/// the input IR) does not match the replacement function (derived from the
+/// VecDesc mapping).
+class ReplaceWithVecLibTest : public ::testing::Test {
+protected:
+  LLVMContext Ctx;
+
+  /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
+  /// pass. The pass should not crash even when the replacement function
+  /// (derived from the \p VD mapping) does not match the function to be
+  /// replaced (from the input \p IR).
+  bool run(const VecDesc &VD, const char *IR) {
+    // Create TLII and register it with FAM so it's preserved when
+    // ReplaceWithVeclib pass runs.
+    TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
+    TLII.addVectorizableFunctions({VD});
+    FunctionAnalysisManager FAM;
+    FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
+
+    // Register and run the pass on the 'foo' function from the input IR.
+    FunctionPassManager FPM;
+    FPM.addPass(ReplaceWithVeclib());
+    std::unique_ptr<Module> M = parseIR(Ctx, IR);
+    PassBuilder PB;
+    PB.registerFunctionAnalyses(FAM);
+    FPM.run(*M->getFunction("foo"), FAM);
+
+    return true;
+  }
+};
+
+} // end anonymous namespace
+
+static const char *IR = R"IR(
+define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
+  %call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
+  ret <vscale x 4 x float> %call
+}
+
+declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
+)IR";
+
+// LLVM intrinsic 'powi' (in IR) has the same signature with the VecDesc.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+  VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+                       ElementCount::getScalable(4), true, "_ZGVsMxvu"};
+  EXPECT_TRUE(run(CorrectVD, IR));
+}
+
+// LLVM intrinsic 'powi' (in IR) has different signature with the VecDesc.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+  VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+                         ElementCount::getScalable(4), true, "_ZGVsMxvv"};
+  /// TODO: test should avoid and not crash.
+  EXPECT_DEATH(run(IncorrectVD, IR), "");
+}

>From 9dfc955fa13264e72a85b4dc93c6582cafc486d2 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Fri, 5 Jan 2024 15:25:30 +0000
Subject: [PATCH 2/8] [TLI] Fix replace-with-veclib crashes with invalid
 arguments.

replace-with-veclib used to crash when the arguments of the TLI mapping
did not match the arguments of the mapping. Now, it simply ignores such
cases.
---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp        | 24 +++++++++++++++++--
 .../Analysis/ReplaceWithVecLibTest.cpp        |  3 +--
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 56025aa5c45fb3..92f2d006fd79c2 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,15 +108,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // Compute the argument types of the corresponding scalar call and the scalar
   // function name. For calls, it additionally finds the function to replace
   // and checks that all vector operands match the previously found EC.
-  SmallVector<Type *, 8> ScalarArgTypes;
+  SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
   std::string ScalarName;
   Function *FuncToReplace = nullptr;
-  if (auto *CI = dyn_cast<CallInst>(&I)) {
+  auto *CI = dyn_cast<CallInst>(&I);
+  if (CI) {
     FuncToReplace = CI->getCalledFunction();
     Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
     assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
     for (auto Arg : enumerate(CI->args())) {
       auto *ArgTy = Arg.value()->getType();
+      OrigArgTypes.push_back(ArgTy);
       if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
         ScalarArgTypes.push_back(ArgTy);
       } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -174,6 +176,24 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
 
   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
                                      VD->getVectorFnName(), FuncToReplace);
+
+  // For calls, bail out when their arguments do not match with the TLI mapping.
+  if (CI) {
+    int IdxNonPred = 0;
+    for (auto [OrigTy, VFParam] :
+         zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+      if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+        continue;
+      ++IdxNonPred;
+      if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+        LLVM_DEBUG(dbgs() << DEBUG_TYPE
+                          << ": Will not replace: wrong type at index: "
+                          << IdxNonPred << ": " << *OrigTy << "\n");
+        return false;
+      }
+    }
+  }
+
   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
                     << "` with call to `" << TLIFunc->getName() << "`.\n");
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 8f80c67b2ed414..858f72894861c1 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -81,6 +81,5 @@ TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
 TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
   VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
                          ElementCount::getScalable(4), true, "_ZGVsMxvv"};
-  /// TODO: test should avoid and not crash.
-  EXPECT_DEATH(run(IncorrectVD, IR), "");
+  EXPECT_TRUE(run(IncorrectVD, IR));
 }

>From c93f1c097773d8ffb086990a546a517919ad6a54 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Tue, 9 Jan 2024 10:17:16 +0000
Subject: [PATCH 3/8] Using VFParam.ParamPos to access Types

---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp        | 27 ++++++++++---------
 .../Analysis/ReplaceWithVecLibTest.cpp        | 12 ++++++---
 2 files changed, 22 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 92f2d006fd79c2..64025bfe031dd4 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -170,30 +170,31 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   if (!OptInfo)
     return false;
 
-  FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
-  if (!VectorFTy)
-    return false;
-
-  Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
-                                     VD->getVectorFnName(), FuncToReplace);
-
-  // For calls, bail out when their arguments do not match with the TLI mapping.
+  // There is no guarantee that the vectorized instructions followed the VFABI
+  // specification when being created, this is why we need to add extra check to
+  // make sure that the operands of the vector function obtained via VFABI match
+  // the operands of the original vector instruction.
   if (CI) {
-    int IdxNonPred = 0;
-    for (auto [OrigTy, VFParam] :
-         zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+    for (auto VFParam : OptInfo->Shape.Parameters) {
       if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
         continue;
-      ++IdxNonPred;
+      Type *OrigTy = OrigArgTypes[VFParam.ParamPos];
       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
         LLVM_DEBUG(dbgs() << DEBUG_TYPE
                           << ": Will not replace: wrong type at index: "
-                          << IdxNonPred << ": " << *OrigTy << "\n");
+                          << VFParam.ParamPos << ": " << *OrigTy << "\n");
         return false;
       }
     }
   }
 
+  FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
+  if (!VectorFTy)
+    return false;
+
+  Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
+                                     VD->getVectorFnName(), FuncToReplace);
+
   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
                     << "` with call to `" << TLIFunc->getName() << "`.\n");
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 858f72894861c1..400743af97c8cf 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -70,16 +70,20 @@ define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
 declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
 )IR";
 
-// LLVM intrinsic 'powi' (in IR) has the same signature with the VecDesc.
+// The VFABI prefix in TLI describes signature which is matching the powi
+// intrinsic declaration.
 TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
   VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
-                       ElementCount::getScalable(4), true, "_ZGVsMxvu"};
+                       ElementCount::getScalable(4), /*Masked*/ true,
+                       "_ZGVsMxvu"};
   EXPECT_TRUE(run(CorrectVD, IR));
 }
 
-// LLVM intrinsic 'powi' (in IR) has different signature with the VecDesc.
+// The VFABI prefix in TLI describes signature which is not matching the powi
+// intrinsic declaration.
 TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
   VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
-                         ElementCount::getScalable(4), true, "_ZGVsMxvv"};
+                         ElementCount::getScalable(4), /*Masked*/ true,
+                         "_ZGVsMxvv"};
   EXPECT_TRUE(run(IncorrectVD, IR));
 }

>From 39cb3532849edc0b46150d45ecc9f10965525ecd Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Tue, 9 Jan 2024 12:47:12 +0000
Subject: [PATCH 4/8] Addressing reviewers

---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 64025bfe031dd4..7dccff1bbac746 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -178,6 +178,11 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
     for (auto VFParam : OptInfo->Shape.Parameters) {
       if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
         continue;
+
+      // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
+      // a bug in the VFABI parser.
+      assert(VFParam.ParamPos < OrigArgTypes.size() &&
+             "ParamPos has invalid range.");
       Type *OrigTy = OrigArgTypes[VFParam.ParamPos];
       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
         LLVM_DEBUG(dbgs() << DEBUG_TYPE

>From 407a68379ac43f8dbf1d77b7fc1e2851084d4fc7 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 10 Jan 2024 12:18:27 +0000
Subject: [PATCH 5/8] Addressing reviewers (2)

---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7dccff1bbac746..f581cf9fbe9a92 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -108,7 +108,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   // Compute the argument types of the corresponding scalar call and the scalar
   // function name. For calls, it additionally finds the function to replace
   // and checks that all vector operands match the previously found EC.
-  SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
+  SmallVector<Type *, 8> ScalarArgTypes;
   std::string ScalarName;
   Function *FuncToReplace = nullptr;
   auto *CI = dyn_cast<CallInst>(&I);
@@ -118,7 +118,6 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
     assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
     for (auto Arg : enumerate(CI->args())) {
       auto *ArgTy = Arg.value()->getType();
-      OrigArgTypes.push_back(ArgTy);
       if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
         ScalarArgTypes.push_back(ArgTy);
       } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -181,9 +180,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
 
       // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
       // a bug in the VFABI parser.
-      assert(VFParam.ParamPos < OrigArgTypes.size() &&
+      assert(VFParam.ParamPos < CI->arg_size() &&
              "ParamPos has invalid range.");
-      Type *OrigTy = OrigArgTypes[VFParam.ParamPos];
+      Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
         LLVM_DEBUG(dbgs() << DEBUG_TYPE
                           << ": Will not replace: wrong type at index: "

>From 254025f966132bbfb0b5dcd406a9bd1a4697804a Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Wed, 10 Jan 2024 16:37:22 +0000
Subject: [PATCH 6/8] Improved tests by comparing last line of stderr.

---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp        |  5 +++-
 .../Analysis/ReplaceWithVecLibTest.cpp        | 29 +++++++++++++++----
 2 files changed, 28 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index f581cf9fbe9a92..7d8686a1681a7e 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -185,7 +185,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
       Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
         LLVM_DEBUG(dbgs() << DEBUG_TYPE
-                          << ": Will not replace: wrong type at index: "
+                          << ": Will not replace. Wrong type at index "
                           << VFParam.ParamPos << ": " << *OrigTy << "\n");
         return false;
       }
@@ -245,6 +245,9 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
   auto Changed = runImpl(TLI, F);
   if (Changed) {
+    LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
+                      << NumCallsReplaced << "\n");
+
     PreservedAnalyses PA;
     PA.preserveSet<CFGAnalyses>();
     PA.preserve<TargetLibraryAnalysis>();
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 400743af97c8cf..2157e5ee0a6981 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -32,6 +32,16 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
 /// the input IR) does not match the replacement function (derived from the
 /// VecDesc mapping).
 class ReplaceWithVecLibTest : public ::testing::Test {
+
+  std::string getLastLine(std::string Out) {
+    // remove ending '\n' if it exists
+    if (!Out.empty() && *(Out.cend() - 1) == '\n')
+      Out.pop_back();
+
+    size_t LastNL = Out.find_last_of('\n');
+    return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1);
+  }
+
 protected:
   LLVMContext Ctx;
 
@@ -39,7 +49,10 @@ class ReplaceWithVecLibTest : public ::testing::Test {
   /// pass. The pass should not crash even when the replacement function
   /// (derived from the \p VD mapping) does not match the function to be
   /// replaced (from the input \p IR).
-  bool run(const VecDesc &VD, const char *IR) {
+  ///
+  /// \returns the last line of the standard error to be compared for
+  /// correctness.
+  std::string run(const VecDesc &VD, const char *IR) {
     // Create TLII and register it with FAM so it's preserved when
     // ReplaceWithVeclib pass runs.
     TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
@@ -53,9 +66,12 @@ class ReplaceWithVecLibTest : public ::testing::Test {
     std::unique_ptr<Module> M = parseIR(Ctx, IR);
     PassBuilder PB;
     PB.registerFunctionAnalyses(FAM);
-    FPM.run(*M->getFunction("foo"), FAM);
 
-    return true;
+    // Enable debugging and capture std error
+    llvm::DebugFlag = true;
+    testing::internal::CaptureStderr();
+    FPM.run(*M->getFunction("foo"), FAM);
+    return getLastLine(testing::internal::GetCapturedStderr());
   }
 };
 
@@ -76,7 +92,8 @@ TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
   VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
                        ElementCount::getScalable(4), /*Masked*/ true,
                        "_ZGVsMxvu"};
-  EXPECT_TRUE(run(CorrectVD, IR));
+  EXPECT_EQ(run(CorrectVD, IR),
+            "Instructions replaced with vector libraries: 1");
 }
 
 // The VFABI prefix in TLI describes signature which is not matching the powi
@@ -85,5 +102,7 @@ TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
   VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
                          ElementCount::getScalable(4), /*Masked*/ true,
                          "_ZGVsMxvv"};
-  EXPECT_TRUE(run(IncorrectVD, IR));
+  EXPECT_EQ(
+      run(IncorrectVD, IR),
+      "replace-with-veclib: Will not replace. Wrong type at index 1: i32");
 }

>From b4a0b6ef020cd66b20315ba7a36d9483eda37771 Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Thu, 11 Jan 2024 10:07:02 +0000
Subject: [PATCH 7/8] Fix test to work with and without assertions.

---
 .../Analysis/ReplaceWithVecLibTest.cpp        | 27 ++++++++++++++++++-
 1 file changed, 26 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 2157e5ee0a6981..8905bd94624684 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -34,7 +34,7 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
 class ReplaceWithVecLibTest : public ::testing::Test {
 
   std::string getLastLine(std::string Out) {
-    // remove ending '\n' if it exists
+    // remove any trailing '\n'
     if (!Out.empty() && *(Out.cend() - 1) == '\n')
       Out.pop_back();
 
@@ -86,6 +86,9 @@ define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
 declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
 )IR";
 
+// With assertions on, perform stricter checks by verifying the debug output.
+#ifndef NDEBUG
+
 // The VFABI prefix in TLI describes signature which is matching the powi
 // intrinsic declaration.
 TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
@@ -106,3 +109,25 @@ TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
       run(IncorrectVD, IR),
       "replace-with-veclib: Will not replace. Wrong type at index 1: i32");
 }
+
+// Without assertions, check that tests don't crash.
+#else
+
+// The VFABI prefix in TLI describes signature which is matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+  VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+                       ElementCount::getScalable(4), /*Masked*/ true,
+                       "_ZGVsMxvu"};
+  EXPECT_EQ(run(CorrectVD, IR), "");
+}
+
+// The VFABI prefix in TLI describes signature which is not matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+  VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+                         ElementCount::getScalable(4), /*Masked*/ true,
+                         "_ZGVsMxvv"};
+  EXPECT_EQ(run(IncorrectVD, IR), "");
+}
+#endif
\ No newline at end of file

>From e40d08205ee3f98efd2161e8de9436aab82d8eec Mon Sep 17 00:00:00 2001
From: Paschalis Mpeis <Paschalis.Mpeis at arm.com>
Date: Thu, 11 Jan 2024 14:06:02 +0000
Subject: [PATCH 8/8] Addressing reviewers.

---
 llvm/lib/CodeGen/ReplaceWithVeclib.cpp        |  6 ++--
 .../Analysis/ReplaceWithVecLibTest.cpp        | 31 ++++---------------
 2 files changed, 9 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7d8686a1681a7e..7b0215535a92c8 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -184,9 +184,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
              "ParamPos has invalid range.");
       Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
-        LLVM_DEBUG(dbgs() << DEBUG_TYPE
-                          << ": Will not replace. Wrong type at index "
-                          << VFParam.ParamPos << ": " << *OrigTy << "\n");
+        LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
+                          << ". Wrong type at index " << VFParam.ParamPos
+                          << ": " << *OrigTy << "\n");
         return false;
       }
     }
diff --git a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
index 8905bd94624684..9447fbc4106e30 100644
--- a/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
+++ b/llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp
@@ -31,6 +31,8 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
 /// allows checking that the pass won't crash when the function to replace (from
 /// the input IR) does not match the replacement function (derived from the
 /// VecDesc mapping).
+///
+/// NOTE: Assertions must be enabled for these tests to run.
 class ReplaceWithVecLibTest : public ::testing::Test {
 
   std::string getLastLine(std::string Out) {
@@ -86,7 +88,7 @@ define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
 declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
 )IR";
 
-// With assertions on, perform stricter checks by verifying the debug output.
+// Need assertions enabled for running the tests.
 #ifndef NDEBUG
 
 // The VFABI prefix in TLI describes signature which is matching the powi
@@ -105,29 +107,8 @@ TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
   VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
                          ElementCount::getScalable(4), /*Masked*/ true,
                          "_ZGVsMxvv"};
-  EXPECT_EQ(
-      run(IncorrectVD, IR),
-      "replace-with-veclib: Will not replace. Wrong type at index 1: i32");
-}
-
-// Without assertions, check that tests don't crash.
-#else
-
-// The VFABI prefix in TLI describes signature which is matching the powi
-// intrinsic declaration.
-TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
-  VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
-                       ElementCount::getScalable(4), /*Masked*/ true,
-                       "_ZGVsMxvu"};
-  EXPECT_EQ(run(CorrectVD, IR), "");
-}
-
-// The VFABI prefix in TLI describes signature which is not matching the powi
-// intrinsic declaration.
-TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
-  VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
-                         ElementCount::getScalable(4), /*Masked*/ true,
-                         "_ZGVsMxvv"};
-  EXPECT_EQ(run(IncorrectVD, IR), "");
+  EXPECT_EQ(run(IncorrectVD, IR),
+            "replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong "
+            "type at index 1: i32");
 }
 #endif
\ No newline at end of file



More information about the llvm-commits mailing list