[llvm] [CGData] Global Merge Functions (PR #112671)

Zhaoxuan Jiang via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 7 22:57:50 PST 2024


================
@@ -0,0 +1,744 @@
+//===---- GlobalMergeFunctions.cpp - Global merge functions -------*- C++ -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass defines the implementation of a function merging mechanism
+// that utilizes a stable function hash to track differences in constants and
+// create potential merge candidates. The process involves two rounds:
+// 1. The first round collects stable function hashes and identifies merge
+//    candidates with matching hashes. It also computes the set of parameters
+//    that point to different constants during the stable function merge.
+// 2. The second round leverages this collected global function information to
+//    optimistically create a merged function in each module context, ensuring
+//    correct transformation.
+// Similar to the global outliner, this approach uses the linker's deduplication
+// (ICF) to fold identical merged functions, thereby reducing the final binary
+// size. The work is inspired by the concepts discussed in the following paper:
+// https://dl.acm.org/doi/pdf/10.1145/3652032.3657575.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/GlobalMergeFunctions.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/ModuleSummaryAnalysis.h"
+#include "llvm/CGData/CodeGenData.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/StructuralHash.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+#define DEBUG_TYPE "global-merge-func"
+
+using namespace llvm;
+using namespace llvm::support;
+
+static cl::opt<bool>
+    DisableGlobalMerging("disable-global-merging", cl::Hidden,
+                         cl::desc("Disable global merging only by ignoring "
+                                  "the codegen data generation or use. Local "
+                                  "merging is still enabled within a module."),
+                         cl::init(false));
+static cl::opt<unsigned> GlobalMergingMinInstrs(
+    "global-merging-min-instrs",
+    cl::desc("The minimum instruction count required when merging functions."),
+    cl::init(1), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingMaxParams(
+    "global-merging-max-params",
+    cl::desc(
+        "The maximum number of parameters allowed when merging functions."),
+    cl::init(std::numeric_limits<unsigned>::max()), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingParamOverhead(
+    "global-merging-param-overhead",
+    cl::desc("The overhead cost associated with each parameter when merging "
+             "functions."),
+    cl::init(2), cl::Hidden);
+static cl::opt<unsigned>
+    GlobalMergingCallOverhead("global-merging-call-overhead",
+                              cl::desc("The overhead cost associated with each "
+                                       "function call when merging functions."),
+                              cl::init(1), cl::Hidden);
+static cl::opt<unsigned> GlobalMergingExtraThreshold(
+    "global-merging-extra-threshold",
+    cl::desc("An additional cost threshold that must be exceeded for merging "
+             "to be considered beneficial."),
+    cl::init(0), cl::Hidden);
+
+extern cl::opt<bool> EnableGlobalMergeFunc;
+
+STATISTIC(NumMismatchedFunctionHashGlobalMergeFunction,
+          "Number of mismatched function hash for global merge function");
+STATISTIC(NumMismatchedInstCountGlobalMergeFunction,
+          "Number of mismatched instruction count for global merge function");
+STATISTIC(NumMismatchedConstHashGlobalMergeFunction,
+          "Number of mismatched const hash for global merge function");
+STATISTIC(NumMismatchedModuleIdGlobalMergeFunction,
+          "Number of mismatched Module Id for global merge function");
+STATISTIC(NumGlobalMergeFunctions,
+          "Number of functions that are actually merged using function hash");
+STATISTIC(NumAnalyzedModues, "Number of modules that are analyzed");
+STATISTIC(NumAnalyzedFunctions, "Number of functions that are analyzed");
+STATISTIC(NumEligibleFunctions, "Number of functions that are eligible");
+
+/// Returns true if the \OpIdx operand of \p CI is the callee operand.
+static bool isCalleeOperand(const CallBase *CI, unsigned OpIdx) {
+  return &CI->getCalledOperandUse() == &CI->getOperandUse(OpIdx);
+}
+
+static bool canParameterizeCallOperand(const CallBase *CI, unsigned OpIdx) {
+  if (CI->isInlineAsm())
+    return false;
+  Function *Callee = CI->getCalledOperand()
+                         ? dyn_cast_or_null<Function>(
+                               CI->getCalledOperand()->stripPointerCasts())
+                         : nullptr;
+  if (Callee) {
+    if (Callee->isIntrinsic())
+      return false;
+    // objc_msgSend stubs must be called, and can't have their address taken.
+    if (Callee->getName().starts_with("objc_msgSend$"))
+      return false;
+  }
+  if (isCalleeOperand(CI, OpIdx) &&
+      CI->getOperandBundle(LLVMContext::OB_ptrauth).has_value()) {
+    // The operand is the callee and it has already been signed. Ignore this
+    // because we cannot add another ptrauth bundle to the call instruction.
+    return false;
+  }
+  return true;
+}
+
+bool isEligibleInstrunctionForConstantSharing(const Instruction *I) {
+  switch (I->getOpcode()) {
+  case Instruction::Load:
+  case Instruction::Store:
+  case Instruction::Call:
+  case Instruction::Invoke:
+    return true;
+  default:
+    return false;
+  }
+}
+
+bool isEligibleOperandForConstantSharing(const Instruction *I, unsigned OpIdx) {
+  assert(OpIdx < I->getNumOperands() && "Invalid operand index");
+
+  if (!isEligibleInstrunctionForConstantSharing(I))
+    return false;
+
+  auto Opnd = I->getOperand(OpIdx);
+  if (!isa<Constant>(Opnd))
+    return false;
+
+  if (const auto *CI = dyn_cast<CallBase>(I))
+    return canParameterizeCallOperand(CI, OpIdx);
+
+  return true;
+}
+
+/// Returns true if function \p F is eligible for merging.
+bool isEligibleFunction(Function *F) {
+  if (F->isDeclaration())
+    return false;
+
+  if (F->hasFnAttribute(llvm::Attribute::NoMerge))
+    return false;
+
+  if (F->hasAvailableExternallyLinkage())
+    return false;
+
+  if (F->getFunctionType()->isVarArg())
+    return false;
+
+  if (F->getCallingConv() == CallingConv::SwiftTail)
+    return false;
+
+  // If function contains callsites with musttail, if we merge
+  // it, the merged function will have the musttail callsite, but
+  // the number of parameters can change, thus the parameter count
+  // of the callsite will mismatch with the function itself.
+  for (const BasicBlock &BB : *F) {
+    for (const Instruction &I : BB) {
+      const auto *CB = dyn_cast<CallBase>(&I);
+      if (CB && CB->isMustTailCall())
+        return false;
+    }
+  }
+
+  return true;
+}
+
+static bool
+isEligibleInstrunctionForConstantSharingLocal(const Instruction *I) {
+  switch (I->getOpcode()) {
+  case Instruction::Load:
+  case Instruction::Store:
+  case Instruction::Call:
+  case Instruction::Invoke:
+    return true;
+  default:
+    return false;
+  }
+}
+
+static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
+  assert(OpIdx < I->getNumOperands() && "Invalid operand index");
+
+  if (!isEligibleInstrunctionForConstantSharingLocal(I))
+    return false;
+
+  if (!isa<Constant>(I->getOperand(OpIdx)))
+    return false;
+
+  if (const auto *CI = dyn_cast<CallBase>(I))
+    return canParameterizeCallOperand(CI, OpIdx);
+
+  return true;
+}
+
+static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
+  Type *SrcTy = V->getType();
+  if (SrcTy->isStructTy()) {
+    assert(DestTy->isStructTy());
+    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
+      Value *Element =
+          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
+                     DestTy->getStructElementType(I));
+
+      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+  assert(!DestTy->isStructTy());
+  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+    auto *DestAT = dyn_cast<ArrayType>(DestTy);
+    assert(DestAT);
+    assert(SrcAT->getNumElements() == DestAT->getNumElements());
+    Value *Result = UndefValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+      Value *Element =
+          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
+                     DestAT->getElementType());
+
+      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+  assert(!DestTy->isArrayTy());
+  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
+    return Builder.CreateIntToPtr(V, DestTy);
+  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
+    return Builder.CreatePtrToInt(V, DestTy);
+  else
+    return Builder.CreateBitCast(V, DestTy);
+}
+
+void GlobalMergeFunc::analyze(Module &M) {
+  ++NumAnalyzedModues;
+  for (Function &Func : M) {
+    ++NumAnalyzedFunctions;
+    if (isEligibleFunction(&Func)) {
+      ++NumEligibleFunctions;
+
+      auto FI = llvm::StructuralHashWithDifferences(Func, ignoreOp);
+
+      // Convert the operand map to a vector for a serialization-friendly
+      // format.
+      IndexOperandHashVecType IndexOperandHashes;
+      for (auto &Pair : *FI.IndexOperandHashMap)
+        IndexOperandHashes.emplace_back(Pair);
+
+      StableFunction SF(FI.FunctionHash, get_stable_name(Func.getName()).str(),
+                        M.getModuleIdentifier(), FI.IndexInstruction->size(),
+                        std::move(IndexOperandHashes));
+
+      LocalFunctionMap->insert(SF);
+    }
+  }
+}
+
+/// Tuple to hold function info to process merging.
+struct FuncMergeInfo {
+  StableFunctionMap::StableFunctionEntry *SF;
+  Function *F;
+  std::unique_ptr<IndexInstrMap> IndexInstruction;
+};
+
+// Given the func info, and the parameterized locations, create and return
+// a new merged function by replacing the original constants with the new
+// parameters.
+static Function *createMergedFunction(FuncMergeInfo &FI,
+                                      ArrayRef<Type *> ConstParamTypes,
+                                      const ParamLocsVecTy &ParamLocsVec) {
+  // Synthesize a new merged function name by appending ".Tgm" to the root
+  // function's name.
+  auto *MergedFunc = FI.F;
+  auto NewFunctionName = MergedFunc->getName().str() + ".Tgm";
+  auto *M = MergedFunc->getParent();
+  assert(!M->getFunction(NewFunctionName));
+
+  FunctionType *OrigTy = MergedFunc->getFunctionType();
+  // Get the original params' types.
+  SmallVector<Type *> ParamTypes(OrigTy->param_begin(), OrigTy->param_end());
+  // Append const parameter types that are passed in.
+  ParamTypes.append(ConstParamTypes.begin(), ConstParamTypes.end());
+  FunctionType *FuncType =
+      FunctionType::get(OrigTy->getReturnType(), ParamTypes, false);
+
+  // Declare a new function
+  Function *NewFunction =
+      Function::Create(FuncType, MergedFunc->getLinkage(), NewFunctionName);
+  if (auto *SP = MergedFunc->getSubprogram())
+    NewFunction->setSubprogram(SP);
+  NewFunction->copyAttributesFrom(MergedFunc);
+  NewFunction->setDLLStorageClass(GlobalValue::DefaultStorageClass);
+
+  NewFunction->setLinkage(GlobalValue::InternalLinkage);
+  NewFunction->addFnAttr(Attribute::NoInline);
----------------
nocchijiang wrote:

If `MergedFunc` is `alwaysinline`, we should remove the attribute from `NewFunction`. The verifier will complain about it if both `noinline` and `alwaysinline` are present.

https://github.com/llvm/llvm-project/pull/112671


More information about the llvm-commits mailing list