[llvm-branch-commits] [llvm] [MergeFunctions] Add support to run the pass over a set of function pointers (PR #110996)
Rafael Eckstein via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Oct 3 06:54:29 PDT 2024
https://github.com/Casperento created https://github.com/llvm/llvm-project/pull/110996
This modification will enable the usage of `MergeFunctions` as a standalone library. Currently, `MergeFunctions` can only be applied to an entire module. By adopting this change, developers will gain the flexibility to reuse the `MergeFunctions` code within their own projects, choosing which functions to merge; hence, promoting code reusability. Notice that this modification will not break backward compatibility, because `MergeFunctions` will still work as a pass after the modification.
### Summary of Changes:
- Modified the `MergeFunctionsPass` to allow running the pass over a set of function pointers.
- This behavior is optional and doesn't interfere with the existing functionality of running the pass on the entire `Module`.
- Added unit tests to assert the correctness of the updated implementation, ensuring that function merging works as expected when run on both sets of pointers and full modules.
>From 9b0073551ece0d22bf3378af2b03e456a26031b6 Mon Sep 17 00:00:00 2001
From: Casperento <44746868+Casperento at users.noreply.github.com>
Date: Tue, 24 Sep 2024 16:45:59 -0300
Subject: [PATCH] new runOn method
remove templates
unit tests added
format
---
.../llvm/Transforms/IPO/MergeFunctions.h | 7 +
llvm/lib/Transforms/IPO/MergeFunctions.cpp | 63 +++-
.../unittests/Transforms/Utils/CMakeLists.txt | 1 +
.../Transforms/Utils/MergeFunctionsTest.cpp | 270 ++++++++++++++++++
.../llvm/unittests/Transforms/Utils/BUILD.gn | 1 +
5 files changed, 340 insertions(+), 2 deletions(-)
create mode 100644 llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
index 822f0fd99188d0..1b3b1d22f11e28 100644
--- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
@@ -15,7 +15,10 @@
#ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
#define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
+#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
+#include <map>
+#include <set>
namespace llvm {
@@ -25,6 +28,10 @@ class Module;
class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+ static bool runOnModule(Module &M);
+ static std::pair<bool, std::map<Function *, Function *>>
+ runOnFunctions(std::set<Function *> &F);
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index feda5d6459cb47..2e775be4cab7c8 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -122,6 +122,7 @@
#include <algorithm>
#include <cassert>
#include <iterator>
+#include <map>
#include <set>
#include <utility>
#include <vector>
@@ -198,6 +199,8 @@ class MergeFunctions {
}
bool runOnModule(Module &M);
+ bool runOnFunctions(std::set<Function *> &F);
+ std::map<Function *, Function *> &getDelToNewMap();
private:
// The function comparison operator is provided here so that FunctionNodes do
@@ -291,17 +294,31 @@ class MergeFunctions {
// dangling iterators into FnTree. The invariant that preserves this is that
// there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
+
+ /// Deleted-New functions mapping
+ std::map<Function *, Function *> DelToNewMap;
};
} // end anonymous namespace
PreservedAnalyses MergeFunctionsPass::run(Module &M,
ModuleAnalysisManager &AM) {
- MergeFunctions MF;
- if (!MF.runOnModule(M))
+ if (!MergeFunctionsPass::runOnModule(M))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
+bool MergeFunctionsPass::runOnModule(Module &M) {
+ MergeFunctions MF;
+ return MF.runOnModule(M);
+}
+
+std::pair<bool, std::map<Function *, Function *>>
+MergeFunctionsPass::runOnFunctions(std::set<Function *> &F) {
+ MergeFunctions MF;
+ bool MergeResult = MF.runOnFunctions(F);
+ return {MergeResult, MF.getDelToNewMap()};
+}
+
#ifndef NDEBUG
bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -439,6 +456,47 @@ bool MergeFunctions::runOnModule(Module &M) {
return Changed;
}
+bool MergeFunctions::runOnFunctions(std::set<Function *> &F) {
+ bool Changed = false;
+ std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> HashedFuncs;
+ for (Function *Func : F) {
+ if (isEligibleForMerging(*Func)) {
+ HashedFuncs.push_back({FunctionComparator::functionHash(*Func), Func});
+ }
+ }
+ llvm::stable_sort(HashedFuncs, less_first());
+ auto S = HashedFuncs.begin();
+ for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
+ if ((I != S && std::prev(I)->first == I->first) ||
+ (std::next(I) != IE && std::next(I)->first == I->first)) {
+ Deferred.push_back(WeakTrackingVH(I->second));
+ }
+ }
+ do {
+ std::vector<WeakTrackingVH> Worklist;
+ Deferred.swap(Worklist);
+ LLVM_DEBUG(dbgs() << "size of function: " << F.size() << '\n');
+ LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
+ for (WeakTrackingVH &I : Worklist) {
+ if (!I)
+ continue;
+ Function *F = cast<Function>(I);
+ if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
+ Changed |= insert(F);
+ }
+ }
+ LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
+ } while (!Deferred.empty());
+ FnTree.clear();
+ FNodesInTree.clear();
+ GlobalNumbers.clear();
+ return Changed;
+}
+
+std::map<Function *, Function *> &MergeFunctions::getDelToNewMap() {
+ return this->DelToNewMap;
+}
+
// Replace direct callers of Old with New.
void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType());
@@ -917,6 +975,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
Function *DeleteF = NewFunction;
mergeTwoFunctions(OldF.getFunc(), DeleteF);
+ this->DelToNewMap.emplace(DeleteF, OldF.getFunc());
return true;
}
diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index 751de6fc5becc9..eb38f6237916f8 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -24,6 +24,7 @@ add_llvm_unittest(UtilsTests
LoopUtilsTest.cpp
MemTransferLowering.cpp
ModuleUtilsTest.cpp
+ MergeFunctionsTest.cpp
ScalarEvolutionExpanderTest.cpp
SizeOptsTest.cpp
SSAUpdaterBulkTest.cpp
diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
new file mode 100644
index 00000000000000..d19175f760b12f
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -0,0 +1,270 @@
+//===- MergeFunctionsTest.cpp - Unit tests for
+//MergeFunctionsPass-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MergeFunctions.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+TEST(MergeFunctions, TrueOutputModuleTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ // Expects true after merging _slice_add10 and _slice_add10_alt
+ EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, TrueOutputFunctionsTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ std::set<Function *> FunctionsSet;
+ for (Function &F : *M)
+ FunctionsSet.insert(&F);
+
+ std::pair<bool, std::map<Function *, Function *>> MergeResult =
+ MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+ // Expects true after merging _slice_add10 and _slice_add10_alt
+ EXPECT_TRUE(MergeResult.first);
+
+ // Expects that both functions (_slice_add10 and _slice_add10_alt)
+ // be mapped to the same new function
+ EXPECT_TRUE(MergeResult.second.size() > 0);
+ std::map<Function *, Function *> DelToNew = MergeResult.second;
+ Function *NewFunction = M->getFunction("_slice_add10");
+ for (auto P : DelToNew)
+ if (P.second)
+ EXPECT_EQ(P.second, NewFunction);
+}
+
+TEST(MergeFunctions, FalseOutputModuleTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %0
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+ EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, FalseOutputFunctionsTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %0
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ std::set<Function *> FunctionsSet;
+ for (Function &F : *M)
+ FunctionsSet.insert(&F);
+
+ std::pair<bool, std::map<Function *, Function *>> MergeResult =
+ MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+ for (auto P : MergeResult.second)
+ std::cout << P.first << " " << P.second << "\n";
+
+ // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+ EXPECT_FALSE(MergeResult.first);
+
+ // Expects empty map
+ EXPECT_EQ(MergeResult.second.size(), 0u);
+}
+
+} // namespace
\ No newline at end of file
diff --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
index 15e3518960b49d..411404dc11625a 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
@@ -26,6 +26,7 @@ unittest("UtilsTests") {
"LoopUtilsTest.cpp",
"MemTransferLowering.cpp",
"ModuleUtilsTest.cpp",
+ "MergeFunctionsTest.cpp",
"ProfDataUtilTest.cpp",
"SSAUpdaterBulkTest.cpp",
"ScalarEvolutionExpanderTest.cpp",
More information about the llvm-branch-commits
mailing list