[llvm] [Coroutines] Verify normalization was not missed (PR #108096)

Tyler Nowicki via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 14:37:21 PDT 2024


https://github.com/TylerNowicki created https://github.com/llvm/llvm-project/pull/108096

* Add asserts to verify normalization of the coroutine happened.
* This will be important if/when normalization becomes part of an ABI object.

>From 34559ad06f82f7d1b48f7982e7a64944d72204b4 Mon Sep 17 00:00:00 2001
From: tnowicki <tnowicki.nowicki at amd.com>
Date: Fri, 23 Aug 2024 13:13:20 -0400
Subject: [PATCH 1/2] [Coroutines] Split buildCoroutineFrame

* Split buildCoroutineFrame into code related to normalization and code
  related to actually building the coroutine frame.
* This will enable future specialization of buildCoroutineFrame for
  different ABIs.
---
 llvm/lib/Transforms/Coroutines/CoroFrame.cpp  | 14 ++++++++------
 llvm/lib/Transforms/Coroutines/CoroInternal.h |  4 +++-
 llvm/lib/Transforms/Coroutines/CoroSplit.cpp  |  3 ++-
 3 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index 4b76fc79361008..8ee4bfa3b888df 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -1754,7 +1754,8 @@ static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB,
   if (depth == 0) return false;
 
   // If this is a suspend block, we're about to exit the resumption function.
-  if (isSuspendBlock(BB)) return true;
+  if (isSuspendBlock(BB))
+    return true;
 
   // Recurse into the successors.
   for (auto *Succ : successors(BB)) {
@@ -2288,9 +2289,8 @@ static void doRematerializations(
   rewriteMaterializableInstructions(AllRemats);
 }
 
-void coro::buildCoroutineFrame(
-    Function &F, Shape &Shape, TargetTransformInfo &TTI,
-    const std::function<bool(Instruction &)> &MaterializableCallback) {
+void coro::normalizeCoroutine(Function &F, coro::Shape &Shape,
+                              TargetTransformInfo &TTI) {
   // Don't eliminate swifterror in async functions that won't be split.
   if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
     eliminateSwiftError(F, Shape);
@@ -2337,10 +2337,12 @@ void coro::buildCoroutineFrame(
   // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will
   // never have its definition separated from the PHI by the suspend point.
   rewritePHIs(F);
+}
 
-  // Build suspend crossing info.
+void coro::buildCoroutineFrame(
+    Function &F, Shape &Shape,
+    const std::function<bool(Instruction &)> &MaterializableCallback) {
   SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
-
   doRematerializations(F, Checker, MaterializableCallback);
 
   const DominatorTree DT(F);
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index be86f96525b677..698c21a797420a 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -281,8 +281,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
 };
 
 bool defaultMaterializable(Instruction &V);
+void normalizeCoroutine(Function &F, coro::Shape &Shape,
+                        TargetTransformInfo &TTI);
 void buildCoroutineFrame(
-    Function &F, Shape &Shape, TargetTransformInfo &TTI,
+    Function &F, Shape &Shape,
     const std::function<bool(Instruction &)> &MaterializableCallback);
 CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
                              TargetTransformInfo &TTI,
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 494c4d632de95f..dc3829d7f28eb1 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2030,7 +2030,8 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
   lowerAwaitSuspends(F, Shape);
 
   simplifySuspendPoints(Shape);
-  buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
+  normalizeCoroutine(F, Shape, TTI);
+  buildCoroutineFrame(F, Shape, MaterializableCallback);
   replaceFrameSizeAndAlignment(Shape);
   bool isNoSuspendCoroutine = Shape.CoroSuspends.empty();
 

>From fef22013e7a8864117d008df671d6926bffbf12f Mon Sep 17 00:00:00 2001
From: tnowicki <tnowicki.nowicki at amd.com>
Date: Tue, 10 Sep 2024 17:32:24 -0400
Subject: [PATCH 2/2] [Coroutines] Verify normalization was not missed

* Add asserts to verify normalization of the coroutine happened.
* This will be important if/when normalization becomes part of an ABI
  object.
---
 llvm/lib/Transforms/Coroutines/CoroFrame.cpp         |  2 +-
 .../Transforms/Coroutines/SuspendCrossingInfo.cpp    | 12 +++++++++++-
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index 8ee4bfa3b888df..b792835159a8a3 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -1284,7 +1284,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) {
 
       // If we have a single edge PHINode, remove it and replace it with a
       // reload from the coroutine frame. (We already took care of multi edge
-      // PHINodes by rewriting them in the rewritePHIs function).
+      // PHINodes by normalizing them in the rewritePHIs function).
       if (auto *PN = dyn_cast<PHINode>(U)) {
         assert(PN->getNumIncomingValues() == 1 &&
                "unexpected number of incoming "
diff --git a/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
index 6b0dc126d5471a..84699e653db602 100644
--- a/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
+++ b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
@@ -165,8 +165,13 @@ SuspendCrossingInfo::SuspendCrossingInfo(
   // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
   // the code beyond coro.end is reachable during initial invocation of the
   // coroutine.
-  for (auto *CE : CoroEnds)
+  for (auto *CE : CoroEnds) {
+    // Verify CoroEnd was normalized
+    assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() &&
+           CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB");
+
     getBlockData(CE->getParent()).End = true;
+  }
 
   // Mark all suspend blocks and indicate that they kill everything they
   // consume. Note, that crossing coro.save also requires a spill, as any code
@@ -179,6 +184,11 @@ SuspendCrossingInfo::SuspendCrossingInfo(
     B.Kills |= B.Consumes;
   };
   for (auto *CSI : CoroSuspends) {
+    // Verify CoroSuspend was normalized
+    assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() &&
+           CSI->getParent()->size() <= 2 &&
+           "CoroSuspend must be in its own BB");
+
     markSuspendBlock(CSI);
     if (auto *Save = CSI->getCoroSave())
       markSuspendBlock(Save);



More information about the llvm-commits mailing list