[llvm] llvm-reduce: Add values to return reduction (PR #132686)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 01:40:55 PDT 2025


================
@@ -0,0 +1,244 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Try to reduce a function by inserting new return instructions. Try to insert
+// an early return for each instruction value at that point. This requires
+// mutating the return type, or finding instructions with a compatible type.
+//
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "llvm-reduce"
+
+#include "ReduceValuesToReturn.h"
+
+#include "Delta.h"
+#include "Utils.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+using namespace llvm;
+
+/// Return true if it is legal to emit a copy of the function with a non-void
+/// return type.
+static bool canUseNonVoidReturnType(const Function &F) {
+  // Functions with sret arguments must return void.
+  return !F.hasStructRetAttr() &&
+         CallingConv::supportsNonVoidReturnType(F.getCallingConv());
+}
+
+/// Return true if it's legal to replace a function return type to use \p Ty.
+static bool isReallyValidReturnType(Type *Ty) {
+  return FunctionType::isValidReturnType(Ty) && !Ty->isTokenTy() &&
+         Ty->isFirstClassType();
+}
+
+/// Insert a ret inst after \p NewRetValue, which returns the value it produces.
+static void rewriteFuncWithReturnType(Function &OldF, Value *NewRetValue) {
+  Type *NewRetTy = NewRetValue->getType();
+  FunctionType *OldFuncTy = OldF.getFunctionType();
+
+  FunctionType *NewFuncTy =
+      FunctionType::get(NewRetTy, OldFuncTy->params(), OldFuncTy->isVarArg());
+
+  LLVMContext &Ctx = OldF.getContext();
+  Instruction *NewRetI = cast<Instruction>(NewRetValue);
+  BasicBlock *NewRetBlock = NewRetI->getParent();
+
+  BasicBlock::iterator NewValIt = NewRetI->getIterator();
+
+  // Hack up any return values in other blocks, we can't leave them as ret void.
+  if (OldFuncTy->getReturnType()->isVoidTy()) {
+    for (BasicBlock &OtherRetBB : OldF) {
+      if (&OtherRetBB != NewRetBlock) {
+        auto *OrigRI = dyn_cast<ReturnInst>(OtherRetBB.getTerminator());
+        if (!OrigRI)
+          continue;
+
+        OrigRI->eraseFromParent();
+        ReturnInst::Create(Ctx, getDefaultValue(NewRetTy), &OtherRetBB);
+      }
+    }
+  }
+
+  // Now prune any CFG edges we have to deal with.
+  //
+  // Use KeepOneInputPHIs in case the instruction we are using for the return is
+  // that phi.
+  // TODO: Could avoid this with fancier iterator management.
+  for (BasicBlock *Succ : successors(NewRetBlock))
+    Succ->removePredecessor(NewRetBlock, /*KeepOneInputPHIs=*/true);
+
+  // Now delete the tail of this block, in reverse to delete uses before defs.
+  for (Instruction &I : make_early_inc_range(
+           make_range(NewRetBlock->rbegin(), NewValIt.getReverse()))) {
+    Value *Replacement = getDefaultValue(I.getType());
+    I.replaceAllUsesWith(Replacement);
+    I.eraseFromParent();
+  }
+
+  ReturnInst::Create(Ctx, NewRetValue, NewRetBlock);
+
+  // TODO: We may be eliminating blocks that were originally unreachable. We
+  // probably ought to only be pruning blocks that became dead directly as a
+  // result of our pruning here.
+  EliminateUnreachableBlocks(OldF);
+
+  Function *NewF =
+      Function::Create(NewFuncTy, OldF.getLinkage(), OldF.getAddressSpace(), "",
+                       OldF.getParent());
+
+  NewF->removeFromParent();
+  OldF.getParent()->getFunctionList().insertAfter(OldF.getIterator(), NewF);
+  NewF->takeName(&OldF);
+  NewF->copyAttributesFrom(&OldF);
+
+  // Adjust the callsite uses to the new return type. We pre-filtered cases
+  // where the original call type was incorrectly non-void.
+  for (User *U : make_early_inc_range(OldF.users())) {
+    if (auto *CB = dyn_cast<CallBase>(U);
+        CB && CB->getCalledOperand() == &OldF) {
+      if (CB->getType()->isVoidTy()) {
+        FunctionType *CallType = CB->getFunctionType();
+
+        // The callsite may not match the new function type, in an undefined
+        // behavior way. Only mutate the local return type.
+        FunctionType *NewCallType = FunctionType::get(
+            NewRetTy, CallType->params(), CallType->isVarArg());
+
+        CB->mutateType(NewRetTy);
+        CB->setCalledFunction(NewCallType, NewF);
+      } else {
+        assert(CB->getType() == NewRetTy &&
+               "only handle exact return type match with non-void returns");
+      }
+    }
+  }
+
+  // Preserve the parameters of OldF.
+  ValueToValueMapTy VMap;
+  for (auto Z : zip_first(OldF.args(), NewF->args())) {
+    Argument &OldArg = std::get<0>(Z);
+    Argument &NewArg = std::get<1>(Z);
+
+    NewArg.setName(OldArg.getName()); // Copy the name over...
+    VMap[&OldArg] = &NewArg;          // Add mapping to VMap
+  }
+
+  SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned.
+  CloneFunctionInto(NewF, &OldF, VMap,
+                    CloneFunctionChangeType::LocalChangesOnly, Returns, "",
+                    /*CodeInfo=*/nullptr);
+  OldF.replaceAllUsesWith(NewF);
+  OldF.eraseFromParent();
+}
+
+// Check if all the callsites of the void function are void, or happen to
+// incorrectly use the new return type.
+//
+// TODO: We could make better effort to handle call type mismatches.
+static bool canReplaceFuncUsers(const Function &F, Type *NewRetTy) {
+  for (const Use &U : F.uses()) {
+    const CallBase *CB = dyn_cast<CallBase>(U.getUser());
+    if (!CB)
+      continue;
+
+    // Normal pointer uses are trivially replacable.
+    if (!CB->isCallee(&U))
+      continue;
+
+    // We can trivially replace the correct void call sites.
+    if (CB->getType()->isVoidTy())
+      continue;
+
+    // We can trivially replace the call if the return type happened to match
+    // the new return type.
+    if (CB->getType() == NewRetTy)
+      continue;
+
+    LLVM_DEBUG(dbgs() << "Cannot replace callsite with wrong type: " << *CB
+                      << '\n');
+    return false;
+  }
+
+  return true;
+}
+
+/// Return true if it's worthwhile replacing the non-void return value of \p BB
+/// with \p Replacement
+static bool shouldReplaceNonVoidReturnValue(const BasicBlock &BB,
+                                            const Value *Replacement) {
+  if (const auto *RI = dyn_cast<ReturnInst>(BB.getTerminator()))
+    return RI->getReturnValue() != Replacement;
+  return true;
+}
+
+static bool canHandleSuccessors(const BasicBlock &BB) {
+  // TODO: Handle invoke and other exotic terminators
+  if (!isa<ReturnInst, UnreachableInst, BranchInst, SwitchInst>(
+          BB.getTerminator()))
+    return false;
+
+  for (const BasicBlock *Succ : successors(&BB)) {
+    if (!Succ->canSplitPredecessors())
+      return false;
+  }
+
+  return true;
+}
+
+static bool shouldForwardValueToReturn(const BasicBlock &BB, const Value *V,
+                                       Type *RetTy) {
+  if (!isReallyValidReturnType(V->getType()))
+    return false;
+
+  return (RetTy->isVoidTy() ||
+          (RetTy == V->getType() && shouldReplaceNonVoidReturnValue(BB, V))) &&
+         canReplaceFuncUsers(*BB.getParent(), V->getType());
+}
+
+static bool tryForwardingInstructionsToReturn(
+    Function &F, Oracle &O,
+    std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {
+
+  // TODO: Should we try to expand returns to aggregate for function that
+  // already have a return value?
+  Type *RetTy = F.getReturnType();
+
+  for (BasicBlock &BB : F) {
+    if (!canHandleSuccessors(BB))
+      continue;
+
+    for (Instruction &I : BB) {
+      if (shouldForwardValueToReturn(BB, &I, RetTy) && !O.shouldKeep()) {
+        FuncsToReplace.emplace_back(&F, &I);
+        return true;
----------------
nikic wrote:

Hm, does the oracle properly handle this kind of early return?

Basically if O.shouldKeep() return false and now return true, we'll now perform an extra query (for the next instruction in the block), effectively shifting all the indices?

https://github.com/llvm/llvm-project/pull/132686


More information about the llvm-commits mailing list