[llvm-branch-commits] [llvm] b9a243d - [Coroutines] Enhance symmetric transfer for constant CmpInst

Tom Stellard via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 12 12:07:14 PST 2022


Author: Chuanqi Xu
Date: 2022-01-12T12:06:51-08:00
New Revision: b9a243d1cac2aae0596dc5a1f097e1adc1fe5102

URL: https://github.com/llvm/llvm-project/commit/b9a243d1cac2aae0596dc5a1f097e1adc1fe5102
DIFF: https://github.com/llvm/llvm-project/commit/b9a243d1cac2aae0596dc5a1f097e1adc1fe5102.diff

LOG: [Coroutines] Enhance symmetric transfer for constant CmpInst

This fixes bug52896.

Simply, some symmetric transfer optimization chances get invalided due
to we delete some inlined optimization passes in 822b92a. This would
cause stack-overflow in some situations which should be avoided by the
design of coroutine. This patch tries to fix this by transforming the
constant CmpInst instruction which was done in the deleted passes.

Reviewed By: rjmccall, junparser

Differential Revision: https://reviews.llvm.org/D116327

(cherry picked from commit 403772ff1ce5618c8d02316531386b415312274a)

Added: 
    llvm/test/Transforms/Coroutines/coro-split-musttail4.ll

Modified: 
    llvm/lib/Transforms/Coroutines/CoroSplit.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index b6932dbbfc3f..fc83befe3950 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
@@ -1174,6 +1175,15 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
   DenseMap<Value *, Value *> ResolvedValues;
   BasicBlock *UnconditionalSucc = nullptr;
+  assert(InitialInst->getModule());
+  const DataLayout &DL = InitialInst->getModule()->getDataLayout();
+
+  auto TryResolveConstant = [&ResolvedValues](Value *V) {
+    auto It = ResolvedValues.find(V);
+    if (It != ResolvedValues.end())
+      V = It->second;
+    return dyn_cast<ConstantInt>(V);
+  };
 
   Instruction *I = InitialInst;
   while (I->isTerminator() ||
@@ -1190,47 +1200,65 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
     }
     if (auto *BR = dyn_cast<BranchInst>(I)) {
       if (BR->isUnconditional()) {
-        BasicBlock *BB = BR->getSuccessor(0);
+        BasicBlock *Succ = BR->getSuccessor(0);
         if (I == InitialInst)
-          UnconditionalSucc = BB;
-        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-        I = BB->getFirstNonPHIOrDbgOrLifetime();
+          UnconditionalSucc = Succ;
+        scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
+        I = Succ->getFirstNonPHIOrDbgOrLifetime();
+        continue;
+      }
+
+      BasicBlock *BB = BR->getParent();
+      // Handle the case the condition of the conditional branch is constant.
+      // e.g.,
+      //
+      //     br i1 false, label %cleanup, label %CoroEnd
+      //
+      // It is possible during the transformation. We could continue the
+      // simplifying in this case.
+      if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
+        // Handle this branch in next iteration.
+        I = BB->getTerminator();
         continue;
       }
     } else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
+      // If the case number of suspended switch instruction is reduced to
+      // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
       auto *BR = dyn_cast<BranchInst>(I->getNextNode());
-      if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
-        // If the case number of suspended switch instruction is reduced to
-        // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
-        // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
-        ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
-        if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
-          Value *V = CondCmp->getOperand(0);
-          auto it = ResolvedValues.find(V);
-          if (it != ResolvedValues.end())
-            V = it->second;
-
-          if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
-            BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
-                                 ? BR->getSuccessor(0)
-                                 : BR->getSuccessor(1);
-            scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-            I = BB->getFirstNonPHIOrDbgOrLifetime();
-            continue;
-          }
-        }
-      }
+      if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
+        return false;
+
+      // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
+      // So we try to resolve constant for the first operand only since the
+      // second operand should be literal constant by design.
+      ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
+      auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
+      if (!Cond0 || !Cond1)
+        return false;
+
+      // Both operands of the CmpInst are Constant. So that we could evaluate
+      // it immediately to get the destination.
+      auto *ConstResult =
+          dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
+              CondCmp->getPredicate(), Cond0, Cond1, DL));
+      if (!ConstResult)
+        return false;
+
+      CondCmp->replaceAllUsesWith(ConstResult);
+      CondCmp->eraseFromParent();
+
+      // Handle this branch in next iteration.
+      I = BR;
+      continue;
     } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
-      Value *V = SI->getCondition();
-      auto it = ResolvedValues.find(V);
-      if (it != ResolvedValues.end())
-        V = it->second;
-      if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
-        BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
-        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-        I = BB->getFirstNonPHIOrDbgOrLifetime();
-        continue;
-      }
+      ConstantInt *Cond = TryResolveConstant(SI->getCondition());
+      if (!Cond)
+        return false;
+
+      BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
+      scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
+      I = BB->getFirstNonPHIOrDbgOrLifetime();
+      continue;
     }
     return false;
   }

diff  --git a/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
new file mode 100644
index 000000000000..0d73d94a93f5
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
@@ -0,0 +1,65 @@
+; Tests that coro-split will convert a call before coro.suspend to a musttail call
+; while the user of the coro.suspend is a icmpinst.
+; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
+
+define void @fakeresume1(i8*)  {
+entry:
+  ret void;
+}
+
+define void @f() #0 {
+entry:
+  %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
+  %alloc = call i8* @malloc(i64 16) #3
+  %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)
+
+  %save = call token @llvm.coro.save(i8* null)
+
+  %init_suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %init_suspend, label %coro.end [
+    i8 0, label %await.ready
+    i8 1, label %coro.end
+  ]
+await.ready:
+  %save2 = call token @llvm.coro.save(i8* null)
+
+  call fastcc void @fakeresume1(i8* align 8 null)
+  %suspend = call i8 @llvm.coro.suspend(token %save2, i1 true)
+  %switch = icmp ult i8 %suspend, 2
+  br i1 %switch, label %cleanup, label %coro.end
+
+cleanup:
+  %free.handle = call i8* @llvm.coro.free(token %id, i8* %vFrame)
+  %.not = icmp eq i8* %free.handle, null
+  br i1 %.not, label %coro.end, label %coro.free
+
+coro.free:
+  call void @delete(i8* nonnull %free.handle) #2
+  br label %coro.end
+
+coro.end:
+  call i1 @llvm.coro.end(i8* null, i1 false)
+  ret void
+}
+
+; CHECK-LABEL: @f.resume(
+; CHECK:          musttail call fastcc void @fakeresume1(
+; CHECK-NEXT:     ret void
+
+declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
+declare i1 @llvm.coro.alloc(token) #2
+declare i64 @llvm.coro.size.i64() #3
+declare i8* @llvm.coro.begin(token, i8* writeonly) #2
+declare token @llvm.coro.save(i8*) #2
+declare i8* @llvm.coro.frame() #3
+declare i8 @llvm.coro.suspend(token, i1) #2
+declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1
+declare i1 @llvm.coro.end(i8*, i1) #2
+declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1
+declare i8* @malloc(i64)
+declare void @delete(i8* nonnull) #2
+
+attributes #0 = { "coroutine.presplit"="1" }
+attributes #1 = { argmemonly nounwind readonly }
+attributes #2 = { nounwind }
+attributes #3 = { nounwind readnone }


        


More information about the llvm-branch-commits mailing list