[llvm] 144cd22 - [CodeExtractor] Creating exit stubs based off original order branch instructions.

Andrew Litteken via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 8 15:15:43 PDT 2021


Author: Andrew Litteken
Date: 2021-09-08T15:15:15-07:00
New Revision: 144cd22baef2d068f077e514de1b4f0d8b0973cf

URL: https://github.com/llvm/llvm-project/commit/144cd22baef2d068f077e514de1b4f0d8b0973cf
DIFF: https://github.com/llvm/llvm-project/commit/144cd22baef2d068f077e514de1b4f0d8b0973cf.diff

LOG: [CodeExtractor] Creating exit stubs based off original order branch instructions.

Previously the CodeExtractor created exit stubs, and the subsequent return value of the outlined function based on the order of out-of-region blocks after splitting any phi nodes, and collecting the blocks to be outlined. This could cause differences in order if there was a difference of exit block phi nodes between the two regions. This patch moves the collection of the output target blocks to be before this occurs, so that the assignment of target block to output value will be the same, regardless of the contents of the output block.

Reviewers: paquette, roelofs

Differential Revision: https://reviews.llvm.org/D108657

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/CodeExtractor.h
    llvm/lib/Transforms/Utils/CodeExtractor.cpp
    llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 0205f23d7040f..f08173e45a5bf 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -100,6 +100,10 @@ class CodeExtractorAnalysisCache {
     unsigned NumExitBlocks = std::numeric_limits<unsigned>::max();
     Type *RetTy;
 
+    // Mapping from the original exit blocks, to the new blocks inside
+    // the function.
+    SmallVector<BasicBlock *, 4> OldTargets;
+
     // Suffix to use when creating extracted function (appended to the original
     // function name + "."). If empty, the default is to use the entry block
     // label, if non-empty, otherwise "extracted".

diff  --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index e94dab18b9c08..8bd09198ee745 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -434,6 +434,7 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
   }
   // Now add the old exit block to the outline region.
   Blocks.insert(CommonExitBlock);
+  OldTargets.push_back(NewExitBlock);
   return CommonExitBlock;
 }
 
@@ -1248,45 +1249,57 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
   // not in the region to be extracted.
   std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
 
+  // Iterate over the previously collected targets, and create new blocks inside
+  // the function to branch to.
   unsigned switchVal = 0;
+  for (BasicBlock *OldTarget : OldTargets) {
+    if (Blocks.count(OldTarget))
+      continue;
+    BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
+    if (NewTarget)
+      continue;
+
+    // If we don't already have an exit stub for this non-extracted
+    // destination, create one now!
+    NewTarget = BasicBlock::Create(Context,
+                                    OldTarget->getName() + ".exitStub",
+                                    newFunction);
+    unsigned SuccNum = switchVal++;
+
+    Value *brVal = nullptr;
+    assert(NumExitBlocks < 0xffff && "too many exit blocks for switch");
+    switch (NumExitBlocks) {
+    case 0:
+    case 1: break;  // No value needed.
+    case 2:         // Conditional branch, return a bool
+      brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
+      break;
+    default:
+      brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
+      break;
+    }
+
+    ReturnInst::Create(Context, brVal, NewTarget);
+
+    // Update the switch instruction.
+    TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
+                                        SuccNum),
+                        OldTarget);
+  }
+
   for (BasicBlock *Block : Blocks) {
     Instruction *TI = Block->getTerminator();
-    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
-      if (!Blocks.count(TI->getSuccessor(i))) {
-        BasicBlock *OldTarget = TI->getSuccessor(i);
-        // add a new basic block which returns the appropriate value
-        BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
-        if (!NewTarget) {
-          // If we don't already have an exit stub for this non-extracted
-          // destination, create one now!
-          NewTarget = BasicBlock::Create(Context,
-                                         OldTarget->getName() + ".exitStub",
-                                         newFunction);
-          unsigned SuccNum = switchVal++;
-
-          Value *brVal = nullptr;
-          switch (NumExitBlocks) {
-          case 0:
-          case 1: break;  // No value needed.
-          case 2:         // Conditional branch, return a bool
-            brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
-            break;
-          default:
-            brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
-            break;
-          }
-
-          ReturnInst::Create(Context, brVal, NewTarget);
-
-          // Update the switch instruction.
-          TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
-                                              SuccNum),
-                             OldTarget);
-        }
-
-        // rewrite the original branch instruction with this new target
-        TI->setSuccessor(i, NewTarget);
-      }
+    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
+      if (Blocks.count(TI->getSuccessor(i)))
+        continue;
+      BasicBlock *OldTarget = TI->getSuccessor(i);
+      // add a new basic block which returns the appropriate value
+      BasicBlock *NewTarget = ExitBlockMap[OldTarget];
+      assert(NewTarget && "Unknown target block!");
+
+      // rewrite the original branch instruction with this new target
+      TI->setSuccessor(i, NewTarget);
+   }
   }
 
   // Store the arguments right after the definition of output value.
@@ -1640,6 +1653,16 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
   }
   NumExitBlocks = ExitBlocks.size();
 
+  for (BasicBlock *Block : Blocks) {
+    Instruction *TI = Block->getTerminator();
+    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
+      if (Blocks.count(TI->getSuccessor(i)))
+        continue;
+      BasicBlock *OldTarget = TI->getSuccessor(i);
+      OldTargets.push_back(OldTarget);
+    }
+  }
+
   // If we have to split PHI nodes of the entry or exit blocks, do so now.
   severSplitPHINodesOfEntry(header);
   severSplitPHINodesOfExits(ExitBlocks);

diff  --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 093bd980e9356..5f7b0111c1c62 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -135,6 +135,121 @@ TEST(CodeExtractor, InputOutputMonitoring) {
   EXPECT_FALSE(verifyFunction(*Func));
 }
 
+TEST(CodeExtractor, ExitBlockOrderingPhis) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    define void @foo(i32 %a, i32 %b) {
+    entry:
+      %0 = alloca i32, align 4
+      br label %test0
+    test0:
+      %c = load i32, i32* %0, align 4
+      br label %test1
+    test1:
+      %e = load i32, i32* %0, align 4
+      br i1 true, label %first, label %test
+    test:
+      %d = load i32, i32* %0, align 4
+      br i1 true, label %first, label %next
+    first:
+      %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
+      ret void
+    next:
+      %2 = add i32 %d, 1
+      %3 = add i32 %e, 1
+      ret void
+    }
+  )invalid",
+                                                Err, Ctx));
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
+                                           getBlockByName(Func, "test1"),
+                                           getBlockByName(Func, "test") };
+
+  CodeExtractor CE(Candidates);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  Function *Outlined = CE.extractCodeRegion(CEAC);
+  EXPECT_TRUE(Outlined);
+
+  BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
+  BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
+
+  Instruction *FirstTerm = FirstExitStub->getTerminator();
+  ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
+  EXPECT_TRUE(FirstReturn);
+  ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
+  EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
+
+  Instruction *NextTerm = NextExitStub->getTerminator();
+  ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
+  EXPECT_TRUE(NextReturn);
+  ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
+  EXPECT_TRUE(CINext->getLimitedValue() == 0u);
+  
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
+TEST(CodeExtractor, ExitBlockOrdering) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    define void @foo(i32 %a, i32 %b) {
+    entry:
+      %0 = alloca i32, align 4
+      br label %test0
+    test0:
+      %c = load i32, i32* %0, align 4
+      br label %test1
+    test1:
+      %e = load i32, i32* %0, align 4
+      br i1 true, label %first, label %test
+    test:
+      %d = load i32, i32* %0, align 4
+      br i1 true, label %first, label %next
+    first:
+      ret void
+    next:
+      %1 = add i32 %d, 1
+      %2 = add i32 %e, 1
+      ret void
+    }
+  )invalid",
+                                                Err, Ctx));
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
+                                           getBlockByName(Func, "test1"),
+                                           getBlockByName(Func, "test") };
+
+  CodeExtractor CE(Candidates);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  Function *Outlined = CE.extractCodeRegion(CEAC);
+  EXPECT_TRUE(Outlined);
+
+  BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
+  BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
+
+  Instruction *FirstTerm = FirstExitStub->getTerminator();
+  ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
+  EXPECT_TRUE(FirstReturn);
+  ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
+  EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
+
+  Instruction *NextTerm = NextExitStub->getTerminator();
+  ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
+  EXPECT_TRUE(NextReturn);
+  ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
+  EXPECT_TRUE(CINext->getLimitedValue() == 0u);
+  
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
   LLVMContext Ctx;
   SMDiagnostic Err;


        


More information about the llvm-commits mailing list