[llvm] f6795e6 - [CodeExtractor] Refactor extractCodeRegion, fix alloca emission. (#114419)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 11:12:26 PST 2024


Author: Michael Kruse
Date: 2024-11-12T20:12:22+01:00
New Revision: f6795e6b4f619cbecc59a92f7e5fad7ca90ece54

URL: https://github.com/llvm/llvm-project/commit/f6795e6b4f619cbecc59a92f7e5fad7ca90ece54
DIFF: https://github.com/llvm/llvm-project/commit/f6795e6b4f619cbecc59a92f7e5fad7ca90ece54.diff

LOG: [CodeExtractor] Refactor extractCodeRegion, fix alloca emission. (#114419)

Reorganize the code into phases:

 * Analyze/normalize
 * Create extracted function prototype
 * Generate the new function's implementation
 * Generate call to new function
 * Connect call to original function's CFG

The motivation is #114669 to optionally clone the selected code region
into the new function instead of moving it. The current structure made
it difficult to add such functionality since there was no obvious place
to do so, not made easier by some functions doing more than their name
suggests. For instance, constructFunction modifies code outside the
constructed function, but also function properties such as
setPersonalityFn are derived somewhere else. Another example is
emitCallAndSwitchStatement, which despite its name also inserts stores
for output parameters.

Many operations also implicitly depend on the order they are applied
which this patch tries to reduce. For instance, ExtractedFuncRetVals
becomes the list exit blocks which also defines the return value when
leaving via that block. It is computed early such that the new
function's return instructions and the switch can be generated
independently. Also, ExtractedFuncRetVals is combining the lists
ExitBlocks and OldTargets which were not always kept consistent with
each other or NumExitBlocks. The method recomputeExitBlocks() will
update it when necessary.

The coding style partially contradict the current coding standard. For
instance some local variable start with lower case letters. I updated
some, but not all occurrences to make the diff match at least some lines
as unchanged.

The patch [D96854](https://reviews.llvm.org/D96854) introduced some
confusion of function argument indexes this is fixed here as well, hence
the patch is not NFC anymore. Tested in modified CodeExtractorTest.cpp.
Patch [D121061](https://reviews.llvm.org/D121061) introduced
AllocationBlock, but not all allocas were inserted there.

Efectively includes the following fixes:
1. https://github.com/llvm/llvm-project/commit/ce73b1672a6053d5974dc2342881aac02efe2dbb
2. https://github.com/llvm/llvm-project/commit/4aaa92578686176243a294eeb2ca5697a99edcaa
3. Missing allocas, still unfixed

Originally submitted as https://reviews.llvm.org/D115218

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/CodeExtractor.h
    llvm/lib/Transforms/Utils/CodeExtractor.cpp
    llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 826347e79f7195..60c5def3472b6b 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -35,6 +35,7 @@ class Instruction;
 class Module;
 class Type;
 class Value;
+class StructType;
 
 /// A cache for the CodeExtractor analysis. The operation \ref
 /// CodeExtractor::extractCodeRegion is guaranteed not to invalidate this
@@ -101,12 +102,25 @@ class CodeExtractorAnalysisCache {
 
     // Bits of intermediate state computed at various phases of extraction.
     SetVector<BasicBlock *> Blocks;
-    unsigned NumExitBlocks = std::numeric_limits<unsigned>::max();
-    Type *RetTy;
 
-    // Mapping from the original exit blocks, to the new blocks inside
-    // the function.
-    SmallVector<BasicBlock *, 4> OldTargets;
+    /// Lists of blocks that are branched from the code region to be extracted,
+    /// also called the exit blocks. Each block is contained at most once. Its
+    /// order defines the return value of the extracted function.
+    ///
+    /// When there is just one (or no) exit block, the return value is
+    /// irrelevant.
+    ///
+    /// When there are exactly two exit blocks, the extracted function returns a
+    /// boolean. For ExtractedFuncRetVals[0], it returns 'true'. For
+    /// ExtractedFuncRetVals[1] it returns 'false'.
+    /// NOTE: Since a boolean is represented by i1, ExtractedFuncRetVals[0]
+    ///       returns 1 and ExtractedFuncRetVals[1] returns 0, which opposite
+    ///       of the regular pattern below.
+    ///
+    /// When there are 3 or more exit blocks, leaving the extracted function via
+    /// the first block it returns 0. When leaving via the second entry it
+    /// returns 1, etc.
+    SmallVector<BasicBlock *> ExtractedFuncRetVals;
 
     // Suffix to use when creating extracted function (appended to the original
     // function name + "."). If empty, the default is to use the entry block
@@ -238,26 +252,61 @@ class CodeExtractorAnalysisCache {
     getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
                        Instruction *Addr, BasicBlock *ExitBlock) const;
 
+    /// Updates the list of SwitchCases (corresponding to exit blocks) after
+    /// changes of the control flow or the Blocks list.
+    void computeExtractedFuncRetVals();
+
+    /// Return the type used for the return code of the extracted function to
+    /// indicate which exit block to jump to.
+    Type *getSwitchType();
+
     void severSplitPHINodesOfEntry(BasicBlock *&Header);
-    void severSplitPHINodesOfExits(const SetVector<BasicBlock *> &Exits);
+    void severSplitPHINodesOfExits();
     void splitReturnBlocks();
 
-    Function *constructFunction(const ValueSet &inputs,
-                                const ValueSet &outputs,
-                                BasicBlock *header,
-                                BasicBlock *newRootNode, BasicBlock *newHeader,
-                                Function *oldFunction, Module *M);
-
     void moveCodeToFunction(Function *newFunction);
 
     void calculateNewCallTerminatorWeights(
         BasicBlock *CodeReplacer,
-        DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+        const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
         BranchProbabilityInfo *BPI);
 
-    CallInst *emitCallAndSwitchStatement(Function *newFunction,
-                                         BasicBlock *newHeader,
-                                         ValueSet &inputs, ValueSet &outputs);
+    /// Normalizes the control flow of the extracted regions, such as ensuring
+    /// that the extracted region does not contain a return instruction.
+    void normalizeCFGForExtraction(BasicBlock *&header);
+
+    /// Generates the function declaration for the function containing the
+    /// extracted code.
+    Function *constructFunctionDeclaration(const ValueSet &inputs,
+                                           const ValueSet &outputs,
+                                           BlockFrequency EntryFreq,
+                                           const Twine &Name,
+                                           ValueSet &StructValues,
+                                           StructType *&StructTy);
+
+    /// Generates the code for the extracted function. That is: a prolog, the
+    /// moved or copied code from the original function, and epilogs for each
+    /// exit.
+    void emitFunctionBody(const ValueSet &inputs, const ValueSet &outputs,
+                          const ValueSet &StructValues, Function *newFunction,
+                          StructType *StructArgTy, BasicBlock *header,
+                          const ValueSet &SinkingCands);
+
+    /// Generates a Basic Block that calls the extracted function.
+    CallInst *emitReplacerCall(const ValueSet &inputs, const ValueSet &outputs,
+                               const ValueSet &StructValues,
+                               Function *newFunction, StructType *StructArgTy,
+                               Function *oldFunction, BasicBlock *ReplIP,
+                               BlockFrequency EntryFreq,
+                               ArrayRef<Value *> LifetimesStart,
+                               std::vector<Value *> &Reloads);
+
+    /// Connects the basic block containing the call to the extracted function
+    /// into the original function's control flow.
+    void insertReplacerCall(
+        Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer,
+        const ValueSet &outputs, ArrayRef<Value *> Reloads,
+        const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights);
   };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index ccca88d6c8e7d3..ed4c9c3c30f4f3 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -421,7 +421,6 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
   }
   // Now add the old exit block to the outline region.
   Blocks.insert(CommonExitBlock);
-  OldTargets.push_back(NewExitBlock);
   return CommonExitBlock;
 }
 
@@ -735,9 +734,8 @@ void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
 /// outlined region, we split these PHIs on two: one with inputs from region
 /// and other with remaining incoming blocks; then first PHIs are placed in
 /// outlined region.
-void CodeExtractor::severSplitPHINodesOfExits(
-    const SetVector<BasicBlock *> &Exits) {
-  for (BasicBlock *ExitBB : Exits) {
+void CodeExtractor::severSplitPHINodesOfExits() {
+  for (BasicBlock *ExitBB : ExtractedFuncRetVals) {
     BasicBlock *NewBB = nullptr;
 
     for (PHINode &PN : ExitBB->phis()) {
@@ -801,44 +799,28 @@ void CodeExtractor::splitReturnBlocks() {
     }
 }
 
-/// constructFunction - make a function based on inputs and outputs, as follows:
-/// f(in0, ..., inN, out0, ..., outN)
-Function *CodeExtractor::constructFunction(const ValueSet &inputs,
-                                           const ValueSet &outputs,
-                                           BasicBlock *header,
-                                           BasicBlock *newRootNode,
-                                           BasicBlock *newHeader,
-                                           Function *oldFunction,
-                                           Module *M) {
+Function *CodeExtractor::constructFunctionDeclaration(
+    const ValueSet &inputs, const ValueSet &outputs, BlockFrequency EntryFreq,
+    const Twine &Name, ValueSet &StructValues, StructType *&StructTy) {
   LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
   LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
 
-  // This function returns unsigned, outputs will go back by reference.
-  switch (NumExitBlocks) {
-  case 0:
-  case 1: RetTy = Type::getVoidTy(header->getContext()); break;
-  case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
-  default: RetTy = Type::getInt16Ty(header->getContext()); break;
-  }
+  Function *oldFunction = Blocks.front()->getParent();
+  Module *M = Blocks.front()->getModule();
 
+  // Assemble the function's parameter lists.
   std::vector<Type *> ParamTy;
   std::vector<Type *> AggParamTy;
-  std::vector<std::tuple<unsigned, Value *>> NumberedInputs;
-  std::vector<std::tuple<unsigned, Value *>> NumberedOutputs;
-  ValueSet StructValues;
   const DataLayout &DL = M->getDataLayout();
 
   // Add the types of the input values to the function's argument list
-  unsigned ArgNum = 0;
   for (Value *value : inputs) {
     LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
       AggParamTy.push_back(value->getType());
       StructValues.insert(value);
-    } else {
+    } else
       ParamTy.push_back(value->getType());
-      NumberedInputs.emplace_back(ArgNum++, value);
-    }
   }
 
   // Add the types of the output values to the function's argument list.
@@ -847,11 +829,9 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
       AggParamTy.push_back(output->getType());
       StructValues.insert(output);
-    } else {
+    } else
       ParamTy.push_back(
           PointerType::get(output->getType(), DL.getAllocaAddrSpace()));
-      NumberedOutputs.emplace_back(ArgNum++, output);
-    }
   }
 
   assert(
@@ -862,14 +842,13 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
          "Expeced StructValues only with AggregateArgs set");
 
   // Concatenate scalar and aggregate params in ParamTy.
-  size_t NumScalarParams = ParamTy.size();
-  StructType *StructTy = nullptr;
-  if (AggregateArgs && !AggParamTy.empty()) {
+  if (!AggParamTy.empty()) {
     StructTy = StructType::get(M->getContext(), AggParamTy);
     ParamTy.push_back(PointerType::get(
         StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
   }
 
+  Type *RetTy = getSwitchType();
   LLVM_DEBUG({
     dbgs() << "Function type: " << *RetTy << " f(";
     for (Type *i : ParamTy)
@@ -880,15 +859,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   FunctionType *funcType = FunctionType::get(
       RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
 
-  std::string SuffixToUse =
-      Suffix.empty()
-          ? (header->getName().empty() ? "extracted" : header->getName().str())
-          : Suffix;
   // Create the new function
-  Function *newFunction = Function::Create(
-      funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
-      oldFunction->getName() + "." + SuffixToUse, M);
-  newFunction->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
+  Function *newFunction =
+      Function::Create(funcType, GlobalValue::InternalLinkage,
+                       oldFunction->getAddressSpace(), Name, M);
+
+  // Propagate personality info to the new function if there is one.
+  if (oldFunction->hasPersonalityFn())
+    newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
 
   // Inherit all of the target dependent attributes and white-listed
   // target independent attributes.
@@ -1017,65 +995,37 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
     newFunction->addFnAttr(Attr);
   }
 
-  if (NumExitBlocks == 0) {
-    // Mark the new function `noreturn` if applicable. Terminators which resume
-    // exception propagation are treated as returning instructions. This is to
-    // avoid inserting traps after calls to outlined functions which unwind.
-    if (none_of(Blocks, [](const BasicBlock *BB) {
-          const Instruction *Term = BB->getTerminator();
-          return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
-        }))
-      newFunction->setDoesNotReturn();
-  }
-
-  newFunction->insert(newFunction->end(), newRootNode);
-
   // Create scalar and aggregate iterators to name all of the arguments we
   // inserted.
   Function::arg_iterator ScalarAI = newFunction->arg_begin();
-  Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
 
-  // Rewrite all users of the inputs in the extracted region to use the
-  // arguments (or appropriate addressing into struct) instead.
-  for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
-    Value *RewriteVal;
-    if (AggregateArgs && StructValues.contains(inputs[i])) {
-      Value *Idx[2];
-      Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
-      BasicBlock::iterator TI = newFunction->begin()->getTerminator()->getIterator();
-      GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
-      RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
-                                "loadgep_" + inputs[i]->getName(), TI);
-      ++aggIdx;
-    } else
-      RewriteVal = &*ScalarAI++;
+  // Set names and attributes for input and output arguments.
+  ScalarAI = newFunction->arg_begin();
+  for (Value *input : inputs) {
+    if (StructValues.contains(input))
+      continue;
 
-    std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
-    for (User *use : Users)
-      if (Instruction *inst = dyn_cast<Instruction>(use))
-        if (Blocks.count(inst->getParent()))
-          inst->replaceUsesOfWith(inputs[i], RewriteVal);
+    ScalarAI->setName(input->getName());
+    if (input->isSwiftError())
+      newFunction->addParamAttr(ScalarAI - newFunction->arg_begin(),
+                                Attribute::SwiftError);
+    ++ScalarAI;
   }
+  for (Value *output : outputs) {
+    if (StructValues.contains(output))
+      continue;
 
-  // Set names for input and output arguments.
-  for (auto [i, argVal] : NumberedInputs)
-    newFunction->getArg(i)->setName(argVal->getName());
-  for (auto [i, argVal] : NumberedOutputs)
-    newFunction->getArg(i)->setName(argVal->getName() + ".out");
+    ScalarAI->setName(output->getName() + ".out");
+    ++ScalarAI;
+  }
 
-  // Rewrite branches to basic blocks outside of the loop to new dummy blocks
-  // within the new function. This must be done before we lose track of which
-  // blocks were originally in the code region.
-  std::vector<User *> Users(header->user_begin(), header->user_end());
-  for (auto &U : Users)
-    // The BasicBlock which contains the branch is not in the region
-    // modify the branch target to a new block
-    if (Instruction *I = dyn_cast<Instruction>(U))
-      if (I->isTerminator() && I->getFunction() == oldFunction &&
-          !Blocks.count(I->getParent()))
-        I->replaceUsesOfWith(header, newHeader);
+  // Update the entry count of the function.
+  if (BFI) {
+    auto Count = BFI->getProfileCountFromFreq(EntryFreq);
+    if (Count.has_value())
+      newFunction->setEntryCount(
+          ProfileCount(*Count, Function::PCT_Real)); // FIXME
+  }
 
   return newFunction;
 }
@@ -1171,315 +1121,8 @@ static void insertLifetimeMarkersSurroundingCall(
   }
 }
 
-/// emitCallAndSwitchStatement - This method sets up the caller side by adding
-/// the call instruction, splitting any PHI nodes in the header block as
-/// necessary.
-CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
-                                                    BasicBlock *codeReplacer,
-                                                    ValueSet &inputs,
-                                                    ValueSet &outputs) {
-  // Emit a call to the new function, passing in: *pointer to struct (if
-  // aggregating parameters), or plan inputs and allocated memory for outputs
-  std::vector<Value *> params, ReloadOutputs, Reloads;
-  ValueSet StructValues;
-
-  Module *M = newFunction->getParent();
-  LLVMContext &Context = M->getContext();
-  const DataLayout &DL = M->getDataLayout();
-  CallInst *call = nullptr;
-
-  // Add inputs as params, or to be filled into the struct
-  unsigned ScalarInputArgNo = 0;
-  SmallVector<unsigned, 1> SwiftErrorArgs;
-  for (Value *input : inputs) {
-    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
-      StructValues.insert(input);
-    else {
-      params.push_back(input);
-      if (input->isSwiftError())
-        SwiftErrorArgs.push_back(ScalarInputArgNo);
-    }
-    ++ScalarInputArgNo;
-  }
-
-  // Create allocas for the outputs
-  unsigned ScalarOutputArgNo = 0;
-  for (Value *output : outputs) {
-    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
-      StructValues.insert(output);
-    } else {
-      AllocaInst *alloca =
-        new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
-                       nullptr, output->getName() + ".loc",
-                       codeReplacer->getParent()->front().begin());
-      ReloadOutputs.push_back(alloca);
-      params.push_back(alloca);
-      ++ScalarOutputArgNo;
-    }
-  }
-
-  StructType *StructArgTy = nullptr;
-  AllocaInst *Struct = nullptr;
-  unsigned NumAggregatedInputs = 0;
-  if (AggregateArgs && !StructValues.empty()) {
-    std::vector<Type *> ArgTypes;
-    for (Value *V : StructValues)
-      ArgTypes.push_back(V->getType());
-
-    // Allocate a struct at the beginning of this function
-    StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
-    Struct = new AllocaInst(
-        StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
-        AllocationBlock ? AllocationBlock->getFirstInsertionPt()
-                        : codeReplacer->getParent()->front().begin());
-
-    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-      auto *StructSpaceCast = new AddrSpaceCastInst(
-          Struct, PointerType ::get(Context, 0), "structArg.ascast");
-      StructSpaceCast->insertAfter(Struct);
-      params.push_back(StructSpaceCast);
-    } else {
-      params.push_back(Struct);
-    }
-    // Store aggregated inputs in the struct.
-    for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
-      if (inputs.contains(StructValues[i])) {
-        Value *Idx[2];
-        Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-        Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
-        GetElementPtrInst *GEP = GetElementPtrInst::Create(
-            StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
-        GEP->insertInto(codeReplacer, codeReplacer->end());
-        new StoreInst(StructValues[i], GEP, codeReplacer);
-        NumAggregatedInputs++;
-      }
-    }
-  }
-
-  // Emit the call to the function
-  call = CallInst::Create(newFunction, params,
-                          NumExitBlocks > 1 ? "targetBlock" : "");
-  // Add debug location to the new call, if the original function has debug
-  // info. In that case, the terminator of the entry block of the extracted
-  // function contains the first debug location of the extracted function,
-  // set in extractCodeRegion.
-  if (codeReplacer->getParent()->getSubprogram()) {
-    if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
-      call->setDebugLoc(DL);
-  }
-  call->insertInto(codeReplacer, codeReplacer->end());
-
-  // Set swifterror parameter attributes.
-  for (unsigned SwiftErrArgNo : SwiftErrorArgs) {
-    call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
-    newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
-  }
-
-  // Reload the outputs passed in by reference, use the struct if output is in
-  // the aggregate or reload from the scalar argument.
-  for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
-                aggIdx = NumAggregatedInputs;
-       i != e; ++i) {
-    Value *Output = nullptr;
-    if (AggregateArgs && StructValues.contains(outputs[i])) {
-      Value *Idx[2];
-      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
-      GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
-      GEP->insertInto(codeReplacer, codeReplacer->end());
-      Output = GEP;
-      ++aggIdx;
-    } else {
-      Output = ReloadOutputs[scalarIdx];
-      ++scalarIdx;
-    }
-    LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
-                                  outputs[i]->getName() + ".reload",
-                                  codeReplacer);
-    Reloads.push_back(load);
-    std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
-    for (User *U : Users) {
-      Instruction *inst = cast<Instruction>(U);
-      if (!Blocks.count(inst->getParent()))
-        inst->replaceUsesOfWith(outputs[i], load);
-    }
-  }
-
-  // Now we can emit a switch statement using the call as a value.
-  SwitchInst *TheSwitch =
-      SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
-                         codeReplacer, 0, codeReplacer);
-
-  // Since there may be multiple exits from the original region, make the new
-  // function return an unsigned, switch on that number.  This loop iterates
-  // over all of the blocks in the extracted region, updating any terminator
-  // instructions in the to-be-extracted region that branch to blocks that are
-  // not in the region to be extracted.
-  std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
-
-  // Iterate over the previously collected targets, and create new blocks inside
-  // the function to branch to.
-  unsigned switchVal = 0;
-  for (BasicBlock *OldTarget : OldTargets) {
-    if (Blocks.count(OldTarget))
-      continue;
-    BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
-    if (NewTarget)
-      continue;
-
-    // If we don't already have an exit stub for this non-extracted
-    // destination, create one now!
-    NewTarget = BasicBlock::Create(Context,
-                                    OldTarget->getName() + ".exitStub",
-                                    newFunction);
-    unsigned SuccNum = switchVal++;
-
-    Value *brVal = nullptr;
-    assert(NumExitBlocks < 0xffff && "too many exit blocks for switch");
-    switch (NumExitBlocks) {
-    case 0:
-    case 1: break;  // No value needed.
-    case 2:         // Conditional branch, return a bool
-      brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
-      break;
-    default:
-      brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
-      break;
-    }
-
-    ReturnInst::Create(Context, brVal, NewTarget);
-
-    // Update the switch instruction.
-    TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
-                                        SuccNum),
-                        OldTarget);
-  }
-
-  for (BasicBlock *Block : Blocks) {
-    Instruction *TI = Block->getTerminator();
-    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
-      if (Blocks.count(TI->getSuccessor(i)))
-        continue;
-      BasicBlock *OldTarget = TI->getSuccessor(i);
-      // add a new basic block which returns the appropriate value
-      BasicBlock *NewTarget = ExitBlockMap[OldTarget];
-      assert(NewTarget && "Unknown target block!");
-
-      // rewrite the original branch instruction with this new target
-      TI->setSuccessor(i, NewTarget);
-   }
-  }
-
-  // Store the arguments right after the definition of output value.
-  // This should be proceeded after creating exit stubs to be ensure that invoke
-  // result restore will be placed in the outlined function.
-  Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
-  std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
-  Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
-  std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
-
-  for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
-       ++i) {
-    auto *OutI = dyn_cast<Instruction>(outputs[i]);
-    if (!OutI)
-      continue;
-
-    // Find proper insertion point.
-    BasicBlock::iterator InsertPt;
-    // In case OutI is an invoke, we insert the store at the beginning in the
-    // 'normal destination' BB. Otherwise we insert the store right after OutI.
-    if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
-      InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
-    else if (auto *Phi = dyn_cast<PHINode>(OutI))
-      InsertPt = Phi->getParent()->getFirstInsertionPt();
-    else
-      InsertPt = std::next(OutI->getIterator());
-
-    assert((InsertPt->getFunction() == newFunction ||
-            Blocks.count(InsertPt->getParent())) &&
-           "InsertPt should be in new function");
-    if (AggregateArgs && StructValues.contains(outputs[i])) {
-      assert(AggOutputArgBegin != newFunction->arg_end() &&
-             "Number of aggregate output arguments should match "
-             "the number of defined values");
-      Value *Idx[2];
-      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
-      GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
-          InsertPt);
-      new StoreInst(outputs[i], GEP, InsertPt);
-      ++aggIdx;
-      // Since there should be only one struct argument aggregating
-      // all the output values, we shouldn't increment AggOutputArgBegin, which
-      // always points to the struct argument, in this case.
-    } else {
-      assert(ScalarOutputArgBegin != newFunction->arg_end() &&
-             "Number of scalar output arguments should match "
-             "the number of defined values");
-      new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertPt);
-      ++ScalarOutputArgBegin;
-    }
-  }
-
-  // Now that we've done the deed, simplify the switch instruction.
-  Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
-  switch (NumExitBlocks) {
-  case 0:
-    // There are no successors (the block containing the switch itself), which
-    // means that previously this was the last part of the function, and hence
-    // this should be rewritten as a `ret` or `unreachable`.
-    if (newFunction->doesNotReturn()) {
-      // If fn is no return, end with an unreachable terminator.
-      (void)new UnreachableInst(Context, TheSwitch->getIterator());
-    } else if (OldFnRetTy->isVoidTy()) {
-      // We have no return value.
-      ReturnInst::Create(Context, nullptr,
-                         TheSwitch->getIterator()); // Return void
-    } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
-      // return what we have
-      ReturnInst::Create(Context, TheSwitch->getCondition(),
-                         TheSwitch->getIterator());
-    } else {
-      // Otherwise we must have code extracted an unwind or something, just
-      // return whatever we want.
-      ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy),
-                         TheSwitch->getIterator());
-    }
-
-    TheSwitch->eraseFromParent();
-    break;
-  case 1:
-    // Only a single destination, change the switch into an unconditional
-    // branch.
-    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getIterator());
-    TheSwitch->eraseFromParent();
-    break;
-  case 2:
-    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
-                       call, TheSwitch->getIterator());
-    TheSwitch->eraseFromParent();
-    break;
-  default:
-    // Otherwise, make the default destination of the switch instruction be one
-    // of the other successors.
-    TheSwitch->setCondition(call);
-    TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
-    // Remove redundant case
-    TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
-    break;
-  }
-
-  // Insert lifetime markers around the reloads of any output values. The
-  // allocas output values are stored in are only in-use in the codeRepl block.
-  insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
-
-  return call;
-}
-
 void CodeExtractor::moveCodeToFunction(Function *newFunction) {
-  auto newFuncIt = newFunction->front().getIterator();
+  auto newFuncIt = newFunction->begin();
   for (BasicBlock *Block : Blocks) {
     // Delete the basic block from the old function, and the list of blocks
     Block->removeFromParent();
@@ -1495,7 +1138,7 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
 
 void CodeExtractor::calculateNewCallTerminatorWeights(
     BasicBlock *CodeReplacer,
-    DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+    const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
     BranchProbabilityInfo *BPI) {
   using Distribution = BlockFrequencyInfoImplBase::Distribution;
   using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
@@ -1513,7 +1156,7 @@ void CodeExtractor::calculateNewCallTerminatorWeights(
   // Add each of the frequencies of the successors.
   for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
     BlockNode ExitNode(i);
-    uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+    uint64_t ExitFreq = ExitWeights.lookup(TI->getSuccessor(i)).getFrequency();
     if (ExitFreq != 0)
       BranchDist.addExit(ExitNode, ExitFreq);
     else
@@ -1745,9 +1388,49 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
   BasicBlock *header = *Blocks.begin();
   Function *oldFunction = header->getParent();
 
+  normalizeCFGForExtraction(header);
+
+  // Remove @llvm.assume calls that will be moved to the new function from the
+  // old function's assumption cache.
+  for (BasicBlock *Block : Blocks) {
+    for (Instruction &I : llvm::make_early_inc_range(*Block)) {
+      if (auto *AI = dyn_cast<AssumeInst>(&I)) {
+        if (AC)
+          AC->unregisterAssumption(AI);
+        AI->eraseFromParent();
+      }
+    }
+  }
+
+  ValueSet SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
+  findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+  assert(HoistingCands.empty() || CommonExit);
+
+  // Find inputs to, outputs from the code region.
+  findInputsOutputs(inputs, outputs, SinkingCands);
+
+  // Collect objects which are inputs to the extraction region and also
+  // referenced by lifetime start markers within it. The effects of these
+  // markers must be replicated in the calling function to prevent the stack
+  // coloring pass from merging slots which store input objects.
+  ValueSet LifetimesStart;
+  eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
+
+  if (!HoistingCands.empty()) {
+    auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
+    Instruction *TI = HoistToBlock->getTerminator();
+    for (auto *II : HoistingCands)
+      cast<Instruction>(II)->moveBefore(TI);
+    computeExtractedFuncRetVals();
+  }
+
+  // CFG/ExitBlocks must not change hereafter
+
   // Calculate the entry frequency of the new function before we change the root
   //   block.
   BlockFrequency EntryFreq;
+  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
   if (BFI) {
     assert(BPI && "Both BPI and BFI are required to preserve profile info");
     for (BasicBlock *Pred : predecessors(header)) {
@@ -1756,140 +1439,227 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
       EntryFreq +=
           BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
     }
-  }
 
-  // Remove @llvm.assume calls that will be moved to the new function from the
-  // old function's assumption cache.
-  for (BasicBlock *Block : Blocks) {
-    for (Instruction &I : llvm::make_early_inc_range(*Block)) {
-      if (auto *AI = dyn_cast<AssumeInst>(&I)) {
-        if (AC)
-          AC->unregisterAssumption(AI);
-        AI->eraseFromParent();
+    for (BasicBlock *Succ : ExtractedFuncRetVals) {
+      for (BasicBlock *Block : predecessors(Succ)) {
+        if (!Blocks.count(Block))
+          continue;
+
+        // Update the branch weight for this successor.
+        BlockFrequency &BF = ExitWeights[Succ];
+        BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ);
       }
     }
   }
 
+  // Determine position for the replacement code. Do so before header is moved
+  // to the new function.
+  BasicBlock *ReplIP = header;
+  while (ReplIP && Blocks.count(ReplIP))
+    ReplIP = ReplIP->getNextNode();
+
+  // Construct new function based on inputs/outputs & add allocas for all defs.
+  std::string SuffixToUse =
+      Suffix.empty()
+          ? (header->getName().empty() ? "extracted" : header->getName().str())
+          : Suffix;
+
+  ValueSet StructValues;
+  StructType *StructTy;
+  Function *newFunction = constructFunctionDeclaration(
+      inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse,
+      StructValues, StructTy);
+  newFunction->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
+
+  emitFunctionBody(inputs, outputs, StructValues, newFunction, StructTy, header,
+                   SinkingCands);
+
+  std::vector<Value *> Reloads;
+  CallInst *TheCall = emitReplacerCall(
+      inputs, outputs, StructValues, newFunction, StructTy, oldFunction, ReplIP,
+      EntryFreq, LifetimesStart.getArrayRef(), Reloads);
+
+  insertReplacerCall(oldFunction, header, TheCall->getParent(), outputs,
+                     Reloads, ExitWeights);
+
+  fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall);
+
+  LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
+    newFunction->dump();
+    report_fatal_error("verification of newFunction failed!");
+  });
+  LLVM_DEBUG(if (verifyFunction(*oldFunction))
+                 report_fatal_error("verification of oldFunction failed!"));
+  LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC))
+                 report_fatal_error("Stale Asumption cache for old Function!"));
+  return newFunction;
+}
+
+void CodeExtractor::normalizeCFGForExtraction(BasicBlock *&header) {
   // If we have any return instructions in the region, split those blocks so
   // that the return is not in the region.
   splitReturnBlocks();
 
-  // Calculate the exit blocks for the extracted region and the total exit
-  // weights for each of those blocks.
-  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
-  SetVector<BasicBlock *> ExitBlocks;
+  // If we have to split PHI nodes of the entry or exit blocks, do so now.
+  severSplitPHINodesOfEntry(header);
+
+  // If a PHI in an exit block has multiple incoming values from the outlined
+  // region, create a new PHI for those values within the region such that only
+  // PHI itself becomes an output value, not each of its incoming values
+  // individually.
+  computeExtractedFuncRetVals();
+  severSplitPHINodesOfExits();
+}
+
+void CodeExtractor::computeExtractedFuncRetVals() {
+  ExtractedFuncRetVals.clear();
+
+  SmallPtrSet<BasicBlock *, 2> ExitBlocks;
   for (BasicBlock *Block : Blocks) {
     for (BasicBlock *Succ : successors(Block)) {
-      if (!Blocks.count(Succ)) {
-        // Update the branch weight for this successor.
-        if (BFI) {
-          BlockFrequency &BF = ExitWeights[Succ];
-          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ);
-        }
-        ExitBlocks.insert(Succ);
-      }
+      if (Blocks.count(Succ))
+        continue;
+
+      bool IsNew = ExitBlocks.insert(Succ).second;
+      if (IsNew)
+        ExtractedFuncRetVals.push_back(Succ);
     }
   }
-  NumExitBlocks = ExitBlocks.size();
+}
 
-  for (BasicBlock *Block : Blocks) {
-    for (BasicBlock *OldTarget : successors(Block))
-      if (!Blocks.contains(OldTarget))
-        OldTargets.push_back(OldTarget);
-  }
+Type *CodeExtractor::getSwitchType() {
+  LLVMContext &Context = Blocks.front()->getContext();
 
-  // If we have to split PHI nodes of the entry or exit blocks, do so now.
-  severSplitPHINodesOfEntry(header);
-  severSplitPHINodesOfExits(ExitBlocks);
+  assert(ExtractedFuncRetVals.size() < 0xffff &&
+         "too many exit blocks for switch");
+  switch (ExtractedFuncRetVals.size()) {
+  case 0:
+  case 1:
+    return Type::getVoidTy(Context);
+  case 2:
+    // Conditional branch, return a bool
+    return Type::getInt1Ty(Context);
+  default:
+    return Type::getInt16Ty(Context);
+  }
+}
 
-  // This takes place of the original loop
-  BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
-                                                "codeRepl", oldFunction,
-                                                header);
-  codeReplacer->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
+void CodeExtractor::emitFunctionBody(
+    const ValueSet &inputs, const ValueSet &outputs,
+    const ValueSet &StructValues, Function *newFunction,
+    StructType *StructArgTy, BasicBlock *header, const ValueSet &SinkingCands) {
+  Function *oldFunction = header->getParent();
+  LLVMContext &Context = oldFunction->getContext();
 
   // The new function needs a root node because other nodes can branch to the
   // head of the region, but the entry node of a function cannot have preds.
-  BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
-                                               "newFuncRoot");
+  BasicBlock *newFuncRoot =
+      BasicBlock::Create(Context, "newFuncRoot", newFunction);
   newFuncRoot->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
 
-  auto *BranchI = BranchInst::Create(header);
-  applyFirstDebugLoc(oldFunction, Blocks.getArrayRef(), BranchI);
-  BranchI->insertInto(newFuncRoot, newFuncRoot->end());
-
-  ValueSet SinkingCands, HoistingCands;
-  BasicBlock *CommonExit = nullptr;
-  findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
-  assert(HoistingCands.empty() || CommonExit);
-
-  // Find inputs to, outputs from the code region.
-  findInputsOutputs(inputs, outputs, SinkingCands);
-
   // Now sink all instructions which only have non-phi uses inside the region.
   // Group the allocas at the start of the block, so that any bitcast uses of
   // the allocas are well-defined.
-  AllocaInst *FirstSunkAlloca = nullptr;
   for (auto *II : SinkingCands) {
-    if (auto *AI = dyn_cast<AllocaInst>(II)) {
-      AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
-      if (!FirstSunkAlloca)
-        FirstSunkAlloca = AI;
+    if (!isa<AllocaInst>(II)) {
+      cast<Instruction>(II)->moveBefore(*newFuncRoot,
+                                        newFuncRoot->getFirstInsertionPt());
     }
   }
-  assert((SinkingCands.empty() || FirstSunkAlloca) &&
-         "Did not expect a sink candidate without any allocas");
   for (auto *II : SinkingCands) {
-    if (!isa<AllocaInst>(II)) {
-      cast<Instruction>(II)->moveAfter(FirstSunkAlloca);
+    if (auto *AI = dyn_cast<AllocaInst>(II)) {
+      AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
     }
   }
 
-  if (!HoistingCands.empty()) {
-    auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
-    Instruction *TI = HoistToBlock->getTerminator();
-    for (auto *II : HoistingCands)
-      cast<Instruction>(II)->moveBefore(TI);
+  Function::arg_iterator ScalarAI = newFunction->arg_begin();
+  Argument *AggArg = StructValues.empty()
+                         ? nullptr
+                         : newFunction->getArg(newFunction->arg_size() - 1);
+
+  // Rewrite all users of the inputs in the extracted region to use the
+  // arguments (or appropriate addressing into struct) instead.
+  SmallVector<Value *> NewValues;
+  for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
+    Value *RewriteVal;
+    if (StructValues.contains(inputs[i])) {
+      Value *Idx[2];
+      Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
+      GetElementPtrInst *GEP = GetElementPtrInst::Create(
+          StructArgTy, AggArg, Idx, "gep_" + inputs[i]->getName(), newFuncRoot);
+      RewriteVal = new LoadInst(StructArgTy->getElementType(aggIdx), GEP,
+                                "loadgep_" + inputs[i]->getName(), newFuncRoot);
+      ++aggIdx;
+    } else
+      RewriteVal = &*ScalarAI++;
+
+    NewValues.push_back(RewriteVal);
   }
 
-  // Collect objects which are inputs to the extraction region and also
-  // referenced by lifetime start markers within it. The effects of these
-  // markers must be replicated in the calling function to prevent the stack
-  // coloring pass from merging slots which store input objects.
-  ValueSet LifetimesStart;
-  eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
+  moveCodeToFunction(newFunction);
 
-  // Construct new function based on inputs/outputs & add allocas for all defs.
-  Function *newFunction =
-      constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
-                        oldFunction, oldFunction->getParent());
+  for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
+    Value *RewriteVal = NewValues[i];
 
-  // Update the entry count of the function.
-  if (BFI) {
-    auto Count = BFI->getProfileCountFromFreq(EntryFreq);
-    if (Count)
-      newFunction->setEntryCount(
-          ProfileCount(*Count, Function::PCT_Real)); // FIXME
-    BFI->setBlockFreq(codeReplacer, EntryFreq);
+    std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
+    for (User *use : Users)
+      if (Instruction *inst = dyn_cast<Instruction>(use))
+        if (Blocks.count(inst->getParent()))
+          inst->replaceUsesOfWith(inputs[i], RewriteVal);
   }
 
-  CallInst *TheCall =
-      emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
+  // Since there may be multiple exits from the original region, make the new
+  // function return an unsigned, switch on that number.  This loop iterates
+  // over all of the blocks in the extracted region, updating any terminator
+  // instructions in the to-be-extracted region that branch to blocks that are
+  // not in the region to be extracted.
+  std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
 
-  moveCodeToFunction(newFunction);
+  // Iterate over the previously collected targets, and create new blocks inside
+  // the function to branch to.
+  for (auto P : enumerate(ExtractedFuncRetVals)) {
+    BasicBlock *OldTarget = P.value();
+    size_t SuccNum = P.index();
 
-  // Replicate the effects of any lifetime start/end markers which referenced
-  // input objects in the extraction region by placing markers around the call.
-  insertLifetimeMarkersSurroundingCall(
-      oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall);
+    BasicBlock *NewTarget = BasicBlock::Create(
+        Context, OldTarget->getName() + ".exitStub", newFunction);
+    ExitBlockMap[OldTarget] = NewTarget;
 
-  // Propagate personality info to the new function if there is one.
-  if (oldFunction->hasPersonalityFn())
-    newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
+    Value *brVal = nullptr;
+    Type *RetTy = getSwitchType();
+    assert(ExtractedFuncRetVals.size() < 0xffff &&
+           "too many exit blocks for switch");
+    switch (ExtractedFuncRetVals.size()) {
+    case 0:
+    case 1:
+      // No value needed.
+      break;
+    case 2: // Conditional branch, return a bool
+      brVal = ConstantInt::get(RetTy, !SuccNum);
+      break;
+    default:
+      brVal = ConstantInt::get(RetTy, SuccNum);
+      break;
+    }
 
-  // Update the branch weights for the exit block.
-  if (BFI && NumExitBlocks > 1)
-    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+    ReturnInst::Create(Context, brVal, NewTarget);
+  }
+
+  for (BasicBlock *Block : Blocks) {
+    Instruction *TI = Block->getTerminator();
+    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
+      if (Blocks.count(TI->getSuccessor(i)))
+        continue;
+      BasicBlock *OldTarget = TI->getSuccessor(i);
+      // add a new basic block which returns the appropriate value
+      BasicBlock *NewTarget = ExitBlockMap[OldTarget];
+      assert(NewTarget && "Unknown target block!");
+
+      // rewrite the original branch instruction with this new target
+      TI->setSuccessor(i, NewTarget);
+    }
+  }
 
   // Loop over all of the PHI nodes in the header and exit blocks, and change
   // any references to the old incoming edge to be the new incoming edge.
@@ -1900,7 +1670,303 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
         PN->setIncomingBlock(i, newFuncRoot);
   }
 
-  for (BasicBlock *ExitBB : ExitBlocks)
+  // Connect newFunction entry block to new header.
+  BranchInst *BranchI = BranchInst::Create(header, newFuncRoot);
+  applyFirstDebugLoc(oldFunction, Blocks.getArrayRef(), BranchI);
+
+  // Store the arguments right after the definition of output value.
+  // This should be proceeded after creating exit stubs to be ensure that invoke
+  // result restore will be placed in the outlined function.
+  ScalarAI = newFunction->arg_begin();
+  unsigned AggIdx = 0;
+
+  for (Value *Input : inputs) {
+    if (StructValues.contains(Input))
+      ++AggIdx;
+    else
+      ++ScalarAI;
+  }
+
+  for (Value *Output : outputs) {
+    // Find proper insertion point.
+    // In case Output is an invoke, we insert the store at the beginning in the
+    // 'normal destination' BB. Otherwise we insert the store right after
+    // Output.
+    BasicBlock::iterator InsertPt;
+    if (auto *InvokeI = dyn_cast<InvokeInst>(Output))
+      InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
+    else if (auto *Phi = dyn_cast<PHINode>(Output))
+      InsertPt = Phi->getParent()->getFirstInsertionPt();
+    else if (auto *OutI = dyn_cast<Instruction>(Output))
+      InsertPt = std::next(OutI->getIterator());
+    else {
+      // Globals don't need to be updated, just advance to the next argument.
+      if (StructValues.contains(Output))
+        ++AggIdx;
+      else
+        ++ScalarAI;
+      continue;
+    }
+
+    assert((InsertPt->getFunction() == newFunction ||
+            Blocks.count(InsertPt->getParent())) &&
+           "InsertPt should be in new function");
+
+    if (StructValues.contains(Output)) {
+      assert(AggArg && "Number of aggregate output arguments should match "
+                       "the number of defined values");
+      Value *Idx[2];
+      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx);
+      GetElementPtrInst *GEP = GetElementPtrInst::Create(
+          StructArgTy, AggArg, Idx, "gep_" + Output->getName(), InsertPt);
+      new StoreInst(Output, GEP, InsertPt);
+      ++AggIdx;
+    } else {
+      assert(ScalarAI != newFunction->arg_end() &&
+             "Number of scalar output arguments should match "
+             "the number of defined values");
+      new StoreInst(Output, &*ScalarAI, InsertPt);
+      ++ScalarAI;
+    }
+  }
+
+  if (ExtractedFuncRetVals.empty()) {
+    // Mark the new function `noreturn` if applicable. Terminators which resume
+    // exception propagation are treated as returning instructions. This is to
+    // avoid inserting traps after calls to outlined functions which unwind.
+    if (none_of(Blocks, [](const BasicBlock *BB) {
+          const Instruction *Term = BB->getTerminator();
+          return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
+        }))
+      newFunction->setDoesNotReturn();
+  }
+}
+
+CallInst *CodeExtractor::emitReplacerCall(
+    const ValueSet &inputs, const ValueSet &outputs,
+    const ValueSet &StructValues, Function *newFunction,
+    StructType *StructArgTy, Function *oldFunction, BasicBlock *ReplIP,
+    BlockFrequency EntryFreq, ArrayRef<Value *> LifetimesStart,
+    std::vector<Value *> &Reloads) {
+  LLVMContext &Context = oldFunction->getContext();
+  Module *M = oldFunction->getParent();
+  const DataLayout &DL = M->getDataLayout();
+
+  // This takes place of the original loop
+  BasicBlock *codeReplacer =
+      BasicBlock::Create(Context, "codeRepl", oldFunction, ReplIP);
+  codeReplacer->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
+  BasicBlock *AllocaBlock =
+      AllocationBlock ? AllocationBlock : &oldFunction->getEntryBlock();
+  AllocaBlock->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat;
+
+  // Update the entry count of the function.
+  if (BFI)
+    BFI->setBlockFreq(codeReplacer, EntryFreq);
+
+  std::vector<Value *> params;
+
+  // Add inputs as params, or to be filled into the struct
+  for (Value *input : inputs) {
+    if (StructValues.contains(input))
+      continue;
+
+    params.push_back(input);
+  }
+
+  // Create allocas for the outputs
+  std::vector<Value *> ReloadOutputs;
+  for (Value *output : outputs) {
+    if (StructValues.contains(output))
+      continue;
+
+    AllocaInst *alloca = new AllocaInst(
+        output->getType(), DL.getAllocaAddrSpace(), nullptr,
+        output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
+    params.push_back(alloca);
+    ReloadOutputs.push_back(alloca);
+  }
+
+  AllocaInst *Struct = nullptr;
+  if (!StructValues.empty()) {
+    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
+                            "structArg", AllocaBlock->getFirstInsertionPt());
+    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+      auto *StructSpaceCast = new AddrSpaceCastInst(
+          Struct, PointerType ::get(Context, 0), "structArg.ascast");
+      StructSpaceCast->insertAfter(Struct);
+      params.push_back(StructSpaceCast);
+    } else {
+      params.push_back(Struct);
+    }
+
+    unsigned AggIdx = 0;
+    for (Value *input : inputs) {
+      if (!StructValues.contains(input))
+        continue;
+
+      Value *Idx[2];
+      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx);
+      GetElementPtrInst *GEP = GetElementPtrInst::Create(
+          StructArgTy, Struct, Idx, "gep_" + input->getName());
+      GEP->insertInto(codeReplacer, codeReplacer->end());
+      new StoreInst(input, GEP, codeReplacer);
+
+      ++AggIdx;
+    }
+  }
+
+  // Emit the call to the function
+  CallInst *call = CallInst::Create(
+      newFunction, params, ExtractedFuncRetVals.size() > 1 ? "targetBlock" : "",
+      codeReplacer);
+
+  // Set swifterror parameter attributes.
+  unsigned ParamIdx = 0;
+  unsigned AggIdx = 0;
+  for (auto input : inputs) {
+    if (StructValues.contains(input)) {
+      ++AggIdx;
+    } else {
+      if (input->isSwiftError())
+        call->addParamAttr(ParamIdx, Attribute::SwiftError);
+      ++ParamIdx;
+    }
+  }
+
+  // Add debug location to the new call, if the original function has debug
+  // info. In that case, the terminator of the entry block of the extracted
+  // function contains the first debug location of the extracted function,
+  // set in extractCodeRegion.
+  if (codeReplacer->getParent()->getSubprogram()) {
+    if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
+      call->setDebugLoc(DL);
+  }
+
+  // Reload the outputs passed in by reference, use the struct if output is in
+  // the aggregate or reload from the scalar argument.
+  for (unsigned i = 0, e = outputs.size(), scalarIdx = 0; i != e; ++i) {
+    Value *Output = nullptr;
+    if (StructValues.contains(outputs[i])) {
+      Value *Idx[2];
+      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx);
+      GetElementPtrInst *GEP = GetElementPtrInst::Create(
+          StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
+      GEP->insertInto(codeReplacer, codeReplacer->end());
+      Output = GEP;
+      ++AggIdx;
+    } else {
+      Output = ReloadOutputs[scalarIdx];
+      ++scalarIdx;
+    }
+    LoadInst *load =
+        new LoadInst(outputs[i]->getType(), Output,
+                     outputs[i]->getName() + ".reload", codeReplacer);
+    Reloads.push_back(load);
+  }
+
+  // Now we can emit a switch statement using the call as a value.
+  SwitchInst *TheSwitch =
+      SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
+                         codeReplacer, 0, codeReplacer);
+  for (auto P : enumerate(ExtractedFuncRetVals)) {
+    BasicBlock *OldTarget = P.value();
+    size_t SuccNum = P.index();
+
+    TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), SuccNum),
+                       OldTarget);
+  }
+
+  // Now that we've done the deed, simplify the switch instruction.
+  Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
+  switch (ExtractedFuncRetVals.size()) {
+  case 0:
+    // There are no successors (the block containing the switch itself), which
+    // means that previously this was the last part of the function, and hence
+    // this should be rewritten as a `ret` or `unreachable`.
+    if (newFunction->doesNotReturn()) {
+      // If fn is no return, end with an unreachable terminator.
+      (void)new UnreachableInst(Context, TheSwitch->getIterator());
+    } else if (OldFnRetTy->isVoidTy()) {
+      // We have no return value.
+      ReturnInst::Create(Context, nullptr,
+                         TheSwitch->getIterator()); // Return void
+    } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
+      // return what we have
+      ReturnInst::Create(Context, TheSwitch->getCondition(),
+                         TheSwitch->getIterator());
+    } else {
+      // Otherwise we must have code extracted an unwind or something, just
+      // return whatever we want.
+      ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy),
+                         TheSwitch->getIterator());
+    }
+
+    TheSwitch->eraseFromParent();
+    break;
+  case 1:
+    // Only a single destination, change the switch into an unconditional
+    // branch.
+    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getIterator());
+    TheSwitch->eraseFromParent();
+    break;
+  case 2:
+    // Only two destinations, convert to a condition branch.
+    // Remark: This also swaps the target branches:
+    // 0 -> false -> getSuccessor(2); 1 -> true -> getSuccessor(1)
+    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
+                       call, TheSwitch->getIterator());
+    TheSwitch->eraseFromParent();
+    break;
+  default:
+    // Otherwise, make the default destination of the switch instruction be one
+    // of the other successors.
+    TheSwitch->setCondition(call);
+    TheSwitch->setDefaultDest(
+        TheSwitch->getSuccessor(ExtractedFuncRetVals.size()));
+    // Remove redundant case
+    TheSwitch->removeCase(
+        SwitchInst::CaseIt(TheSwitch, ExtractedFuncRetVals.size() - 1));
+    break;
+  }
+
+  // Insert lifetime markers around the reloads of any output values. The
+  // allocas output values are stored in are only in-use in the codeRepl block.
+  insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
+
+  // Replicate the effects of any lifetime start/end markers which referenced
+  // input objects in the extraction region by placing markers around the call.
+  insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
+                                       {}, call);
+
+  return call;
+}
+
+void CodeExtractor::insertReplacerCall(
+    Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer,
+    const ValueSet &outputs, ArrayRef<Value *> Reloads,
+    const DenseMap<BasicBlock *, BlockFrequency> &ExitWeights) {
+
+  // Rewrite branches to basic blocks outside of the loop to new dummy blocks
+  // within the new function. This must be done before we lose track of which
+  // blocks were originally in the code region.
+  std::vector<User *> Users(header->user_begin(), header->user_end());
+  for (auto &U : Users)
+    // The BasicBlock which contains the branch is not in the region
+    // modify the branch target to a new block
+    if (Instruction *I = dyn_cast<Instruction>(U))
+      if (I->isTerminator() && I->getFunction() == oldFunction &&
+          !Blocks.count(I->getParent()))
+        I->replaceUsesOfWith(header, codeReplacer);
+
+  // When moving the code region it is sufficient to replace all uses to the
+  // extracted function values. Since the original definition's block
+  // dominated its use, it will also be dominated by codeReplacer's switch
+  // which joined multiple exit blocks.
+  for (BasicBlock *ExitBB : ExtractedFuncRetVals)
     for (PHINode &PN : ExitBB->phis()) {
       Value *IncomingCodeReplacerVal = nullptr;
       for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
@@ -1918,17 +1984,19 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
       }
     }
 
-  fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall);
+  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+    Value *load = Reloads[i];
+    std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
+    for (User *U : Users) {
+      Instruction *inst = cast<Instruction>(U);
+      if (inst->getParent()->getParent() == oldFunction)
+        inst->replaceUsesOfWith(outputs[i], load);
+    }
+  }
 
-  LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
-    newFunction->dump();
-    report_fatal_error("verification of newFunction failed!");
-  });
-  LLVM_DEBUG(if (verifyFunction(*oldFunction))
-             report_fatal_error("verification of oldFunction failed!"));
-  LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC))
-                 report_fatal_error("Stale Asumption cache for old Function!"));
-  return newFunction;
+  // Update the branch weights for the exit block.
+  if (BFI && ExtractedFuncRetVals.size() > 1)
+    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
 }
 
 bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,

diff  --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 80c2a23a957963..cfe07a2f6c461e 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -7,11 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Utils/CodeExtractor.h"
-#include "llvm/AsmParser/Parser.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -30,6 +31,13 @@ BasicBlock *getBlockByName(Function *F, StringRef name) {
   return nullptr;
 }
 
+Instruction *getInstByName(Function *F, StringRef Name) {
+  for (Instruction &I : instructions(F))
+    if (I.getName() == Name)
+      return &I;
+  return nullptr;
+}
+
 TEST(CodeExtractor, ExitStub) {
   LLVMContext Ctx;
   SMDiagnostic Err;
@@ -513,19 +521,28 @@ TEST(CodeExtractor, PartialAggregateArgs) {
     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
     target triple = "x86_64-unknown-linux-gnu"
 
-    declare void @use(i32)
+    ; use 
diff erent types such that an index mismatch will result in a type mismatch during verification.
+    declare void @use16(i16)
+    declare void @use32(i32)
+    declare void @use64(i64)
 
-    define void @foo(i32 %a, i32 %b, i32 %c) {
+    define void @foo(i16 %a, i32 %b, i64 %c) {
     entry:
       br label %extract
 
     extract:
-      call void @use(i32 %a)
-      call void @use(i32 %b)
-      call void @use(i32 %c)
+      call void @use16(i16 %a)
+      call void @use32(i32 %b)
+      call void @use64(i64 %c)
+      %d = add i16 21, 21
+      %e = add i32 21, 21
+      %f = add i64 21, 21
       br label %exit
 
     exit:
+      call void @use16(i16 %d)
+      call void @use32(i32 %e)
+      call void @use64(i64 %f)
       ret void
     }
   )ir",
@@ -544,18 +561,70 @@ TEST(CodeExtractor, PartialAggregateArgs) {
   BasicBlock *CommonExit = nullptr;
   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
-  // Exclude the first input from the argument aggregate.
-  CE.excludeArgFromAggregate(Inputs[0]);
+  // Exclude the middle input and output from the argument aggregate.
+  CE.excludeArgFromAggregate(Inputs[1]);
+  CE.excludeArgFromAggregate(Outputs[1]);
 
   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
   EXPECT_TRUE(Outlined);
-  // Expect 2 arguments in the outlined function: the excluded input and the
-  // struct aggregate for the remaining inputs.
-  EXPECT_EQ(Outlined->arg_size(), 2U);
+  // Expect 3 arguments in the outlined function: the excluded input, the
+  // excluded output, and the struct aggregate for the remaining inputs.
+  EXPECT_EQ(Outlined->arg_size(), 3U);
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
 
+TEST(CodeExtractor, AllocaBlock) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    define i32 @foo(i32 %x, i32 %y, i32 %z) {
+    entry:
+      br label %allocas
+
+    allocas:
+      br label %body
+
+    body:
+      %w = add i32 %x, %y
+      br label %notExtracted
+
+    notExtracted:
+      %r = add i32 %w, %x
+      ret i32 %r
+    }
+  )invalid",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "body")};
+
+  BasicBlock *AllocaBlock = getBlockByName(Func, "allocas");
+  CodeExtractor CE(Candidates, nullptr, true, nullptr, nullptr, nullptr, false,
+                   false, AllocaBlock);
+  CE.excludeArgFromAggregate(Func->getArg(0));
+  CE.excludeArgFromAggregate(getInstByName(Func, "w"));
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs;
+  Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+  EXPECT_TRUE(Outlined);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+
+  // The only added allocas may be in the dedicated alloca block. There should
+  // be one alloca for the struct, and another one for the reload value.
+  int NumAllocas = 0;
+  for (Instruction &I : instructions(Func)) {
+    if (!isa<AllocaInst>(I))
+      continue;
+    EXPECT_EQ(I.getParent(), AllocaBlock);
+    NumAllocas += 1;
+  }
+  EXPECT_EQ(NumAllocas, 2);
+}
+
 /// Regression test to ensure we don't crash trying to set the name of the ptr
 /// argument
 TEST(CodeExtractor, PartialAggregateArgs2) {


        


More information about the llvm-commits mailing list