[llvm] 2a6e589 - [MergeFunctions] Add support to run the pass over a set of function pointers (#111045)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 07:18:56 PST 2024


Author: Rafael Eckstein
Date: 2024-11-28T16:18:52+01:00
New Revision: 2a6e5896a572c3be47ffe78afed8ba6ef278d336

URL: https://github.com/llvm/llvm-project/commit/2a6e5896a572c3be47ffe78afed8ba6ef278d336
DIFF: https://github.com/llvm/llvm-project/commit/2a6e5896a572c3be47ffe78afed8ba6ef278d336.diff

LOG: [MergeFunctions] Add support to run the pass over a set of function pointers (#111045)

This modification will enable the usage of `MergeFunctions` as a
standalone library. Currently, `MergeFunctions` can only be applied to
an entire module. By adopting this change, developers will gain the
flexibility to reuse the `MergeFunctions` code within their own
projects, choosing which functions to merge; hence, promoting code
reusability. Notice that this modification will not break backward
compatibility, because `MergeFunctions` will still work as a pass after
the modification.

Added: 
    llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp

Modified: 
    llvm/include/llvm/Transforms/IPO/MergeFunctions.h
    llvm/lib/Transforms/IPO/MergeFunctions.cpp
    llvm/unittests/Transforms/Utils/CMakeLists.txt
    llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
index 822f0fd99188d0..71f175c6472b44 100644
--- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
@@ -20,11 +20,16 @@
 namespace llvm {
 
 class Module;
+class Function;
 
 /// Merge identical functions.
 class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+  static bool runOnModule(Module &M);
+  static DenseMap<Function *, Function *>
+  runOnFunctions(ArrayRef<Function *> F);
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index b7d0c3d741deb9..e8508416f54275 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -196,7 +196,10 @@ class MergeFunctions {
   MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
   }
 
-  bool runOnModule(Module &M);
+  template <typename FuncContainer> bool run(FuncContainer &Functions);
+  DenseMap<Function *, Function *> runOnFunctions(ArrayRef<Function *> F);
+
+  SmallPtrSet<GlobalValue *, 4> &getUsed();
 
 private:
   // The function comparison operator is provided here so that FunctionNodes do
@@ -297,17 +300,36 @@ class MergeFunctions {
   // dangling iterators into FnTree. The invariant that preserves this is that
   // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
   DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
+
+  /// Deleted-New functions mapping
+  DenseMap<Function *, Function *> DelToNewMap;
 };
 } // end anonymous namespace
 
 PreservedAnalyses MergeFunctionsPass::run(Module &M,
                                           ModuleAnalysisManager &AM) {
-  MergeFunctions MF;
-  if (!MF.runOnModule(M))
+  if (!MergeFunctionsPass::runOnModule(M))
     return PreservedAnalyses::all();
   return PreservedAnalyses::none();
 }
 
+SmallPtrSet<GlobalValue *, 4> &MergeFunctions::getUsed() { return Used; }
+
+bool MergeFunctionsPass::runOnModule(Module &M) {
+  MergeFunctions MF;
+  SmallVector<GlobalValue *, 4> UsedV;
+  collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
+  collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
+  MF.getUsed().insert(UsedV.begin(), UsedV.end());
+  return MF.run(M);
+}
+
+DenseMap<Function *, Function *>
+MergeFunctionsPass::runOnFunctions(ArrayRef<Function *> F) {
+  MergeFunctions MF;
+  return MF.runOnFunctions(F);
+}
+
 #ifndef NDEBUG
 bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
   if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -409,20 +431,19 @@ static bool isEligibleForMerging(Function &F) {
          !hasDistinctMetadataIntrinsic(F);
 }
 
-bool MergeFunctions::runOnModule(Module &M) {
-  bool Changed = false;
+inline Function *asPtr(Function *Fn) { return Fn; }
+inline Function *asPtr(Function &Fn) { return &Fn; }
 
-  SmallVector<GlobalValue *, 4> UsedV;
-  collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
-  collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
-  Used.insert(UsedV.begin(), UsedV.end());
+template <typename FuncContainer> bool MergeFunctions::run(FuncContainer &M) {
+  bool Changed = false;
 
   // All functions in the module, ordered by hash. Functions with a unique
   // hash value are easily eliminated.
   std::vector<std::pair<stable_hash, Function *>> HashedFuncs;
-  for (Function &Func : M) {
-    if (isEligibleForMerging(Func)) {
-      HashedFuncs.push_back({StructuralHash(Func), &Func});
+  for (auto &Func : M) {
+    Function *FuncPtr = asPtr(Func);
+    if (isEligibleForMerging(*FuncPtr)) {
+      HashedFuncs.push_back({StructuralHash(*FuncPtr), FuncPtr});
     }
   }
 
@@ -433,7 +454,7 @@ bool MergeFunctions::runOnModule(Module &M) {
     // If the hash value matches the previous value or the next one, we must
     // consider merging it. Otherwise it is dropped and never considered again.
     if ((I != S && std::prev(I)->first == I->first) ||
-        (std::next(I) != IE && std::next(I)->first == I->first) ) {
+        (std::next(I) != IE && std::next(I)->first == I->first)) {
       Deferred.push_back(WeakTrackingVH(I->second));
     }
   }
@@ -467,9 +488,16 @@ bool MergeFunctions::runOnModule(Module &M) {
   return Changed;
 }
 
+DenseMap<Function *, Function *>
+MergeFunctions::runOnFunctions(ArrayRef<Function *> F) {
+  [[maybe_unused]] bool MergeResult = this->run(F);
+  assert(MergeResult == !DelToNewMap.empty());
+  return this->DelToNewMap;
+}
+
 // Replace direct callers of Old with New.
 void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
-  for (Use &U : llvm::make_early_inc_range(Old->uses())) {
+  for (Use &U : make_early_inc_range(Old->uses())) {
     CallBase *CB = dyn_cast<CallBase>(U.getUser());
     if (CB && CB->isCallee(&U)) {
       // Do not copy attributes from the called function to the call-site.
@@ -768,8 +796,8 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   ReturnInst *RI = nullptr;
   bool isSwiftTailCall = F->getCallingConv() == CallingConv::SwiftTail &&
                          G->getCallingConv() == CallingConv::SwiftTail;
-  CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail
-                                      : llvm::CallInst::TCK_Tail);
+  CI->setTailCallKind(isSwiftTailCall ? CallInst::TCK_MustTail
+                                      : CallInst::TCK_Tail);
   CI->setCallingConv(F->getCallingConv());
   CI->setAttributes(F->getAttributes());
   if (H->getReturnType()->isVoidTy()) {
@@ -1003,6 +1031,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
 
   Function *DeleteF = NewFunction;
   mergeTwoFunctions(OldF.getFunc(), DeleteF);
+  this->DelToNewMap.insert({DeleteF, OldF.getFunc()});
   return true;
 }
 

diff  --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index 5c7ec28709c169..7effa5d8e7d6d2 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests
   LoopUtilsTest.cpp
   MemTransferLowering.cpp
   ModuleUtilsTest.cpp
+  MergeFunctionsTest.cpp
   ScalarEvolutionExpanderTest.cpp
   SizeOptsTest.cpp
   SSAUpdaterBulkTest.cpp

diff  --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
new file mode 100644
index 00000000000000..56d119878a9ab2
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -0,0 +1,246 @@
+//===- MergeFunctionsTest.cpp - Unit tests for MergeFunctionsPass ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MergeFunctions.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+TEST(MergeFunctions, TrueOutputModuleTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...)
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+    )invalid",
+                                                Err, Ctx));
+
+  // Expects true after merging _slice_add10 and _slice_add10_alt
+  EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, TrueOutputFunctionsTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...)
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+    )invalid",
+                                                Err, Ctx));
+
+  SetVector<Function *> FunctionsSet;
+  for (Function &F : *M)
+    FunctionsSet.insert(&F);
+
+  DenseMap<Function *, Function *> MergeResult =
+      MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());
+
+  // Expects that both functions (_slice_add10 and _slice_add10_alt)
+  // be mapped to the same new function
+  EXPECT_TRUE(!MergeResult.empty());
+  Function *NewFunction = M->getFunction("_slice_add10");
+  for (auto P : MergeResult)
+    if (P.second)
+      EXPECT_EQ(P.second, NewFunction);
+}
+
+TEST(MergeFunctions, FalseOutputModuleTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...)
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %0
+        }
+    )invalid",
+                                                Err, Ctx));
+
+  // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+  EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, FalseOutputFunctionsTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...)
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %0
+        }
+    )invalid",
+                                                Err, Ctx));
+
+  SetVector<Function *> FunctionsSet;
+  for (Function &F : *M)
+    FunctionsSet.insert(&F);
+
+  DenseMap<Function *, Function *> MergeResult =
+      MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());
+
+  // Expects empty map
+  EXPECT_EQ(MergeResult.size(), 0u);
+}
+
+} // namespace
\ No newline at end of file

diff  --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
index 380ed71a2bc010..fcea55c91f083c 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
@@ -27,6 +27,7 @@ unittest("UtilsTests") {
     "LoopUtilsTest.cpp",
     "MemTransferLowering.cpp",
     "ModuleUtilsTest.cpp",
+    "MergeFunctionsTest.cpp",
     "ProfDataUtilTest.cpp",
     "SSAUpdaterBulkTest.cpp",
     "ScalarEvolutionExpanderTest.cpp",


        


More information about the llvm-commits mailing list