[llvm] [CGData] Global Merge Functions (PR #112671)

Kyungwoo Lee via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 10 10:00:46 PST 2024


================
@@ -0,0 +1,744 @@
+//===---- GlobalMergeFunctions.cpp - Global merge functions -------*- C++ -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass defines the implementation of a function merging mechanism
+// that utilizes a stable function hash to track differences in constants and
+// create potential merge candidates. The process involves two rounds:
+// 1. The first round collects stable function hashes and identifies merge
+//    candidates with matching hashes. It also computes the set of parameters
+//    that point to different constants during the stable function merge.
+// 2. The second round leverages this collected global function information to
+//    optimistically create a merged function in each module context, ensuring
+//    correct transformation.
+// Similar to the global outliner, this approach uses the linker's deduplication
+// (ICF) to fold identical merged functions, thereby reducing the final binary
+// size. The work is inspired by the concepts discussed in the following paper:
+// https://dl.acm.org/doi/pdf/10.1145/3652032.3657575.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/GlobalMergeFunctions.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/CGData/CodeGenData.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/StructuralHash.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+#define DEBUG_TYPE "global-merge-func"
+
+using namespace llvm;
+using namespace llvm::support;
+
+static cl::opt<bool>
+    DisableGlobalMerging("disable-global-merging", cl::Hidden,
+                         cl::desc("Disable global merging only by ignoring "
+                                  "the codegen data generation or use. Local "
+                                  "merging is still enabled within a module."),
+                         cl::init(false));
+static cl::opt<unsigned> GlobalMergingMinInstrs(
+    "global-merging-min-instrs",
+    cl::desc("The minimum instruction count required when merging functions."),
+    cl::init(1), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingMaxParams(
+    "global-merging-max-params",
+    cl::desc(
+        "The maximum number of parameters allowed when merging functions."),
+    cl::init(std::numeric_limits<unsigned>::max()), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingParamOverhead(
+    "global-merging-param-overhead",
+    cl::desc("The overhead cost associated with each parameter when merging "
+             "functions."),
+    cl::init(2), cl::Hidden);
+static cl::opt<unsigned>
+    GlobalMergingCallOverhead("global-merging-call-overhead",
+                              cl::desc("The overhead cost associated with each "
+                                       "function call when merging functions."),
+                              cl::init(1), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingExtraThreshold(
+    "global-merging-extra-threshold",
+    cl::desc("An additional cost threshold that must be exceeded for merging "
+             "to be considered beneficial."),
+    cl::init(0), cl::Hidden);
+
+extern cl::opt<bool> EnableGlobalMergeFunc;
+
+STATISTIC(NumMismatchedFunctionHashGlobalMergeFunction,
+          "Number of mismatched function hash for global merge function");
+STATISTIC(NumMismatchedInstCountGlobalMergeFunction,
+          "Number of mismatched instruction count for global merge function");
+STATISTIC(NumMismatchedConstHashGlobalMergeFunction,
+          "Number of mismatched const hash for global merge function");
+STATISTIC(NumMismatchedModuleIdGlobalMergeFunction,
+          "Number of mismatched Module Id for global merge function");
+STATISTIC(NumGlobalMergeFunctions,
+          "Number of functions that are actually merged using function hash");
+STATISTIC(NumAnalyzedModues, "Number of modules that are analyzed");
+STATISTIC(NumAnalyzedFunctions, "Number of functions that are analyzed");
+STATISTIC(NumEligibleFunctions, "Number of functions that are eligible");
+
+/// Returns true if the \OpIdx operand of \p CI is the callee operand.
+static bool isCalleeOperand(const CallBase *CI, unsigned OpIdx) {
+  return &CI->getCalledOperandUse() == &CI->getOperandUse(OpIdx);
+}
+
+static bool canParameterizeCallOperand(const CallBase *CI, unsigned OpIdx) {
+  if (CI->isInlineAsm())
+    return false;
+  Function *Callee = CI->getCalledOperand()
+                         ? dyn_cast_or_null<Function>(
+                               CI->getCalledOperand()->stripPointerCasts())
+                         : nullptr;
+  if (Callee) {
+    if (Callee->isIntrinsic())
+      return false;
+    // objc_msgSend stubs must be called, and can't have their address taken.
+    if (Callee->getName().starts_with("objc_msgSend$"))
+      return false;
+  }
+  if (isCalleeOperand(CI, OpIdx) &&
+      CI->getOperandBundle(LLVMContext::OB_ptrauth).has_value()) {
+    // The operand is the callee and it has already been signed. Ignore this
+    // because we cannot add another ptrauth bundle to the call instruction.
+    return false;
+  }
+  return true;
+}
+
+bool isEligibleInstrunctionForConstantSharing(const Instruction *I) {
+  switch (I->getOpcode()) {
+  case Instruction::Load:
+  case Instruction::Store:
+  case Instruction::Call:
+  case Instruction::Invoke:
+    return true;
+  default:
+    return false;
+  }
+}
+
+bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx) {
+  assert(OpIdx < I->getNumOperands() && "Invalid operand index");
+
+  if (!isEligibleInstrunctionForConstantSharing(I))
+    return false;
+
+  auto Opnd = I->getOperand(OpIdx);
+  if (!isa<Constant>(Opnd))
+    return false;
+
+  if (const auto *CI = dyn_cast<CallBase>(I))
+    return canParameterizeCallOperand(CI, OpIdx);
+
+  return true;
+}
+
+/// Returns true if function \p F is eligible for merging.
+bool isEligibleFunction(Function *F) {
+  if (F->isDeclaration())
+    return false;
+
+  if (F->hasFnAttribute(llvm::Attribute::NoMerge))
+    return false;
+
+  if (F->hasAvailableExternallyLinkage())
+    return false;
+
+  if (F->getFunctionType()->isVarArg())
+    return false;
+
+  if (F->getCallingConv() == CallingConv::SwiftTail)
+    return false;
+
+  // If function contains callsites with musttail, if we merge
+  // it, the merged function will have the musttail callsite, but
+  // the number of parameters can change, thus the parameter count
+  // of the callsite will mismatch with the function itself.
+  for (const BasicBlock &BB : *F) {
+    for (const Instruction &I : BB) {
+      const auto *CB = dyn_cast<CallBase>(&I);
+      if (CB && CB->isMustTailCall())
+        return false;
+    }
+  }
+
+  return true;
+}
+
+static bool
+isEligibleInstrunctionForConstantSharingLocal(const Instruction *I) {
+  switch (I->getOpcode()) {
+  case Instruction::Load:
+  case Instruction::Store:
+  case Instruction::Call:
+  case Instruction::Invoke:
+    return true;
+  default:
+    return false;
+  }
+}
+
+static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
+  assert(OpIdx < I->getNumOperands() && "Invalid operand index");
+
+  if (!isEligibleInstrunctionForConstantSharingLocal(I))
+    return false;
+
+  if (!isa<Constant>(I->getOperand(OpIdx)))
+    return false;
+
+  if (const auto *CI = dyn_cast<CallBase>(I))
+    return canParameterizeCallOperand(CI, OpIdx);
+
+  return true;
+}
+
+static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
+  Type *SrcTy = V->getType();
+  if (SrcTy->isStructTy()) {
+    assert(DestTy->isStructTy());
+    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
+      Value *Element =
+          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
+                     DestTy->getStructElementType(I));
+
+      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+  assert(!DestTy->isStructTy());
+  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+    auto *DestAT = dyn_cast<ArrayType>(DestTy);
+    assert(DestAT);
+    assert(SrcAT->getNumElements() == DestAT->getNumElements());
+    Value *Result = UndefValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+      Value *Element =
+          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
+                     DestAT->getElementType());
+
+      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+  assert(!DestTy->isArrayTy());
+  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
+    return Builder.CreateIntToPtr(V, DestTy);
+  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
+    return Builder.CreatePtrToInt(V, DestTy);
+  else
+    return Builder.CreateBitCast(V, DestTy);
+}
+
+void GlobalMergeFunc::analyze(Module &M) {
+  ++NumAnalyzedModues;
+  for (Function &Func : M) {
+    ++NumAnalyzedFunctions;
+    if (isEligibleFunction(&Func)) {
+      ++NumEligibleFunctions;
+
+      auto FI = llvm::StructuralHashWithDifferences(Func, ignoreOp);
+
+      // Convert the operand map to a vector for a serialization-friendly
+      // format.
+      IndexOperandHashVecType IndexOperandHashes;
+      for (auto &Pair : *FI.IndexOperandHashMap)
+        IndexOperandHashes.emplace_back(Pair);
+
+      StableFunction SF(FI.FunctionHash, get_stable_name(Func.getName()).str(),
+                        M.getModuleIdentifier(), FI.IndexInstruction->size(),
+                        std::move(IndexOperandHashes));
+
+      LocalFunctionMap->insert(SF);
+    }
+  }
+}
+
+/// Tuple to hold function info to process merging.
+struct FuncMergeInfo {
+  StableFunctionMap::StableFunctionEntry *SF;
+  Function *F;
+  std::unique_ptr<IndexInstrMap> IndexInstruction;
+};
+
+// Given the func info, and the parameterized locations, create and return
+// a new merged function by replacing the original constants with the new
+// parameters.
+static Function *createMergedFunction(FuncMergeInfo &FI,
+                                      ArrayRef<Type *> ConstParamTypes,
+                                      const ParamLocsVecTy &ParamLocsVec) {
+  // Synthesize a new merged function name by appending ".Tgm" to the root
+  // function's name.
+  auto *MergedFunc = FI.F;
+  auto NewFunctionName = MergedFunc->getName().str() + ".Tgm";
+  auto *M = MergedFunc->getParent();
+  assert(!M->getFunction(NewFunctionName));
+
+  FunctionType *OrigTy = MergedFunc->getFunctionType();
+  // Get the original params' types.
+  SmallVector<Type *> ParamTypes(OrigTy->param_begin(), OrigTy->param_end());
+  // Append const parameter types that are passed in.
+  ParamTypes.append(ConstParamTypes.begin(), ConstParamTypes.end());
+  FunctionType *FuncType =
+      FunctionType::get(OrigTy->getReturnType(), ParamTypes, false);
+
+  // Declare a new function
+  Function *NewFunction =
+      Function::Create(FuncType, MergedFunc->getLinkage(), NewFunctionName);
+  if (auto *SP = MergedFunc->getSubprogram())
+    NewFunction->setSubprogram(SP);
+  NewFunction->copyAttributesFrom(MergedFunc);
+  NewFunction->setDLLStorageClass(GlobalValue::DefaultStorageClass);
+
+  NewFunction->setLinkage(GlobalValue::InternalLinkage);
+  NewFunction->addFnAttr(Attribute::NoInline);
+
+  // Add the new function before the root function.
+  M->getFunctionList().insert(MergedFunc->getIterator(), NewFunction);
+
+  // Move the body of MergedFunc into the NewFunction.
+  NewFunction->splice(NewFunction->begin(), MergedFunc);
+
+  // Update the original args by the new args.
+  auto NewArgIter = NewFunction->arg_begin();
+  for (Argument &OrigArg : MergedFunc->args()) {
+    Argument &NewArg = *NewArgIter++;
+    OrigArg.replaceAllUsesWith(&NewArg);
+  }
+
+  // Replace the original Constants by the new args.
+  unsigned NumOrigArgs = MergedFunc->arg_size();
+  for (unsigned ParamIdx = 0; ParamIdx < ParamLocsVec.size(); ++ParamIdx) {
+    Argument *NewArg = NewFunction->getArg(NumOrigArgs + ParamIdx);
+    for (auto [InstIndex, OpndIndex] : ParamLocsVec[ParamIdx]) {
+      auto *Inst = FI.IndexInstruction->lookup(InstIndex);
+      auto *OrigC = Inst->getOperand(OpndIndex);
+      if (OrigC->getType() != NewArg->getType()) {
+        IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
+        Inst->setOperand(OpndIndex,
+                         createCast(Builder, NewArg, OrigC->getType()));
+      } else
+        Inst->setOperand(OpndIndex, NewArg);
+    }
+  }
+
+  return NewFunction;
+}
+
+// Given the original function (Thunk) and the merged function (ToFunc), create
+// a thunk to the merged function.
+
+static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
+                        Function *ToFunc) {
+  auto *Thunk = FI.F;
+
+  assert(Thunk->arg_size() + Params.size() ==
+         ToFunc->getFunctionType()->getNumParams());
+  Thunk->dropAllReferences();
+
+  BasicBlock *BB = BasicBlock::Create(Thunk->getContext(), "", Thunk);
+  IRBuilder<> Builder(BB);
+
+  SmallVector<Value *> Args;
+  unsigned ParamIdx = 0;
+  FunctionType *ToFuncTy = ToFunc->getFunctionType();
+
+  // Add arguments which are passed through Thunk.
+  for (Argument &AI : Thunk->args()) {
+    Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
+    ++ParamIdx;
+  }
+
+  // Add new arguments defined by Params.
+  for (auto *Param : Params) {
+    assert(ParamIdx < ToFuncTy->getNumParams());
+    // FIXME: do not support signing
----------------
kyulee-com wrote:

This is about support for a target that enables pointer authentication, which appears to be supported by the Swift function merger. In the current implementation, I've opted to bypass this case for now. Nonetheless, I've removed the comment since the current implementation doesn't support it anyway.

https://github.com/llvm/llvm-project/pull/112671


More information about the llvm-commits mailing list