[llvm] r348205 - [CodeExtractor] Split PHI nodes with incoming values from outlined region (PR39433)

Vedant Kumar via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 3 14:40:21 PST 2018


Author: vedantk
Date: Mon Dec  3 14:40:21 2018
New Revision: 348205

URL: http://llvm.org/viewvc/llvm-project?rev=348205&view=rev
Log:
[CodeExtractor] Split PHI nodes with incoming values from outlined region (PR39433)

If a PHI node out of extracted region has multiple incoming values from it,
split this PHI on two parts. First PHI has incomings only from region and
extracts with it (they are placed to the separate basic block that added to the
list of outlined), and incoming values in original PHI are replaced by first
PHI. Similar solution is already used in CodeExtractor for PHIs in entry block
(severSplitPHINodes method). It covers PR39433 bug.

Patch by Sergei Kachkov!

Differential Revision: https://reviews.llvm.org/D55018

Modified:
    llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h
    llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp
    llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll
    llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp

Modified: llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h?rev=348205&r1=348204&r2=348205&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h (original)
+++ llvm/trunk/include/llvm/Transforms/Utils/CodeExtractor.h Mon Dec  3 14:40:21 2018
@@ -18,6 +18,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include <limits>
 
 namespace llvm {
@@ -146,7 +147,8 @@ class Value;
     BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
 
   private:
-    void severSplitPHINodes(BasicBlock *&Header);
+    void severSplitPHINodesOfEntry(BasicBlock *&Header);
+    void severSplitPHINodesOfExits(const SmallPtrSetImpl<BasicBlock *> &Exits);
     void splitReturnBlocks();
 
     Function *constructFunction(const ValueSet &inputs,

Modified: llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp?rev=348205&r1=348204&r2=348205&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/CodeExtractor.cpp Mon Dec  3 14:40:21 2018
@@ -531,10 +531,10 @@ void CodeExtractor::findInputsOutputs(Va
   }
 }
 
-/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the
-/// region, we need to split the entry block of the region so that the PHI node
-/// is easier to deal with.
-void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) {
+/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
+/// of the region, we need to split the entry block of the region so that the
+/// PHI node is easier to deal with.
+void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
   unsigned NumPredsFromRegion = 0;
   unsigned NumPredsOutsideRegion = 0;
 
@@ -606,6 +606,56 @@ void CodeExtractor::severSplitPHINodes(B
   }
 }
 
+/// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
+/// outlined region, we split these PHIs on two: one with inputs from region
+/// and other with remaining incoming blocks; then first PHIs are placed in
+/// outlined region.
+void CodeExtractor::severSplitPHINodesOfExits(
+    const SmallPtrSetImpl<BasicBlock *> &Exits) {
+  for (BasicBlock *ExitBB : Exits) {
+    BasicBlock *NewBB = nullptr;
+
+    for (PHINode &PN : ExitBB->phis()) {
+      // Find all incoming values from the outlining region.
+      SmallVector<unsigned, 2> IncomingVals;
+      for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
+        if (Blocks.count(PN.getIncomingBlock(i)))
+          IncomingVals.push_back(i);
+
+      // Do not process PHI if there is one (or fewer) predecessor from region.
+      // If PHI has exactly one predecessor from region, only this one incoming
+      // will be replaced on codeRepl block, so it should be safe to skip PHI.
+      if (IncomingVals.size() <= 1)
+        continue;
+
+      // Create block for new PHIs and add it to the list of outlined if it
+      // wasn't done before.
+      if (!NewBB) {
+        NewBB = BasicBlock::Create(ExitBB->getContext(),
+                                   ExitBB->getName() + ".split",
+                                   ExitBB->getParent(), ExitBB);
+        SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB),
+                                           pred_end(ExitBB));
+        for (BasicBlock *PredBB : Preds)
+          if (Blocks.count(PredBB))
+            PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
+        BranchInst::Create(ExitBB, NewBB);
+        Blocks.insert(NewBB);
+      }
+
+      // Split this PHI.
+      PHINode *NewPN =
+          PHINode::Create(PN.getType(), IncomingVals.size(),
+                          PN.getName() + ".ce", NewBB->getFirstNonPHI());
+      for (unsigned i : IncomingVals)
+        NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
+      for (unsigned i : reverse(IncomingVals))
+        PN.removeIncomingValue(i, false);
+      PN.addIncoming(NewPN, NewBB);
+    }
+  }
+}
+
 void CodeExtractor::splitReturnBlocks() {
   for (BasicBlock *Block : Blocks)
     if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
@@ -1173,13 +1223,33 @@ Function *CodeExtractor::extractCodeRegi
     }
   }
 
-  // If we have to split PHI nodes or the entry block, do so now.
-  severSplitPHINodes(header);
-
   // If we have any return instructions in the region, split those blocks so
   // that the return is not in the region.
   splitReturnBlocks();
 
+  // Calculate the exit blocks for the extracted region and the total exit
+  // weights for each of those blocks.
+  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
+  SmallPtrSet<BasicBlock *, 1> ExitBlocks;
+  for (BasicBlock *Block : Blocks) {
+    for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
+         ++SI) {
+      if (!Blocks.count(*SI)) {
+        // Update the branch weight for this successor.
+        if (BFI) {
+          BlockFrequency &BF = ExitWeights[*SI];
+          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
+        }
+        ExitBlocks.insert(*SI);
+      }
+    }
+  }
+  NumExitBlocks = ExitBlocks.size();
+
+  // If we have to split PHI nodes of the entry or exit blocks, do so now.
+  severSplitPHINodesOfEntry(header);
+  severSplitPHINodesOfExits(ExitBlocks);
+
   // This takes place of the original loop
   BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
                                                 "codeRepl", oldFunction,
@@ -1224,25 +1294,6 @@ Function *CodeExtractor::extractCodeRegi
       cast<Instruction>(II)->moveBefore(TI);
   }
 
-  // Calculate the exit blocks for the extracted region and the total exit
-  // weights for each of those blocks.
-  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
-  SmallPtrSet<BasicBlock *, 1> ExitBlocks;
-  for (BasicBlock *Block : Blocks) {
-    for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
-         ++SI) {
-      if (!Blocks.count(*SI)) {
-        // Update the branch weight for this successor.
-        if (BFI) {
-          BlockFrequency &BF = ExitWeights[*SI];
-          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
-        }
-        ExitBlocks.insert(*SI);
-      }
-    }
-  }
-  NumExitBlocks = ExitBlocks.size();
-
   // Construct new function based on inputs/outputs & add allocas for all defs.
   Function *newFunction = constructFunction(inputs, outputs, header,
                                             newFuncRoot,
@@ -1270,8 +1321,8 @@ Function *CodeExtractor::extractCodeRegi
   if (BFI && NumExitBlocks > 1)
     calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
 
-  // Loop over all of the PHI nodes in the header block, and change any
-  // references to the old incoming edge to be the new incoming edge.
+  // Loop over all of the PHI nodes in the header and exit blocks, and change
+  // any references to the old incoming edge to be the new incoming edge.
   for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
     PHINode *PN = cast<PHINode>(I);
     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
@@ -1279,35 +1330,23 @@ Function *CodeExtractor::extractCodeRegi
         PN->setIncomingBlock(i, newFuncRoot);
   }
 
-  // Look at all successors of the codeReplacer block.  If any of these blocks
-  // had PHI nodes in them, we need to update the "from" block to be the code
-  // replacer, not the original block in the extracted region.
-  for (BasicBlock *SuccBB : successors(codeReplacer)) {
-    for (PHINode &PN : SuccBB->phis()) {
+  for (BasicBlock *ExitBB : ExitBlocks)
+    for (PHINode &PN : ExitBB->phis()) {
       Value *IncomingCodeReplacerVal = nullptr;
-      SmallVector<unsigned, 2> IncomingValsToRemove;
-      for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
-        BasicBlock *IncomingBB = PN.getIncomingBlock(I);
-
+      for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
         // Ignore incoming values from outside of the extracted region.
-        if (!Blocks.count(IncomingBB))
+        if (!Blocks.count(PN.getIncomingBlock(i)))
           continue;
 
         // Ensure that there is only one incoming value from codeReplacer.
         if (!IncomingCodeReplacerVal) {
-          PN.setIncomingBlock(I, codeReplacer);
-          IncomingCodeReplacerVal = PN.getIncomingValue(I);
-        } else {
-          assert(IncomingCodeReplacerVal == PN.getIncomingValue(I) &&
+          PN.setIncomingBlock(i, codeReplacer);
+          IncomingCodeReplacerVal = PN.getIncomingValue(i);
+        } else
+          assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
                  "PHI has two incompatbile incoming values from codeRepl");
-          IncomingValsToRemove.push_back(I);
-        }
       }
-
-      for (unsigned I : reverse(IncomingValsToRemove))
-        PN.removeIncomingValue(I, /*DeletePHIIfEmpty=*/false);
     }
-  }
 
   // Erase debug info intrinsics. Variable updates within the new function are
   // invisible to debuggers. This could be improved by defining a DISubprogram
@@ -1338,6 +1377,8 @@ Function *CodeExtractor::extractCodeRegi
     newFunction->setDoesNotReturn();
 
   LLVM_DEBUG(if (verifyFunction(*newFunction))
-                 report_fatal_error("verifyFunction failed!"));
+             report_fatal_error("verification of newFunction failed!"));
+  LLVM_DEBUG(if (verifyFunction(*oldFunction))
+             report_fatal_error("verification of oldFunction failed!"));
   return newFunction;
 }

Modified: llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll?rev=348205&r1=348204&r2=348205&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll (original)
+++ llvm/trunk/test/Transforms/HotColdSplit/duplicate-phi-preds-crash.ll Mon Dec  3 14:40:21 2018
@@ -15,9 +15,9 @@ declare void @sink() cold
 ; CHECK: call {{.*}}@sideeffect(
 ; CHECK: call {{.*}}@realloc(
 ; CHECK-LABEL: codeRepl:
-; CHECK-NEXT: call {{.*}}@realloc2.cold.1(i64 %size, i8* %ptr)
+; CHECK-NEXT: call {{.*}}@realloc2.cold.1(i64 %size, i8* %ptr, i8** %retval.0.ce.loc)
 ; CHECK-LABEL: cleanup:
-; CHECK-NEXT: phi i8* [ null, %if.then ], [ null, %codeRepl ], [ %call, %if.end ]
+; CHECK-NEXT: phi i8* [ null, %if.then ], [ %call, %if.end ], [ %retval.0.ce.reload, %codeRepl ]
 define i8* @realloc2(i8* %ptr, i64 %size) {
 entry:
   %0 = add i64 %size, -1

Modified: llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp?rev=348205&r1=348204&r2=348205&view=diff
==============================================================================
--- llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp (original)
+++ llvm/trunk/unittests/Transforms/Utils/CodeExtractorTest.cpp Mon Dec  3 14:40:21 2018
@@ -11,6 +11,7 @@
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
@@ -21,7 +22,14 @@
 using namespace llvm;
 
 namespace {
-TEST(CodeExtractor, DISABLED_ExitStub) {
+BasicBlock *getBlockByName(Function *F, StringRef name) {
+  for (auto &BB : *F)
+    if (BB.getName() == name)
+      return &BB;
+  return nullptr;
+}
+
+TEST(CodeExtractor, ExitStub) {
   LLVMContext Ctx;
   SMDiagnostic Err;
   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
@@ -46,36 +54,10 @@ TEST(CodeExtractor, DISABLED_ExitStub) {
   )invalid",
                                                 Err, Ctx));
 
-  // CodeExtractor miscompiles this function. There appear to be some issues
-  // with the handling of outlined regions with live output values.
-  //
-  // In the original function, CE adds two reloads in the codeReplacer block:
-  //
-  //   codeRepl:                                         ; preds = %header
-  //     call void @foo_header.split(i32 %z, i32 %x, i32 %y, i32* %.loc, i32* %.loc1)
-  //     %.reload = load i32, i32* %.loc
-  //     %.reload2 = load i32, i32* %.loc1
-  //     br label %notExtracted
-  //
-  // These reloads must flow into the notExtracted block:
-  //
-  //   notExtracted:                                     ; preds = %codeRepl
-  //     %0 = phi i32 [ %.reload, %codeRepl ], [ %.reload2, %body2 ]
-  //
-  // The problem is that the PHI node in notExtracted now has an incoming
-  // value from a BasicBlock that's in a different function.
-
   Function *Func = M->getFunction("foo");
-  SmallVector<BasicBlock *, 3> Candidates;
-  for (auto &BB : *Func) {
-    if (BB.getName() == "body1")
-      Candidates.push_back(&BB);
-    if (BB.getName() == "body2")
-      Candidates.push_back(&BB);
-  }
-  // CodeExtractor requires the first basic block
-  // to dominate all the other ones.
-  Candidates.insert(Candidates.begin(), &Func->getEntryBlock());
+  SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
+                                           getBlockByName(Func, "body1"),
+                                           getBlockByName(Func, "body2") };
 
   DominatorTree DT(*Func);
   CodeExtractor CE(Candidates, &DT);
@@ -83,6 +65,66 @@ TEST(CodeExtractor, DISABLED_ExitStub) {
 
   Function *Outlined = CE.extractCodeRegion();
   EXPECT_TRUE(Outlined);
+  BasicBlock *Exit = getBlockByName(Func, "notExtracted");
+  BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
+  // Ensure that PHI in exit block has only one incoming value (from code
+  // replacer block).
+  EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
+  // Ensure that there is a PHI in outlined function with 2 incoming values.
+  EXPECT_TRUE(ExitSplit &&
+              cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
+TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    define i32 @foo() {
+    header:
+      br i1 undef, label %extracted1, label %pred
+
+    pred:
+      br i1 undef, label %exit1, label %exit2
+
+    extracted1:
+      br i1 undef, label %extracted2, label %exit1
+
+    extracted2:
+      br label %exit2
+
+    exit1:
+      %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
+      ret i32 %0
+
+    exit2:
+      %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
+      ret i32 %1
+    }
+  )invalid", Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 2> ExtractedBlocks{
+    getBlockByName(Func, "extracted1"),
+    getBlockByName(Func, "extracted2")
+  };
+
+  DominatorTree DT(*Func);
+  CodeExtractor CE(ExtractedBlocks, &DT);
+  EXPECT_TRUE(CE.isEligible());
+
+  Function *Outlined = CE.extractCodeRegion();
+  EXPECT_TRUE(Outlined);
+  BasicBlock *Exit1 = getBlockByName(Func, "exit1");
+  BasicBlock *Exit2 = getBlockByName(Func, "exit2");
+  // Ensure that PHIs in exits are not splitted (since that they have only one
+  // incoming value from extracted region).
+  EXPECT_TRUE(Exit1 &&
+          cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
+  EXPECT_TRUE(Exit2 &&
+          cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
   EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
 }
 } // end anonymous namespace




More information about the llvm-commits mailing list