[llvm-branch-commits] [llvm] [MergeFunctions] Add support to run the pass over a set of function pointers (PR #110996)

Rafael Eckstein via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Oct 3 06:54:29 PDT 2024


https://github.com/Casperento created https://github.com/llvm/llvm-project/pull/110996

This modification will enable the usage of `MergeFunctions` as a standalone library. Currently, `MergeFunctions` can only be applied to an entire module. By adopting this change, developers will gain the flexibility to reuse the `MergeFunctions` code within their own projects, choosing which functions to merge; hence, promoting code reusability. Notice that this modification will not break backward compatibility, because `MergeFunctions` will still work as a pass after the modification.

### Summary of Changes:
- Modified the `MergeFunctionsPass` to allow running the pass over a set of function pointers.
- This behavior is optional and doesn't interfere with the existing functionality of running the pass on the entire `Module`.
- Added unit tests to assert the correctness of the updated implementation, ensuring that function merging works as expected when run on both sets of pointers and full modules.

>From 9b0073551ece0d22bf3378af2b03e456a26031b6 Mon Sep 17 00:00:00 2001
From: Casperento <44746868+Casperento at users.noreply.github.com>
Date: Tue, 24 Sep 2024 16:45:59 -0300
Subject: [PATCH] new runOn method

remove templates

unit tests added

format
---
 .../llvm/Transforms/IPO/MergeFunctions.h      |   7 +
 llvm/lib/Transforms/IPO/MergeFunctions.cpp    |  63 +++-
 .../unittests/Transforms/Utils/CMakeLists.txt |   1 +
 .../Transforms/Utils/MergeFunctionsTest.cpp   | 270 ++++++++++++++++++
 .../llvm/unittests/Transforms/Utils/BUILD.gn  |   1 +
 5 files changed, 340 insertions(+), 2 deletions(-)
 create mode 100644 llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp

diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
index 822f0fd99188d0..1b3b1d22f11e28 100644
--- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
@@ -15,7 +15,10 @@
 #ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
 #define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
 
+#include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
+#include <map>
+#include <set>
 
 namespace llvm {
 
@@ -25,6 +28,10 @@ class Module;
 class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+  static bool runOnModule(Module &M);
+  static std::pair<bool, std::map<Function *, Function *>>
+  runOnFunctions(std::set<Function *> &F);
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index feda5d6459cb47..2e775be4cab7c8 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -122,6 +122,7 @@
 #include <algorithm>
 #include <cassert>
 #include <iterator>
+#include <map>
 #include <set>
 #include <utility>
 #include <vector>
@@ -198,6 +199,8 @@ class MergeFunctions {
   }
 
   bool runOnModule(Module &M);
+  bool runOnFunctions(std::set<Function *> &F);
+  std::map<Function *, Function *> &getDelToNewMap();
 
 private:
   // The function comparison operator is provided here so that FunctionNodes do
@@ -291,17 +294,31 @@ class MergeFunctions {
   // dangling iterators into FnTree. The invariant that preserves this is that
   // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
   DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
+
+  /// Deleted-New functions mapping
+  std::map<Function *, Function *> DelToNewMap;
 };
 } // end anonymous namespace
 
 PreservedAnalyses MergeFunctionsPass::run(Module &M,
                                           ModuleAnalysisManager &AM) {
-  MergeFunctions MF;
-  if (!MF.runOnModule(M))
+  if (!MergeFunctionsPass::runOnModule(M))
     return PreservedAnalyses::all();
   return PreservedAnalyses::none();
 }
 
+bool MergeFunctionsPass::runOnModule(Module &M) {
+  MergeFunctions MF;
+  return MF.runOnModule(M);
+}
+
+std::pair<bool, std::map<Function *, Function *>>
+MergeFunctionsPass::runOnFunctions(std::set<Function *> &F) {
+  MergeFunctions MF;
+  bool MergeResult = MF.runOnFunctions(F);
+  return {MergeResult, MF.getDelToNewMap()};
+}
+
 #ifndef NDEBUG
 bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
   if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -439,6 +456,47 @@ bool MergeFunctions::runOnModule(Module &M) {
   return Changed;
 }
 
+bool MergeFunctions::runOnFunctions(std::set<Function *> &F) {
+  bool Changed = false;
+  std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> HashedFuncs;
+  for (Function *Func : F) {
+    if (isEligibleForMerging(*Func)) {
+      HashedFuncs.push_back({FunctionComparator::functionHash(*Func), Func});
+    }
+  }
+  llvm::stable_sort(HashedFuncs, less_first());
+  auto S = HashedFuncs.begin();
+  for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
+    if ((I != S && std::prev(I)->first == I->first) ||
+        (std::next(I) != IE && std::next(I)->first == I->first)) {
+      Deferred.push_back(WeakTrackingVH(I->second));
+    }
+  }
+  do {
+    std::vector<WeakTrackingVH> Worklist;
+    Deferred.swap(Worklist);
+    LLVM_DEBUG(dbgs() << "size of function: " << F.size() << '\n');
+    LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
+    for (WeakTrackingVH &I : Worklist) {
+      if (!I)
+        continue;
+      Function *F = cast<Function>(I);
+      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
+        Changed |= insert(F);
+      }
+    }
+    LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
+  } while (!Deferred.empty());
+  FnTree.clear();
+  FNodesInTree.clear();
+  GlobalNumbers.clear();
+  return Changed;
+}
+
+std::map<Function *, Function *> &MergeFunctions::getDelToNewMap() {
+  return this->DelToNewMap;
+}
+
 // Replace direct callers of Old with New.
 void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
   Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType());
@@ -917,6 +975,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
 
   Function *DeleteF = NewFunction;
   mergeTwoFunctions(OldF.getFunc(), DeleteF);
+  this->DelToNewMap.emplace(DeleteF, OldF.getFunc());
   return true;
 }
 
diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index 751de6fc5becc9..eb38f6237916f8 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -24,6 +24,7 @@ add_llvm_unittest(UtilsTests
   LoopUtilsTest.cpp
   MemTransferLowering.cpp
   ModuleUtilsTest.cpp
+  MergeFunctionsTest.cpp
   ScalarEvolutionExpanderTest.cpp
   SizeOptsTest.cpp
   SSAUpdaterBulkTest.cpp
diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
new file mode 100644
index 00000000000000..d19175f760b12f
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -0,0 +1,270 @@
+//===- MergeFunctionsTest.cpp - Unit tests for
+//MergeFunctionsPass-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MergeFunctions.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+TEST(MergeFunctions, TrueOutputModuleTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  // Expects true after merging _slice_add10 and _slice_add10_alt
+  EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, TrueOutputFunctionsTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  std::set<Function *> FunctionsSet;
+  for (Function &F : *M)
+    FunctionsSet.insert(&F);
+
+  std::pair<bool, std::map<Function *, Function *>> MergeResult =
+      MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+  // Expects true after merging _slice_add10 and _slice_add10_alt
+  EXPECT_TRUE(MergeResult.first);
+
+  // Expects that both functions (_slice_add10 and _slice_add10_alt)
+  // be mapped to the same new function
+  EXPECT_TRUE(MergeResult.second.size() > 0);
+  std::map<Function *, Function *> DelToNew = MergeResult.second;
+  Function *NewFunction = M->getFunction("_slice_add10");
+  for (auto P : DelToNew)
+    if (P.second)
+      EXPECT_EQ(P.second, NewFunction);
+}
+
+TEST(MergeFunctions, FalseOutputModuleTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %0
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+  EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, FalseOutputFunctionsTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %0
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  std::set<Function *> FunctionsSet;
+  for (Function &F : *M)
+    FunctionsSet.insert(&F);
+
+  std::pair<bool, std::map<Function *, Function *>> MergeResult =
+      MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+  for (auto P : MergeResult.second)
+    std::cout << P.first << " " << P.second << "\n";
+
+  // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+  EXPECT_FALSE(MergeResult.first);
+
+  // Expects empty map
+  EXPECT_EQ(MergeResult.second.size(), 0u);
+}
+
+} // namespace
\ No newline at end of file
diff --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
index 15e3518960b49d..411404dc11625a 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
@@ -26,6 +26,7 @@ unittest("UtilsTests") {
     "LoopUtilsTest.cpp",
     "MemTransferLowering.cpp",
     "ModuleUtilsTest.cpp",
+    "MergeFunctionsTest.cpp",
     "ProfDataUtilTest.cpp",
     "SSAUpdaterBulkTest.cpp",
     "ScalarEvolutionExpanderTest.cpp",



More information about the llvm-branch-commits mailing list