[llvm] [TLI] Fix replace-with-veclib crash with invalid arguments (PR #77112)

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 05:56:03 PST 2024


================
@@ -0,0 +1,89 @@
+//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ReplaceWithVeclib.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+  if (!Mod)
+    Err.print("ReplaceWithVecLibTest", errs());
+  return Mod;
+}
+
+/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
+/// allows checking that the pass won't crash when the function to replace (from
+/// the input IR) does not match the replacement function (derived from the
+/// VecDesc mapping).
+class ReplaceWithVecLibTest : public ::testing::Test {
+protected:
+  LLVMContext Ctx;
+
+  /// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
+  /// pass. The pass should not crash even when the replacement function
+  /// (derived from the \p VD mapping) does not match the function to be
+  /// replaced (from the input \p IR).
+  bool run(const VecDesc &VD, const char *IR) {
+    // Create TLII and register it with FAM so it's preserved when
+    // ReplaceWithVeclib pass runs.
+    TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
+    TLII.addVectorizableFunctions({VD});
+    FunctionAnalysisManager FAM;
+    FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
+
+    // Register and run the pass on the 'foo' function from the input IR.
+    FunctionPassManager FPM;
+    FPM.addPass(ReplaceWithVeclib());
+    std::unique_ptr<Module> M = parseIR(Ctx, IR);
+    PassBuilder PB;
+    PB.registerFunctionAnalyses(FAM);
+    FPM.run(*M->getFunction("foo"), FAM);
+
+    return true;
+  }
+};
+
+} // end anonymous namespace
+
+static const char *IR = R"IR(
+define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
+  %call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
+  ret <vscale x 4 x float> %call
+}
+
+declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
+)IR";
+
+// The VFABI prefix in TLI describes signature which is matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
+  VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
+                       ElementCount::getScalable(4), /*Masked*/ true,
+                       "_ZGVsMxvu"};
+  EXPECT_TRUE(run(CorrectVD, IR));
+}
+
+// The VFABI prefix in TLI describes signature which is not matching the powi
+// intrinsic declaration.
+TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
+  VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
+                         ElementCount::getScalable(4), /*Masked*/ true,
+                         "_ZGVsMxvv"};
+  EXPECT_TRUE(run(IncorrectVD, IR));
----------------
labrinea wrote:

I don't see a difference in the return value between the EXPECT lines of the valid/invalid mapping tests. I would expect that a successful invocation of the pass differs from an unsuccessful. I think a regression test would be sufficient to prove the pass doesn't crash anymore on the reproducer. Unittests are meant to check that an API behaves as it is meant to. That said I would suggest to either move this to a regression or even better check the return value of ```replaceWithTLIFunction``` (must be non-static for this to be possible).

https://github.com/llvm/llvm-project/pull/77112


More information about the llvm-commits mailing list