[llvm] [MergeFunctions] Add support to run the pass over a set of function pointers (PR #111045)
Rafael Eckstein via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 1 13:55:19 PDT 2024
https://github.com/Casperento updated https://github.com/llvm/llvm-project/pull/111045
>From 69e1cc7d02df7d04b341c87abaa84cf4b6ab309d Mon Sep 17 00:00:00 2001
From: Casperento <44746868+Casperento at users.noreply.github.com>
Date: Tue, 24 Sep 2024 16:45:59 -0300
Subject: [PATCH 1/2] new runOn method
remove templates
unit tests added
format
updated data structures
format
---
.../llvm/Transforms/IPO/MergeFunctions.h | 7 +
llvm/lib/Transforms/IPO/MergeFunctions.cpp | 63 +++-
.../unittests/Transforms/Utils/CMakeLists.txt | 1 +
.../Transforms/Utils/MergeFunctionsTest.cpp | 271 ++++++++++++++++++
.../llvm/unittests/Transforms/Utils/BUILD.gn | 1 +
5 files changed, 341 insertions(+), 2 deletions(-)
create mode 100644 llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
index 822f0fd99188d0..1b3b1d22f11e28 100644
--- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
@@ -15,7 +15,10 @@
#ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
#define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
+#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
+#include <map>
+#include <set>
namespace llvm {
@@ -25,6 +28,10 @@ class Module;
class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+ static bool runOnModule(Module &M);
+ static std::pair<bool, std::map<Function *, Function *>>
+ runOnFunctions(std::set<Function *> &F);
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index b50a700e09038f..a434d7920b6ccf 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -123,6 +123,7 @@
#include <algorithm>
#include <cassert>
#include <iterator>
+#include <map>
#include <set>
#include <utility>
#include <vector>
@@ -198,6 +199,8 @@ class MergeFunctions {
}
bool runOnModule(Module &M);
+ bool runOnFunctions(std::set<Function *> &F);
+ std::map<Function *, Function *> &getDelToNewMap();
private:
// The function comparison operator is provided here so that FunctionNodes do
@@ -298,17 +301,31 @@ class MergeFunctions {
// dangling iterators into FnTree. The invariant that preserves this is that
// there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
+
+ /// Deleted-New functions mapping
+ std::map<Function *, Function *> DelToNewMap;
};
} // end anonymous namespace
PreservedAnalyses MergeFunctionsPass::run(Module &M,
ModuleAnalysisManager &AM) {
- MergeFunctions MF;
- if (!MF.runOnModule(M))
+ if (!MergeFunctionsPass::runOnModule(M))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
+bool MergeFunctionsPass::runOnModule(Module &M) {
+ MergeFunctions MF;
+ return MF.runOnModule(M);
+}
+
+std::pair<bool, std::map<Function *, Function *>>
+MergeFunctionsPass::runOnFunctions(std::set<Function *> &F) {
+ MergeFunctions MF;
+ bool MergeResult = MF.runOnFunctions(F);
+ return {MergeResult, MF.getDelToNewMap()};
+}
+
#ifndef NDEBUG
bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -468,6 +485,47 @@ bool MergeFunctions::runOnModule(Module &M) {
return Changed;
}
+bool MergeFunctions::runOnFunctions(std::set<Function *> &F) {
+ bool Changed = false;
+ std::vector<std::pair<IRHash, Function *>> HashedFuncs;
+ for (Function *Func : F) {
+ if (isEligibleForMerging(*Func)) {
+ HashedFuncs.push_back({StructuralHash(*Func), Func});
+ }
+ }
+ llvm::stable_sort(HashedFuncs, less_first());
+ auto S = HashedFuncs.begin();
+ for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
+ if ((I != S && std::prev(I)->first == I->first) ||
+ (std::next(I) != IE && std::next(I)->first == I->first)) {
+ Deferred.push_back(WeakTrackingVH(I->second));
+ }
+ }
+ do {
+ std::vector<WeakTrackingVH> Worklist;
+ Deferred.swap(Worklist);
+ LLVM_DEBUG(dbgs() << "size of function: " << F.size() << '\n');
+ LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
+ for (WeakTrackingVH &I : Worklist) {
+ if (!I)
+ continue;
+ Function *F = cast<Function>(I);
+ if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
+ Changed |= insert(F);
+ }
+ }
+ LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
+ } while (!Deferred.empty());
+ FnTree.clear();
+ FNodesInTree.clear();
+ GlobalNumbers.clear();
+ return Changed;
+}
+
+std::map<Function *, Function *> &MergeFunctions::getDelToNewMap() {
+ return this->DelToNewMap;
+}
+
// Replace direct callers of Old with New.
void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
for (Use &U : llvm::make_early_inc_range(Old->uses())) {
@@ -1004,6 +1062,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
Function *DeleteF = NewFunction;
mergeTwoFunctions(OldF.getFunc(), DeleteF);
+ this->DelToNewMap.emplace(DeleteF, OldF.getFunc());
return true;
}
diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index 5c7ec28709c169..7effa5d8e7d6d2 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests
LoopUtilsTest.cpp
MemTransferLowering.cpp
ModuleUtilsTest.cpp
+ MergeFunctionsTest.cpp
ScalarEvolutionExpanderTest.cpp
SizeOptsTest.cpp
SSAUpdaterBulkTest.cpp
diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
new file mode 100644
index 00000000000000..696c5391ef4f68
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -0,0 +1,271 @@
+//===- MergeFunctionsTest.cpp - Unit tests for
+//MergeFunctionsPass-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MergeFunctions.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+TEST(MergeFunctions, TrueOutputModuleTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ // Expects true after merging _slice_add10 and _slice_add10_alt
+ EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, TrueOutputFunctionsTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ std::set<Function *> FunctionsSet;
+ for (Function &F : *M)
+ FunctionsSet.insert(&F);
+
+ std::pair<bool, std::map<Function *, Function *>> MergeResult =
+ MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+ // Expects true after merging _slice_add10 and _slice_add10_alt
+ EXPECT_TRUE(MergeResult.first);
+
+ // Expects that both functions (_slice_add10 and _slice_add10_alt)
+ // be mapped to the same new function
+ EXPECT_TRUE(MergeResult.second.size() > 0);
+ std::map<Function *, Function *> DelToNew = MergeResult.second;
+ Function *NewFunction = M->getFunction("_slice_add10");
+ for (auto P : DelToNew)
+ if (P.second)
+ EXPECT_EQ(P.second, NewFunction);
+}
+
+TEST(MergeFunctions, FalseOutputModuleTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %0
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+ EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, FalseOutputFunctionsTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) #0 {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...) #1
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %0
+ }
+
+ attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+ attributes #2 = { nounwind willreturn }
+ )invalid",
+ Err, Ctx));
+
+ std::set<Function *> FunctionsSet;
+ for (Function &F : *M)
+ FunctionsSet.insert(&F);
+
+ std::pair<bool, std::map<Function *, Function *>> MergeResult =
+ MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+ for (auto P : MergeResult.second)
+ std::cout << P.first << " " << P.second << "\n";
+
+ // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+ EXPECT_FALSE(MergeResult.first);
+
+ // Expects empty map
+ EXPECT_EQ(MergeResult.second.size(), 0u);
+}
+
+} // namespace
\ No newline at end of file
diff --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
index 380ed71a2bc010..fcea55c91f083c 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
@@ -27,6 +27,7 @@ unittest("UtilsTests") {
"LoopUtilsTest.cpp",
"MemTransferLowering.cpp",
"ModuleUtilsTest.cpp",
+ "MergeFunctionsTest.cpp",
"ProfDataUtilTest.cpp",
"SSAUpdaterBulkTest.cpp",
"ScalarEvolutionExpanderTest.cpp",
>From 56bbc8cd81b55f87afe35eee9a74da22e8577310 Mon Sep 17 00:00:00 2001
From: Casperento <44746868+Casperento at users.noreply.github.com>
Date: Fri, 1 Nov 2024 17:55:06 -0300
Subject: [PATCH 2/2] format fix
---
llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
index 696c5391ef4f68..2f86ec6ed77fcd 100644
--- a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -1,5 +1,5 @@
//===- MergeFunctionsTest.cpp - Unit tests for
-//MergeFunctionsPass-----------===//
+// MergeFunctionsPass-----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
More information about the llvm-commits
mailing list