[llvm] [Coroutines] Support for Custom ABIs (PR #111755)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 12:57:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Tyler Nowicki (TylerNowicki)

<details>
<summary>Changes</summary>

This change extends the current method for creating ABI object to allow users (plugin libraries) to create custom ABI objects for their needs. This is accomplished by inheriting one of the common ABIs and overriding one or more of the methods to create a custom ABI. To use a custom ABI for a given coroutine the coro.begin.custom.abi intrinsic is used in place of the coro.begin intrinsic. This takes an additional i32 arg that specifies the index of an ABI generator for the custom ABI object in a SmallVector passed to CoroSplitPass ctor.

The detailed changes include:
* Add the llvm.coro.begin.custom intrinsic used to specify the index of the custom ABI to use for the given coroutine.
* Add constructors to CoroSplit that take a list of generators that create the custom ABI object.
* Extend the CreateNewABI function used by CoroSplit to return a unique_ptr to an ABI object.
* Add has/getCustomABI methods to CoroBeginInst class.
* Add a unittest for a custom ABI.

---
Full diff: https://github.com/llvm/llvm-project/pull/111755.diff


9 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1) 
- (modified) llvm/include/llvm/IR/Intrinsics.td (+2-1) 
- (modified) llvm/include/llvm/Transforms/Coroutines/ABI.h (+7-1) 
- (modified) llvm/include/llvm/Transforms/Coroutines/CoroInstr.h (+15-4) 
- (modified) llvm/include/llvm/Transforms/Coroutines/CoroSplit.h (+11-2) 
- (modified) llvm/lib/Transforms/Coroutines/CoroCleanup.cpp (+3-1) 
- (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+35-3) 
- (modified) llvm/lib/Transforms/Coroutines/Coroutines.cpp (+3-1) 
- (modified) llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (+87) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01a16e7c7b1e59..f6888d001fed69 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -778,6 +778,7 @@ class TargetTransformInfoImplBase {
     case Intrinsic::experimental_gc_relocate:
     case Intrinsic::coro_alloc:
     case Intrinsic::coro_begin:
+    case Intrinsic::coro_begin_custom_abi:
     case Intrinsic::coro_free:
     case Intrinsic::coro_end:
     case Intrinsic::coro_frame:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 20dd921ddbd230..8a0721cf23f538 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1719,7 +1719,8 @@ def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty],
                                        [IntrNoMem]>;
 def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                                [WriteOnly<ArgIndex<1>>]>;
-
+def int_coro_begin_custom_abi : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty, llvm_i32_ty],
+                               [WriteOnly<ArgIndex<1>>]>;
 def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                               [IntrReadMem, IntrArgMemOnly,
                                ReadOnly<ArgIndex<1>>,
diff --git a/llvm/include/llvm/Transforms/Coroutines/ABI.h b/llvm/include/llvm/Transforms/Coroutines/ABI.h
index e7568d275c1615..8b83c5308056eb 100644
--- a/llvm/include/llvm/Transforms/Coroutines/ABI.h
+++ b/llvm/include/llvm/Transforms/Coroutines/ABI.h
@@ -29,7 +29,13 @@ namespace coro {
 // This interface/API is to provide an object oriented way to implement ABI
 // functionality. This is intended to replace use of the ABI enum to perform
 // ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
-// ABIs.
+// ABIs. However, specific users may need to modify the behavior of these. This
+// can be accomplished by inheriting one of the common ABIs and overriding one
+// or more of the methods to create a custom ABI. To use a custom ABI for a
+// given coroutine the coro.begin.custom.abi intrinsic is used in place of the
+// coro.begin intrinsic. This takes an additional i32 arg that specifies the
+// index of an ABI generator for the custom ABI object in a SmallVector passed
+// to CoroSplitPass ctor.
 
 class BaseABI {
 public:
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
index a329a06bf13891..3aa30bec85c3a5 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
@@ -124,7 +124,8 @@ class AnyCoroIdInst : public IntrinsicInst {
   IntrinsicInst *getCoroBegin() {
     for (User *U : users())
       if (auto *II = dyn_cast<IntrinsicInst>(U))
-        if (II->getIntrinsicID() == Intrinsic::coro_begin)
+        if (II->getIntrinsicID() == Intrinsic::coro_begin ||
+            II->getIntrinsicID() == Intrinsic::coro_begin_custom_abi)
           return II;
     llvm_unreachable("no coro.begin associated with coro.id");
   }
@@ -442,20 +443,30 @@ class CoroFreeInst : public IntrinsicInst {
   }
 };
 
-/// This class represents the llvm.coro.begin instructions.
+/// This class represents the llvm.coro.begin or llvm.coro.begin.custom.abi
+/// instructions.
 class CoroBeginInst : public IntrinsicInst {
-  enum { IdArg, MemArg };
+  enum { IdArg, MemArg, CustomABIArg };
 
 public:
   AnyCoroIdInst *getId() const {
     return cast<AnyCoroIdInst>(getArgOperand(IdArg));
   }
 
+  bool hasCustomABI() const {
+    return getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
+  }
+
+  int getCustomABI() const {
+    return cast<ConstantInt>(getArgOperand(CustomABIArg))->getZExtValue();
+  }
+
   Value *getMem() const { return getArgOperand(MemArg); }
 
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::coro_begin;
+    return I->getIntrinsicID() == Intrinsic::coro_begin ||
+           I->getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
   }
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
index a5fd57f8f9dfab..6c6a982e828050 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
@@ -28,17 +28,26 @@ struct Shape;
 } // namespace coro
 
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
+  using BaseABITy =
+      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
 
   CoroSplitPass(bool OptimizeFrame = false);
+
+  CoroSplitPass(SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
                 bool OptimizeFrame = false);
 
+  CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+                SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return true; }
 
-  using BaseABITy =
-      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
   // Generator for an ABI transformer
   BaseABITy CreateAndInitABI;
 
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index dd92b3593af92e..1cda7f93f72a2c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -53,6 +53,7 @@ bool Lowerer::lower(Function &F) {
       default:
         continue;
       case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi:
         II->replaceAllUsesWith(II->getArgOperand(1));
         break;
       case Intrinsic::coro_free:
@@ -112,7 +113,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) {
       M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr",
           "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon",
           "llvm.coro.id.async", "llvm.coro.id.retcon.once",
-          "llvm.coro.async.size.replace", "llvm.coro.async.resume"});
+          "llvm.coro.async.size.replace", "llvm.coro.async.resume",
+          "llvm.coro.begin.custom.abi"});
 }
 
 PreservedAnalyses CoroCleanupPass::run(Module &M,
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index ef1f27118bc14b..88ce331c8cfb64 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2200,7 +2200,15 @@ static void addPrepareFunction(const Module &M,
 
 static std::unique_ptr<coro::BaseABI>
 CreateNewABI(Function &F, coro::Shape &S,
-             std::function<bool(Instruction &)> IsMatCallback) {
+             std::function<bool(Instruction &)> IsMatCallback,
+             const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs) {
+  if (S.CoroBegin->hasCustomABI()) {
+    unsigned CustomABI = S.CoroBegin->getCustomABI();
+    if (CustomABI >= GenCustomABIs.size())
+      llvm_unreachable("Custom ABI not found amoung those specified");
+    return GenCustomABIs[CustomABI](F, S);
+  }
+
   switch (S.ABI) {
   case coro::ABI::Switch:
     return std::unique_ptr<coro::BaseABI>(
@@ -2221,7 +2229,17 @@ CreateNewABI(Function &F, coro::Shape &S,
 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
     : CreateAndInitABI([](Function &F, coro::Shape &S) {
         std::unique_ptr<coro::BaseABI> ABI =
-            CreateNewABI(F, S, coro::isTriviallyMaterializable);
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+CoroSplitPass::CoroSplitPass(
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
@@ -2232,7 +2250,21 @@ CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
 CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
                              bool OptimizeFrame)
     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
-        std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+// For back compatibility, constructor takes a materializable callback and
+// creates a generator for an ABI with a modified materializable callback.
+CoroSplitPass::CoroSplitPass(
+    std::function<bool(Instruction &)> IsMatCallback,
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index f4d9a7a8aa8569..1c45bcd7f6a837 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -73,6 +73,7 @@ static const char *const CoroIntrinsics[] = {
     "llvm.coro.await.suspend.handle",
     "llvm.coro.await.suspend.void",
     "llvm.coro.begin",
+    "llvm.coro.begin.custom.abi",
     "llvm.coro.destroy",
     "llvm.coro.done",
     "llvm.coro.end",
@@ -247,7 +248,8 @@ void coro::Shape::analyze(Function &F,
         }
         break;
       }
-      case Intrinsic::coro_begin: {
+      case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi: {
         auto CB = cast<CoroBeginInst>(II);
 
         // Ignore coro id's that aren't pre-split.
diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
index 1d55889a32d7aa..c3394fdaa940ba 100644
--- a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
+++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
@@ -182,4 +182,91 @@ TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
   CallInst *CI = getCallByName(Resume1, "should.remat");
   ASSERT_TRUE(CI);
 }
+
+StringRef TextCoroBeginCustomABI = R"(
+    define ptr @f(i32 %n) presplitcoroutine {
+    entry:
+      %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+      %size = call i32 @llvm.coro.size.i32()
+      %alloc = call ptr @malloc(i32 %size)
+      %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
+
+      %inc1 = add i32 %n, 1
+      %val2 = call i32 @should.remat(i32 %inc1)
+      %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume1
+                                      i8 1, label %cleanup]
+    resume1:
+      %inc2 = add i32 %val2, 1
+      %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume2
+                                      i8 1, label %cleanup]
+
+    resume2:
+      call void @print(i32 %val2)
+      call void @print(i32 %inc2)
+      br label %cleanup
+
+    cleanup:
+      %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+      call void @free(ptr %mem)
+      br label %suspend
+    suspend:
+      call i1 @llvm.coro.end(ptr %hdl, i1 0)
+      ret ptr %hdl
+    }
+
+    declare ptr @llvm.coro.free(token, ptr)
+    declare i32 @llvm.coro.size.i32()
+    declare i8  @llvm.coro.suspend(token, i1)
+    declare void @llvm.coro.resume(ptr)
+    declare void @llvm.coro.destroy(ptr)
+
+    declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+    declare i1 @llvm.coro.alloc(token)
+    declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
+    declare i1 @llvm.coro.end(ptr, i1)
+
+    declare i32 @should.remat(i32)
+
+    declare noalias ptr @malloc(i32)
+    declare void @print(i32)
+    declare void @free(ptr)
+  )";
+
+// SwitchABI with overridden isMaterializable
+class ExtraCustomABI : public coro::SwitchABI {
+public:
+  ExtraCustomABI(Function &F, coro::Shape &S)
+      : coro::SwitchABI(F, S, ExtraMaterializable) {}
+};
+
+TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
+  ParseAssembly(TextCoroBeginCustomABI);
+
+  ASSERT_TRUE(M);
+
+  CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
+    return std::unique_ptr<coro::BaseABI>(new ExtraCustomABI(F, S));
+  };
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(CoroSplitPass({GenCustomABI}));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With callback the extra rematerialization of the function should have
+  // happened
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_TRUE(CI);
+}
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/111755


More information about the llvm-commits mailing list