[llvm] 66227bf - [Coroutines] ABI Objects to improve code separation between different ABIs, users and utilities. (#109713)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 3 08:30:45 PDT 2024


Author: Tyler Nowicki
Date: 2024-10-03T11:30:42-04:00
New Revision: 66227bf7ee2cd674e0306d9a1e9f1d86bc75123f

URL: https://github.com/llvm/llvm-project/commit/66227bf7ee2cd674e0306d9a1e9f1d86bc75123f
DIFF: https://github.com/llvm/llvm-project/commit/66227bf7ee2cd674e0306d9a1e9f1d86bc75123f.diff

LOG: [Coroutines] ABI Objects to improve code separation between different ABIs, users and utilities. (#109713)

This patch re-lands https://github.com/llvm/llvm-project/pull/109338 and
fixes the various test failures.

--- Original description ---

* Adds an ABI object class hierarchy to implement the coroutine ABIs
(Switch, Asyc, and Retcon{Once})
* The ABI object improves the separation of the code related to users,
ABIs and utilities.
* No code changes are required by any existing users.
* Each ABI overrides delegate methods for initialization, building the
coroutine frame and splitting the coroutine, other methods may be added
later.
* CoroSplit invokes a generator lambda to instantiate the ABI object and
calls the ABI object to carry out its primary operations. In a follow-up
change this will be used to instantiated customized ABIs according to a
new intrinsic llvm.coro.begin.custom.abi.
* Note, in a follow-up change additional constructors will be added to
the ABI objects (Switch, Asyc, and AnyRetcon) to allow a set of
generator lambdas to be passed in from a CoroSplit constructor that will
also be added. The init() method is separate from the constructor to
avoid duplication of its code. It is a virtual method so it can be
invoked by CreateAndInitABI or potentially CoroSplit::run(). I wasn't
sure if we should call init() from within CoroSplit::run(),
CreateAndInitABI, or perhaps hard code a call in each constructor. One
consideration is that init() can change the IR, so perhaps it should
appear in CoroSplit::run(), but this looks a bit odd. Invoking init() in
the constructor would result in the call appearing 4 times per ABI
(after adding the new constructors). Invoking it in CreateAndInitABI
seems to be a balance between these.

See RFC for more info:
https://discourse.llvm.org/t/rfc-abi-objects-for-coroutines/81057

Added: 
    llvm/lib/Transforms/Coroutines/ABI.h

Modified: 
    llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
    llvm/lib/Transforms/Coroutines/CoroFrame.cpp
    llvm/lib/Transforms/Coroutines/CoroInternal.h
    llvm/lib/Transforms/Coroutines/CoroShape.h
    llvm/lib/Transforms/Coroutines/CoroSplit.cpp
    llvm/lib/Transforms/Coroutines/Coroutines.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
index a2be1099ff68fc..74ee948a339e31 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
@@ -21,19 +21,26 @@
 
 namespace llvm {
 
+namespace coro {
+class BaseABI;
+class Shape;
+} // namespace coro
+
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
-  const std::function<bool(Instruction &)> MaterializableCallback;
 
   CoroSplitPass(bool OptimizeFrame = false);
   CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
-                bool OptimizeFrame = false)
-      : MaterializableCallback(MaterializableCallback),
-        OptimizeFrame(OptimizeFrame) {}
+                bool OptimizeFrame = false);
 
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
   static bool isRequired() { return true; }
 
+  using BaseABITy =
+      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
+  // Generator for an ABI transformer
+  BaseABITy CreateAndInitABI;
+
   // Would be true if the Optimization level isn't O0.
   bool OptimizeFrame;
 };

diff  --git a/llvm/lib/Transforms/Coroutines/ABI.h b/llvm/lib/Transforms/Coroutines/ABI.h
new file mode 100644
index 00000000000000..c94bf7d356b650
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/ABI.h
@@ -0,0 +1,102 @@
+//===- ABI.h - Coroutine lowering class definitions (ABIs) ----*- C++ -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file defines coroutine lowering classes. The interface for coroutine
+// lowering is defined by BaseABI. Each lowering method (ABI) implements the
+// interface. Note that the enum class ABI, such as ABI::Switch, determines
+// which ABI class, such as SwitchABI, is used to lower the coroutine. Both the
+// ABI enum and ABI class are used by the Coroutine passes when lowering.
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_TRANSFORMS_COROUTINES_ABI_H
+#define LIB_TRANSFORMS_COROUTINES_ABI_H
+
+#include "CoroShape.h"
+#include "SuspendCrossingInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+
+namespace llvm {
+
+class Function;
+
+namespace coro {
+
+// This interface/API is to provide an object oriented way to implement ABI
+// functionality. This is intended to replace use of the ABI enum to perform
+// ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
+// ABIs.
+
+class LLVM_LIBRARY_VISIBILITY BaseABI {
+public:
+  BaseABI(Function &F, coro::Shape &S,
+          std::function<bool(Instruction &)> IsMaterializable)
+      : F(F), Shape(S), IsMaterializable(IsMaterializable) {}
+  virtual ~BaseABI() = default;
+
+  // Initialize the coroutine ABI
+  virtual void init() = 0;
+
+  // Allocate the coroutine frame and do spill/reload as needed.
+  virtual void buildCoroutineFrame();
+
+  // Perform the function splitting according to the ABI.
+  virtual void splitCoroutine(Function &F, coro::Shape &Shape,
+                              SmallVectorImpl<Function *> &Clones,
+                              TargetTransformInfo &TTI) = 0;
+
+  Function &F;
+  coro::Shape &Shape;
+
+  // Callback used by coro::BaseABI::buildCoroutineFrame for rematerialization.
+  // It is provided to coro::doMaterializations(..).
+  std::function<bool(Instruction &I)> IsMaterializable;
+};
+
+class LLVM_LIBRARY_VISIBILITY SwitchABI : public BaseABI {
+public:
+  SwitchABI(Function &F, coro::Shape &S,
+            std::function<bool(Instruction &)> IsMaterializable)
+      : BaseABI(F, S, IsMaterializable) {}
+
+  void init() override;
+
+  void splitCoroutine(Function &F, coro::Shape &Shape,
+                      SmallVectorImpl<Function *> &Clones,
+                      TargetTransformInfo &TTI) override;
+};
+
+class LLVM_LIBRARY_VISIBILITY AsyncABI : public BaseABI {
+public:
+  AsyncABI(Function &F, coro::Shape &S,
+           std::function<bool(Instruction &)> IsMaterializable)
+      : BaseABI(F, S, IsMaterializable) {}
+
+  void init() override;
+
+  void splitCoroutine(Function &F, coro::Shape &Shape,
+                      SmallVectorImpl<Function *> &Clones,
+                      TargetTransformInfo &TTI) override;
+};
+
+class LLVM_LIBRARY_VISIBILITY AnyRetconABI : public BaseABI {
+public:
+  AnyRetconABI(Function &F, coro::Shape &S,
+               std::function<bool(Instruction &)> IsMaterializable)
+      : BaseABI(F, S, IsMaterializable) {}
+
+  void init() override;
+
+  void splitCoroutine(Function &F, coro::Shape &Shape,
+                      SmallVectorImpl<Function *> &Clones,
+                      TargetTransformInfo &TTI) override;
+};
+
+} // end namespace coro
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_COROUTINES_ABI_H

diff  --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index c08f56b024dfc7..021b1f7a4156b9 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -15,6 +15,7 @@
 // the value into the coroutine frame.
 //===----------------------------------------------------------------------===//
 
+#include "ABI.h"
 #include "CoroInternal.h"
 #include "MaterializationUtils.h"
 #include "SpillUtils.h"
@@ -2055,11 +2056,9 @@ void coro::normalizeCoroutine(Function &F, coro::Shape &Shape,
   rewritePHIs(F);
 }
 
-void coro::buildCoroutineFrame(
-    Function &F, Shape &Shape,
-    const std::function<bool(Instruction &)> &MaterializableCallback) {
+void coro::BaseABI::buildCoroutineFrame() {
   SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
-  doRematerializations(F, Checker, MaterializableCallback);
+  doRematerializations(F, Checker, IsMaterializable);
 
   const DominatorTree DT(F);
   if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon &&

diff  --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index fcbd31878bdea7..88d0f83c98c9ec 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -62,9 +62,6 @@ struct LowererBase {
 bool defaultMaterializable(Instruction &V);
 void normalizeCoroutine(Function &F, coro::Shape &Shape,
                         TargetTransformInfo &TTI);
-void buildCoroutineFrame(
-    Function &F, Shape &Shape,
-    const std::function<bool(Instruction &)> &MaterializableCallback);
 CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
                              TargetTransformInfo &TTI,
                              ArrayRef<Value *> Arguments, IRBuilder<> &);

diff  --git a/llvm/lib/Transforms/Coroutines/CoroShape.h b/llvm/lib/Transforms/Coroutines/CoroShape.h
index dcfe94ca076bd8..f4fb4baa6df314 100644
--- a/llvm/lib/Transforms/Coroutines/CoroShape.h
+++ b/llvm/lib/Transforms/Coroutines/CoroShape.h
@@ -275,7 +275,6 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
       invalidateCoroutine(F, CoroFrames);
       return;
     }
-    initABI();
     cleanCoroutine(CoroFrames, UnusedCoroSaves);
   }
 };

diff  --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 382bdfff1926f7..6fdead9a9c3501 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -19,8 +19,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Coroutines/CoroSplit.h"
+#include "ABI.h"
 #include "CoroInstr.h"
 #include "CoroInternal.h"
+#include "MaterializationUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/PriorityWorklist.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -1779,9 +1781,9 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
   return TailCall;
 }
 
-static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
-                                SmallVectorImpl<Function *> &Clones,
-                                TargetTransformInfo &TTI) {
+void coro::AsyncABI::splitCoroutine(Function &F, coro::Shape &Shape,
+                                    SmallVectorImpl<Function *> &Clones,
+                                    TargetTransformInfo &TTI) {
   assert(Shape.ABI == coro::ABI::Async);
   assert(Clones.empty());
   // Reset various things that the optimizer might have decided it
@@ -1874,9 +1876,9 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
   }
 }
 
-static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
-                                 SmallVectorImpl<Function *> &Clones,
-                                 TargetTransformInfo &TTI) {
+void coro::AnyRetconABI::splitCoroutine(Function &F, coro::Shape &Shape,
+                                        SmallVectorImpl<Function *> &Clones,
+                                        TargetTransformInfo &TTI) {
   assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
   assert(Clones.empty());
 
@@ -2044,26 +2046,27 @@ static bool hasSafeElideCaller(Function &F) {
   return false;
 }
 
-static coro::Shape
-splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
-               TargetTransformInfo &TTI, bool OptimizeFrame,
-               std::function<bool(Instruction &)> MaterializableCallback) {
-  PrettyStackTraceFunction prettyStackTrace(F);
+void coro::SwitchABI::splitCoroutine(Function &F, coro::Shape &Shape,
+                                     SmallVectorImpl<Function *> &Clones,
+                                     TargetTransformInfo &TTI) {
+  SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
+}
 
-  // The suspend-crossing algorithm in buildCoroutineFrame get tripped
-  // up by uses in unreachable blocks, so remove them as a first pass.
-  removeUnreachableBlocks(F);
+static void doSplitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
+                             coro::BaseABI &ABI, TargetTransformInfo &TTI) {
+  PrettyStackTraceFunction prettyStackTrace(F);
 
-  coro::Shape Shape(F, OptimizeFrame);
-  if (!Shape.CoroBegin)
-    return Shape;
+  auto &Shape = ABI.Shape;
+  assert(Shape.CoroBegin);
 
   lowerAwaitSuspends(F, Shape);
 
   simplifySuspendPoints(Shape);
+
   normalizeCoroutine(F, Shape, TTI);
-  buildCoroutineFrame(F, Shape, MaterializableCallback);
+  ABI.buildCoroutineFrame();
   replaceFrameSizeAndAlignment(Shape);
+
   bool isNoSuspendCoroutine = Shape.CoroSuspends.empty();
 
   bool shouldCreateNoAllocVariant = !isNoSuspendCoroutine &&
@@ -2075,18 +2078,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
   if (isNoSuspendCoroutine) {
     handleNoSuspendCoroutine(Shape);
   } else {
-    switch (Shape.ABI) {
-    case coro::ABI::Switch:
-      SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
-      break;
-    case coro::ABI::Async:
-      splitAsyncCoroutine(F, Shape, Clones, TTI);
-      break;
-    case coro::ABI::Retcon:
-    case coro::ABI::RetconOnce:
-      splitRetconCoroutine(F, Shape, Clones, TTI);
-      break;
-    }
+    ABI.splitCoroutine(F, Shape, Clones, TTI);
   }
 
   // Replace all the swifterror operations in the original function.
@@ -2107,8 +2099,6 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
 
   if (shouldCreateNoAllocVariant)
     SwitchCoroutineSplitter::createNoAllocVariant(F, Shape, Clones);
-
-  return Shape;
 }
 
 static LazyCallGraph::SCC &updateCallGraphAfterCoroutineSplit(
@@ -2207,8 +2197,44 @@ static void addPrepareFunction(const Module &M,
     Fns.push_back(PrepareFn);
 }
 
+static std::unique_ptr<coro::BaseABI>
+CreateNewABI(Function &F, coro::Shape &S,
+             std::function<bool(Instruction &)> IsMatCallback) {
+  switch (S.ABI) {
+  case coro::ABI::Switch:
+    return std::unique_ptr<coro::BaseABI>(
+        new coro::SwitchABI(F, S, IsMatCallback));
+  case coro::ABI::Async:
+    return std::unique_ptr<coro::BaseABI>(
+        new coro::AsyncABI(F, S, IsMatCallback));
+  case coro::ABI::Retcon:
+    return std::unique_ptr<coro::BaseABI>(
+        new coro::AnyRetconABI(F, S, IsMatCallback));
+  case coro::ABI::RetconOnce:
+    return std::unique_ptr<coro::BaseABI>(
+        new coro::AnyRetconABI(F, S, IsMatCallback));
+  }
+  llvm_unreachable("Unknown ABI");
+}
+
 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
-    : MaterializableCallback(coro::defaultMaterializable),
+    : CreateAndInitABI([](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, coro::isTriviallyMaterializable);
+        ABI->init();
+        return std::move(ABI);
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+// For back compatibility, constructor takes a materializable callback and
+// creates a generator for an ABI with a modified materializable callback.
+CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
+                             bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
+        ABI->init();
+        return std::move(ABI);
+      }),
       OptimizeFrame(OptimizeFrame) {}
 
 PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
@@ -2241,12 +2267,23 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
     Function &F = N->getFunction();
     LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName()
                       << "\n");
+
+    // The suspend-crossing algorithm in buildCoroutineFrame gets tripped up
+    // by unreachable blocks, so remove them as a first pass. Remove the
+    // unreachable blocks before collecting intrinsics into Shape.
+    removeUnreachableBlocks(F);
+
+    coro::Shape Shape(F, OptimizeFrame);
+    if (!Shape.CoroBegin)
+      continue;
+
     F.setSplittedCoroutine();
 
+    std::unique_ptr<coro::BaseABI> ABI = CreateAndInitABI(F, Shape);
+
     SmallVector<Function *, 4> Clones;
-    coro::Shape Shape =
-        splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
-                       OptimizeFrame, MaterializableCallback);
+    auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
+    doSplitCoroutine(F, Clones, *ABI, TTI);
     CurrentSCC = &updateCallGraphAfterCoroutineSplit(
         *N, Shape, Clones, *CurrentSCC, CG, AM, UR, FAM);
 

diff  --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index 453736912a8c5b..184efbfe903d20 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "ABI.h"
 #include "CoroInstr.h"
 #include "CoroInternal.h"
 #include "CoroShape.h"
@@ -373,11 +374,10 @@ void coro::Shape::invalidateCoroutine(
   }
 }
 
-// Perform semantic checking and initialization of the ABI
-void coro::Shape::initABI() {
-  switch (ABI) {
-  case coro::ABI::Switch: {
-    for (auto *AnySuspend : CoroSuspends) {
+void coro::SwitchABI::init() {
+  assert(Shape.ABI == coro::ABI::Switch);
+  {
+    for (auto *AnySuspend : Shape.CoroSuspends) {
       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
       if (!Suspend) {
 #ifndef NDEBUG
@@ -387,21 +387,22 @@ void coro::Shape::initABI() {
       }
 
       if (!Suspend->getCoroSave())
-        createCoroSave(CoroBegin, Suspend);
+        createCoroSave(Shape.CoroBegin, Suspend);
     }
-    break;
   }
-  case coro::ABI::Async: {
-    break;
-  };
-  case coro::ABI::Retcon:
-  case coro::ABI::RetconOnce: {
+}
+
+void coro::AsyncABI::init() { assert(Shape.ABI == coro::ABI::Async); }
+
+void coro::AnyRetconABI::init() {
+  assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
+  {
     // Determine the result value types, and make sure they match up with
     // the values passed to the suspends.
-    auto ResultTys = getRetconResultTypes();
-    auto ResumeTys = getRetconResumeTypes();
+    auto ResultTys = Shape.getRetconResultTypes();
+    auto ResumeTys = Shape.getRetconResumeTypes();
 
-    for (auto *AnySuspend : CoroSuspends) {
+    for (auto *AnySuspend : Shape.CoroSuspends) {
       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
       if (!Suspend) {
 #ifndef NDEBUG
@@ -428,7 +429,7 @@ void coro::Shape::initABI() {
 
 #ifndef NDEBUG
           Suspend->dump();
-          RetconLowering.ResumePrototype->getFunctionType()->dump();
+          Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
 #endif
           report_fatal_error("argument to coro.suspend.retcon does not "
                              "match corresponding prototype function result");
@@ -437,7 +438,7 @@ void coro::Shape::initABI() {
       if (SI != SE || RI != RE) {
 #ifndef NDEBUG
         Suspend->dump();
-        RetconLowering.ResumePrototype->getFunctionType()->dump();
+        Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
 #endif
         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
       }
@@ -456,7 +457,7 @@ void coro::Shape::initABI() {
       if (SuspendResultTys.size() != ResumeTys.size()) {
 #ifndef NDEBUG
         Suspend->dump();
-        RetconLowering.ResumePrototype->getFunctionType()->dump();
+        Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
 #endif
         report_fatal_error("wrong number of results from coro.suspend.retcon");
       }
@@ -464,15 +465,13 @@ void coro::Shape::initABI() {
         if (SuspendResultTys[I] != ResumeTys[I]) {
 #ifndef NDEBUG
           Suspend->dump();
-          RetconLowering.ResumePrototype->getFunctionType()->dump();
+          Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
 #endif
           report_fatal_error("result from coro.suspend.retcon does not "
                              "match corresponding prototype function param");
         }
       }
     }
-    break;
-  }
   }
 }
 


        


More information about the llvm-commits mailing list