[llvm] [CGData] Global Merge Functions (PR #112671)
Ellis Hoag via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 7 10:10:06 PST 2024
================
@@ -0,0 +1,744 @@
+//===---- GlobalMergeFunctions.cpp - Global merge functions -------*- C++ -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass defines the implementation of a function merging mechanism
+// that utilizes a stable function hash to track differences in constants and
+// create potential merge candidates. The process involves two rounds:
+// 1. The first round collects stable function hashes and identifies merge
+// candidates with matching hashes. It also computes the set of parameters
+// that point to different constants during the stable function merge.
+// 2. The second round leverages this collected global function information to
+// optimistically create a merged function in each module context, ensuring
+// correct transformation.
+// Similar to the global outliner, this approach uses the linker's deduplication
+// (ICF) to fold identical merged functions, thereby reducing the final binary
+// size. The work is inspired by the concepts discussed in the following paper:
+// https://dl.acm.org/doi/pdf/10.1145/3652032.3657575.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/GlobalMergeFunctions.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/CGData/CodeGenData.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/StructuralHash.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+#define DEBUG_TYPE "global-merge-func"
+
+using namespace llvm;
+using namespace llvm::support;
+
+static cl::opt<bool>
+ DisableGlobalMerging("disable-global-merging", cl::Hidden,
+ cl::desc("Disable global merging only by ignoring "
+ "the codegen data generation or use. Local "
+ "merging is still enabled within a module."),
+ cl::init(false));
+static cl::opt<unsigned> GlobalMergingMinInstrs(
+ "global-merging-min-instrs",
+ cl::desc("The minimum instruction count required when merging functions."),
+ cl::init(1), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingMaxParams(
+ "global-merging-max-params",
+ cl::desc(
+ "The maximum number of parameters allowed when merging functions."),
+ cl::init(std::numeric_limits<unsigned>::max()), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingParamOverhead(
+ "global-merging-param-overhead",
+ cl::desc("The overhead cost associated with each parameter when merging "
+ "functions."),
+ cl::init(2), cl::Hidden);
+static cl::opt<unsigned>
+ GlobalMergingCallOverhead("global-merging-call-overhead",
+ cl::desc("The overhead cost associated with each "
+ "function call when merging functions."),
+ cl::init(1), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingExtraThreshold(
+ "global-merging-extra-threshold",
+ cl::desc("An additional cost threshold that must be exceeded for merging "
+ "to be considered beneficial."),
+ cl::init(0), cl::Hidden);
+
+extern cl::opt<bool> EnableGlobalMergeFunc;
+
+STATISTIC(NumMismatchedFunctionHashGlobalMergeFunction,
+ "Number of mismatched function hash for global merge function");
+STATISTIC(NumMismatchedInstCountGlobalMergeFunction,
+ "Number of mismatched instruction count for global merge function");
+STATISTIC(NumMismatchedConstHashGlobalMergeFunction,
+ "Number of mismatched const hash for global merge function");
+STATISTIC(NumMismatchedModuleIdGlobalMergeFunction,
+ "Number of mismatched Module Id for global merge function");
+STATISTIC(NumGlobalMergeFunctions,
+ "Number of functions that are actually merged using function hash");
+STATISTIC(NumAnalyzedModues, "Number of modules that are analyzed");
+STATISTIC(NumAnalyzedFunctions, "Number of functions that are analyzed");
+STATISTIC(NumEligibleFunctions, "Number of functions that are eligible");
+
+/// Returns true if the \OpIdx operand of \p CI is the callee operand.
+static bool isCalleeOperand(const CallBase *CI, unsigned OpIdx) {
+ return &CI->getCalledOperandUse() == &CI->getOperandUse(OpIdx);
+}
+
+static bool canParameterizeCallOperand(const CallBase *CI, unsigned OpIdx) {
+ if (CI->isInlineAsm())
+ return false;
+ Function *Callee = CI->getCalledOperand()
+ ? dyn_cast_or_null<Function>(
+ CI->getCalledOperand()->stripPointerCasts())
+ : nullptr;
+ if (Callee) {
+ if (Callee->isIntrinsic())
+ return false;
+ // objc_msgSend stubs must be called, and can't have their address taken.
+ if (Callee->getName().starts_with("objc_msgSend$"))
+ return false;
+ }
+ if (isCalleeOperand(CI, OpIdx) &&
+ CI->getOperandBundle(LLVMContext::OB_ptrauth).has_value()) {
+ // The operand is the callee and it has already been signed. Ignore this
+ // because we cannot add another ptrauth bundle to the call instruction.
+ return false;
+ }
+ return true;
+}
+
+bool isEligibleInstrunctionForConstantSharing(const Instruction *I) {
+ switch (I->getOpcode()) {
+ case Instruction::Load:
+ case Instruction::Store:
+ case Instruction::Call:
+ case Instruction::Invoke:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx) {
+ assert(OpIdx < I->getNumOperands() && "Invalid operand index");
+
+ if (!isEligibleInstrunctionForConstantSharing(I))
+ return false;
+
+ auto Opnd = I->getOperand(OpIdx);
+ if (!isa<Constant>(Opnd))
+ return false;
+
+ if (const auto *CI = dyn_cast<CallBase>(I))
+ return canParameterizeCallOperand(CI, OpIdx);
+
+ return true;
+}
+
+/// Returns true if function \p F is eligible for merging.
+bool isEligibleFunction(Function *F) {
+ if (F->isDeclaration())
+ return false;
+
+ if (F->hasFnAttribute(llvm::Attribute::NoMerge))
+ return false;
+
+ if (F->hasAvailableExternallyLinkage())
+ return false;
+
+ if (F->getFunctionType()->isVarArg())
+ return false;
+
+ if (F->getCallingConv() == CallingConv::SwiftTail)
+ return false;
+
+ // If function contains callsites with musttail, if we merge
+ // it, the merged function will have the musttail callsite, but
+ // the number of parameters can change, thus the parameter count
+ // of the callsite will mismatch with the function itself.
+ for (const BasicBlock &BB : *F) {
+ for (const Instruction &I : BB) {
+ const auto *CB = dyn_cast<CallBase>(&I);
+ if (CB && CB->isMustTailCall())
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool
+isEligibleInstrunctionForConstantSharingLocal(const Instruction *I) {
+ switch (I->getOpcode()) {
+ case Instruction::Load:
+ case Instruction::Store:
+ case Instruction::Call:
+ case Instruction::Invoke:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
+ assert(OpIdx < I->getNumOperands() && "Invalid operand index");
+
+ if (!isEligibleInstrunctionForConstantSharingLocal(I))
+ return false;
+
+ if (!isa<Constant>(I->getOperand(OpIdx)))
+ return false;
+
+ if (const auto *CI = dyn_cast<CallBase>(I))
+ return canParameterizeCallOperand(CI, OpIdx);
+
+ return true;
+}
+
+static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
+ Type *SrcTy = V->getType();
+ if (SrcTy->isStructTy()) {
+ assert(DestTy->isStructTy());
+ assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
+ Value *Result = PoisonValue::get(DestTy);
+ for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
+ Value *Element =
+ createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
+ DestTy->getStructElementType(I));
+
+ Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
+ }
+ return Result;
+ }
+ assert(!DestTy->isStructTy());
+ if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+ auto *DestAT = dyn_cast<ArrayType>(DestTy);
+ assert(DestAT);
+ assert(SrcAT->getNumElements() == DestAT->getNumElements());
+ Value *Result = UndefValue::get(DestTy);
+ for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+ Value *Element =
+ createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
+ DestAT->getElementType());
+
+ Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
+ }
+ return Result;
+ }
+ assert(!DestTy->isArrayTy());
+ if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
+ return Builder.CreateIntToPtr(V, DestTy);
+ else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
+ return Builder.CreatePtrToInt(V, DestTy);
+ else
+ return Builder.CreateBitCast(V, DestTy);
----------------
ellishg wrote:
```suggestion
if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
return Builder.CreatePtrToInt(V, DestTy);
return Builder.CreateBitCast(V, DestTy);
```
https://github.com/llvm/llvm-project/pull/112671
More information about the llvm-commits
mailing list