[llvm] 2a6e589 - [MergeFunctions] Add support to run the pass over a set of function pointers (#111045)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 28 07:18:56 PST 2024
Author: Rafael Eckstein
Date: 2024-11-28T16:18:52+01:00
New Revision: 2a6e5896a572c3be47ffe78afed8ba6ef278d336
URL: https://github.com/llvm/llvm-project/commit/2a6e5896a572c3be47ffe78afed8ba6ef278d336
DIFF: https://github.com/llvm/llvm-project/commit/2a6e5896a572c3be47ffe78afed8ba6ef278d336.diff
LOG: [MergeFunctions] Add support to run the pass over a set of function pointers (#111045)
This modification will enable the usage of `MergeFunctions` as a
standalone library. Currently, `MergeFunctions` can only be applied to
an entire module. By adopting this change, developers will gain the
flexibility to reuse the `MergeFunctions` code within their own
projects, choosing which functions to merge; hence, promoting code
reusability. Notice that this modification will not break backward
compatibility, because `MergeFunctions` will still work as a pass after
the modification.
Added:
llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
Modified:
llvm/include/llvm/Transforms/IPO/MergeFunctions.h
llvm/lib/Transforms/IPO/MergeFunctions.cpp
llvm/unittests/Transforms/Utils/CMakeLists.txt
llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
index 822f0fd99188d0..71f175c6472b44 100644
--- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
@@ -20,11 +20,16 @@
namespace llvm {
class Module;
+class Function;
/// Merge identical functions.
class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+ static bool runOnModule(Module &M);
+ static DenseMap<Function *, Function *>
+ runOnFunctions(ArrayRef<Function *> F);
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index b7d0c3d741deb9..e8508416f54275 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -196,7 +196,10 @@ class MergeFunctions {
MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
}
- bool runOnModule(Module &M);
+ template <typename FuncContainer> bool run(FuncContainer &Functions);
+ DenseMap<Function *, Function *> runOnFunctions(ArrayRef<Function *> F);
+
+ SmallPtrSet<GlobalValue *, 4> &getUsed();
private:
// The function comparison operator is provided here so that FunctionNodes do
@@ -297,17 +300,36 @@ class MergeFunctions {
// dangling iterators into FnTree. The invariant that preserves this is that
// there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
+
+ /// Deleted-New functions mapping
+ DenseMap<Function *, Function *> DelToNewMap;
};
} // end anonymous namespace
PreservedAnalyses MergeFunctionsPass::run(Module &M,
ModuleAnalysisManager &AM) {
- MergeFunctions MF;
- if (!MF.runOnModule(M))
+ if (!MergeFunctionsPass::runOnModule(M))
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}
+SmallPtrSet<GlobalValue *, 4> &MergeFunctions::getUsed() { return Used; }
+
+bool MergeFunctionsPass::runOnModule(Module &M) {
+ MergeFunctions MF;
+ SmallVector<GlobalValue *, 4> UsedV;
+ collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
+ collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
+ MF.getUsed().insert(UsedV.begin(), UsedV.end());
+ return MF.run(M);
+}
+
+DenseMap<Function *, Function *>
+MergeFunctionsPass::runOnFunctions(ArrayRef<Function *> F) {
+ MergeFunctions MF;
+ return MF.runOnFunctions(F);
+}
+
#ifndef NDEBUG
bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -409,20 +431,19 @@ static bool isEligibleForMerging(Function &F) {
!hasDistinctMetadataIntrinsic(F);
}
-bool MergeFunctions::runOnModule(Module &M) {
- bool Changed = false;
+inline Function *asPtr(Function *Fn) { return Fn; }
+inline Function *asPtr(Function &Fn) { return &Fn; }
- SmallVector<GlobalValue *, 4> UsedV;
- collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
- collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
- Used.insert(UsedV.begin(), UsedV.end());
+template <typename FuncContainer> bool MergeFunctions::run(FuncContainer &M) {
+ bool Changed = false;
// All functions in the module, ordered by hash. Functions with a unique
// hash value are easily eliminated.
std::vector<std::pair<stable_hash, Function *>> HashedFuncs;
- for (Function &Func : M) {
- if (isEligibleForMerging(Func)) {
- HashedFuncs.push_back({StructuralHash(Func), &Func});
+ for (auto &Func : M) {
+ Function *FuncPtr = asPtr(Func);
+ if (isEligibleForMerging(*FuncPtr)) {
+ HashedFuncs.push_back({StructuralHash(*FuncPtr), FuncPtr});
}
}
@@ -433,7 +454,7 @@ bool MergeFunctions::runOnModule(Module &M) {
// If the hash value matches the previous value or the next one, we must
// consider merging it. Otherwise it is dropped and never considered again.
if ((I != S && std::prev(I)->first == I->first) ||
- (std::next(I) != IE && std::next(I)->first == I->first) ) {
+ (std::next(I) != IE && std::next(I)->first == I->first)) {
Deferred.push_back(WeakTrackingVH(I->second));
}
}
@@ -467,9 +488,16 @@ bool MergeFunctions::runOnModule(Module &M) {
return Changed;
}
+DenseMap<Function *, Function *>
+MergeFunctions::runOnFunctions(ArrayRef<Function *> F) {
+ [[maybe_unused]] bool MergeResult = this->run(F);
+ assert(MergeResult == !DelToNewMap.empty());
+ return this->DelToNewMap;
+}
+
// Replace direct callers of Old with New.
void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
- for (Use &U : llvm::make_early_inc_range(Old->uses())) {
+ for (Use &U : make_early_inc_range(Old->uses())) {
CallBase *CB = dyn_cast<CallBase>(U.getUser());
if (CB && CB->isCallee(&U)) {
// Do not copy attributes from the called function to the call-site.
@@ -768,8 +796,8 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
ReturnInst *RI = nullptr;
bool isSwiftTailCall = F->getCallingConv() == CallingConv::SwiftTail &&
G->getCallingConv() == CallingConv::SwiftTail;
- CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail
- : llvm::CallInst::TCK_Tail);
+ CI->setTailCallKind(isSwiftTailCall ? CallInst::TCK_MustTail
+ : CallInst::TCK_Tail);
CI->setCallingConv(F->getCallingConv());
CI->setAttributes(F->getAttributes());
if (H->getReturnType()->isVoidTy()) {
@@ -1003,6 +1031,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
Function *DeleteF = NewFunction;
mergeTwoFunctions(OldF.getFunc(), DeleteF);
+ this->DelToNewMap.insert({DeleteF, OldF.getFunc()});
return true;
}
diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index 5c7ec28709c169..7effa5d8e7d6d2 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests
LoopUtilsTest.cpp
MemTransferLowering.cpp
ModuleUtilsTest.cpp
+ MergeFunctionsTest.cpp
ScalarEvolutionExpanderTest.cpp
SizeOptsTest.cpp
SSAUpdaterBulkTest.cpp
diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
new file mode 100644
index 00000000000000..56d119878a9ab2
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -0,0 +1,246 @@
+//===- MergeFunctionsTest.cpp - Unit tests for MergeFunctionsPass ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MergeFunctions.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+TEST(MergeFunctions, TrueOutputModuleTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...)
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+ )invalid",
+ Err, Ctx));
+
+ // Expects true after merging _slice_add10 and _slice_add10_alt
+ EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, TrueOutputFunctionsTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...)
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+ )invalid",
+ Err, Ctx));
+
+ SetVector<Function *> FunctionsSet;
+ for (Function &F : *M)
+ FunctionsSet.insert(&F);
+
+ DenseMap<Function *, Function *> MergeResult =
+ MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());
+
+ // Expects that both functions (_slice_add10 and _slice_add10_alt)
+ // be mapped to the same new function
+ EXPECT_TRUE(!MergeResult.empty());
+ Function *NewFunction = M->getFunction("_slice_add10");
+ for (auto P : MergeResult)
+ if (P.second)
+ EXPECT_EQ(P.second, NewFunction);
+}
+
+TEST(MergeFunctions, FalseOutputModuleTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...)
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %0
+ }
+ )invalid",
+ Err, Ctx));
+
+ // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+ EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, FalseOutputFunctionsTest) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+ @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+ @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+ define dso_local i32 @f(i32 noundef %arg) {
+ entry:
+ %add109 = call i32 @_slice_add10(i32 %arg)
+ %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+ ret i32 %add109
+ }
+
+ declare i32 @printf(ptr noundef, ...)
+
+ define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) {
+ entry:
+ %add99 = call i32 @_slice_add10(i32 %argc)
+ %call = call i32 @f(i32 noundef 2)
+ %sub = sub nsw i32 %call, 6
+ %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+ ret i32 %add99
+ }
+
+ define internal i32 @_slice_add10(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %4
+ }
+
+ define internal i32 @_slice_add10_alt(i32 %arg) {
+ sliceclone_entry:
+ %0 = mul nsw i32 %arg, %arg
+ %1 = mul nsw i32 %0, 2
+ %2 = mul nsw i32 %1, 2
+ %3 = mul nsw i32 %2, 2
+ %4 = add nsw i32 %3, 2
+ ret i32 %0
+ }
+ )invalid",
+ Err, Ctx));
+
+ SetVector<Function *> FunctionsSet;
+ for (Function &F : *M)
+ FunctionsSet.insert(&F);
+
+ DenseMap<Function *, Function *> MergeResult =
+ MergeFunctionsPass::runOnFunctions(FunctionsSet.getArrayRef());
+
+ // Expects empty map
+ EXPECT_EQ(MergeResult.size(), 0u);
+}
+
+} // namespace
\ No newline at end of file
diff --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
index 380ed71a2bc010..fcea55c91f083c 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
@@ -27,6 +27,7 @@ unittest("UtilsTests") {
"LoopUtilsTest.cpp",
"MemTransferLowering.cpp",
"ModuleUtilsTest.cpp",
+ "MergeFunctionsTest.cpp",
"ProfDataUtilTest.cpp",
"SSAUpdaterBulkTest.cpp",
"ScalarEvolutionExpanderTest.cpp",
More information about the llvm-commits
mailing list