[llvm] 403772f - [Coroutines] Enhance symmetric transfer for constant CmpInst

Chuanqi Xu via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 11 18:15:27 PST 2022


Author: Chuanqi Xu
Date: 2022-01-12T10:14:37+08:00
New Revision: 403772ff1ce5618c8d02316531386b415312274a

URL: https://github.com/llvm/llvm-project/commit/403772ff1ce5618c8d02316531386b415312274a
DIFF: https://github.com/llvm/llvm-project/commit/403772ff1ce5618c8d02316531386b415312274a.diff

LOG: [Coroutines] Enhance symmetric transfer for constant CmpInst

This fixes bug52896.

Simply, some symmetric transfer optimization chances get invalided due
to we delete some inlined optimization passes in 822b92a. This would
cause stack-overflow in some situations which should be avoided by the
design of coroutine. This patch tries to fix this by transforming the
constant CmpInst instruction which was done in the deleted passes.

Reviewed By: rjmccall, junparser

Differential Revision: https://reviews.llvm.org/D116327

Added: 
    

Modified: 
    llvm/lib/Transforms/Coroutines/CoroSplit.cpp
    llvm/test/Transforms/Coroutines/coro-split-musttail4.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 960cbe9ea4f0f..8ac0b4f9636aa 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
@@ -1197,6 +1198,15 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
   DenseMap<Value *, Value *> ResolvedValues;
   BasicBlock *UnconditionalSucc = nullptr;
+  assert(InitialInst->getModule());
+  const DataLayout &DL = InitialInst->getModule()->getDataLayout();
+
+  auto TryResolveConstant = [&ResolvedValues](Value *V) {
+    auto It = ResolvedValues.find(V);
+    if (It != ResolvedValues.end())
+      V = It->second;
+    return dyn_cast<ConstantInt>(V);
+  };
 
   Instruction *I = InitialInst;
   while (I->isTerminator() ||
@@ -1213,47 +1223,65 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
     }
     if (auto *BR = dyn_cast<BranchInst>(I)) {
       if (BR->isUnconditional()) {
-        BasicBlock *BB = BR->getSuccessor(0);
+        BasicBlock *Succ = BR->getSuccessor(0);
         if (I == InitialInst)
-          UnconditionalSucc = BB;
-        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-        I = BB->getFirstNonPHIOrDbgOrLifetime();
+          UnconditionalSucc = Succ;
+        scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
+        I = Succ->getFirstNonPHIOrDbgOrLifetime();
+        continue;
+      }
+
+      BasicBlock *BB = BR->getParent();
+      // Handle the case the condition of the conditional branch is constant.
+      // e.g.,
+      //
+      //     br i1 false, label %cleanup, label %CoroEnd
+      //
+      // It is possible during the transformation. We could continue the
+      // simplifying in this case.
+      if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
+        // Handle this branch in next iteration.
+        I = BB->getTerminator();
         continue;
       }
     } else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
+      // If the case number of suspended switch instruction is reduced to
+      // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
       auto *BR = dyn_cast<BranchInst>(I->getNextNode());
-      if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
-        // If the case number of suspended switch instruction is reduced to
-        // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
-        // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
-        ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
-        if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
-          Value *V = CondCmp->getOperand(0);
-          auto it = ResolvedValues.find(V);
-          if (it != ResolvedValues.end())
-            V = it->second;
-
-          if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
-            BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
-                                 ? BR->getSuccessor(0)
-                                 : BR->getSuccessor(1);
-            scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-            I = BB->getFirstNonPHIOrDbgOrLifetime();
-            continue;
-          }
-        }
-      }
+      if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
+        return false;
+
+      // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
+      // So we try to resolve constant for the first operand only since the
+      // second operand should be literal constant by design.
+      ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
+      auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
+      if (!Cond0 || !Cond1)
+        return false;
+
+      // Both operands of the CmpInst are Constant. So that we could evaluate
+      // it immediately to get the destination.
+      auto *ConstResult =
+          dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
+              CondCmp->getPredicate(), Cond0, Cond1, DL));
+      if (!ConstResult)
+        return false;
+
+      CondCmp->replaceAllUsesWith(ConstResult);
+      CondCmp->eraseFromParent();
+
+      // Handle this branch in next iteration.
+      I = BR;
+      continue;
     } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
-      Value *V = SI->getCondition();
-      auto it = ResolvedValues.find(V);
-      if (it != ResolvedValues.end())
-        V = it->second;
-      if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
-        BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
-        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-        I = BB->getFirstNonPHIOrDbgOrLifetime();
-        continue;
-      }
+      ConstantInt *Cond = TryResolveConstant(SI->getCondition());
+      if (!Cond)
+        return false;
+
+      BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
+      scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
+      I = BB->getFirstNonPHIOrDbgOrLifetime();
+      continue;
     }
     return false;
   }

diff  --git a/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
index 9fd8017996206..0d73d94a93f58 100644
--- a/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
+++ b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
@@ -42,9 +42,9 @@ coro.end:
   ret void
 }
 
-; FIXME: The fakerresume1 here should be musttail call.
 ; CHECK-LABEL: @f.resume(
-; CHECK-NOT: musttail call fastcc void @fakeresume1(
+; CHECK:          musttail call fastcc void @fakeresume1(
+; CHECK-NEXT:     ret void
 
 declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
 declare i1 @llvm.coro.alloc(token) #2


        


More information about the llvm-commits mailing list