[Mlir-commits] [mlir] 6cca6b9 - Add async_funcs_only option to AsyncToAsyncRuntime pass

Eugene Zhulenev llvmlistbot at llvm.org
Wed Nov 30 10:27:05 PST 2022


Author: yijiagu
Date: 2022-11-30T10:27:00-08:00
New Revision: 6cca6b9ab9ff27d22cdb4b9b4aa98ec91b6b08d8

URL: https://github.com/llvm/llvm-project/commit/6cca6b9ab9ff27d22cdb4b9b4aa98ec91b6b08d8
DIFF: https://github.com/llvm/llvm-project/commit/6cca6b9ab9ff27d22cdb4b9b4aa98ec91b6b08d8.diff

LOG: Add async_funcs_only option to AsyncToAsyncRuntime pass

This change adds async_funcs_only option to AsyncToAsyncRuntimePass. The goal is to convert async functions to regular functions in early stages of compilation pipeline.

Differential Revision: https://reviews.llvm.org/D138611

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/test/Dialect/Async/async-to-async-runtime.mlir
    mlir/test/mlir-cpu-runner/async-func.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 78216125ab91b..090768cd0209c 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -17,6 +17,7 @@
 
 namespace mlir {
 class ModuleOp;
+class ConversionTarget;
 
 #define GEN_PASS_DECL
 #include "mlir/Dialect/Async/Passes.h.inc"
@@ -27,6 +28,11 @@ std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
                                                  int32_t numWorkerThreads,
                                                  int32_t minTaskSize);
 
+void populateAsyncFuncToAsyncRuntimeConversionPatterns(
+    RewritePatternSet &patterns, ConversionTarget &target);
+
+std::unique_ptr<OperationPass<ModuleOp>> createAsyncFuncToAsyncRuntimePass();
+
 std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
 
 std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index aed5b4ff7865a..8b579dd591604 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -41,12 +41,19 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
 }
 
 def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
-  let summary = "Lower high level async operations (e.g. async.execute) to the"
-                "explicit async.runtime and async.coro operations";
+  let summary = "Lower all high level async operations (e.g. async.execute) to"
+                "the explicit async.runtime and async.coro operations";
   let constructor = "mlir::createAsyncToAsyncRuntimePass()";
   let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
 }
 
+def AsyncFuncToAsyncRuntime : Pass<"async-func-to-async-runtime", "ModuleOp"> {
+  let summary = "Lower async.func operations to the explicit async.runtime and"
+                "async.coro operations";
+  let constructor = "mlir::createAsyncFuncToAsyncRuntimePass()";
+  let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
+}
+
 def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
   let summary = "Automatic reference counting for Async runtime operations";
   let description = [{

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 66c5b731b6e76..e7859488b6b4e 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -30,6 +30,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
+#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
 #include "mlir/Dialect/Async/Passes.h.inc"
 } // namespace mlir
 
@@ -51,6 +52,17 @@ class AsyncToAsyncRuntimePass
 
 } // namespace
 
+namespace {
+
+class AsyncFuncToAsyncRuntimePass
+    : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
+public:
+  AsyncFuncToAsyncRuntimePass() = default;
+  void runOnOperation() override;
+};
+
+} // namespace
+
 /// Function targeted for coroutine transformation has two additional blocks at
 /// the end: coroutine cleanup and coroutine suspension.
 ///
@@ -84,6 +96,9 @@ struct CoroMachinery {
 };
 } // namespace
 
+using FuncCoroMapPtr =
+    std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
+
 /// Utility to partially update the regular function CFG to the coroutine CFG
 /// compatible with LLVM coroutines switched-resume lowering using
 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
@@ -399,9 +414,8 @@ namespace {
 
 class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
 public:
-  AsyncFuncOpLowering(MLIRContext *ctx,
-                      llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
-      : OpConversionPattern<async::FuncOp>(ctx), coros(coros) {}
+  AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+      : OpConversionPattern<async::FuncOp>(ctx), coros_(coros) {}
 
   LogicalResult
   matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
@@ -423,7 +437,7 @@ class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
                                 newFuncOp.end());
 
     CoroMachinery coro = setupCoroMachinery(newFuncOp);
-    coros[newFuncOp] = coro;
+    (*coros_)[newFuncOp] = coro;
     // no initial suspend, we should hot-start
 
     rewriter.eraseOp(op);
@@ -431,7 +445,7 @@ class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
   }
 
 private:
-  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+  FuncCoroMapPtr coros_;
 };
 
 //===----------------------------------------------------------------------===//
@@ -458,16 +472,15 @@ class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
 
 class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
 public:
-  AsyncReturnOpLowering(MLIRContext *ctx,
-                        llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
-      : OpConversionPattern<async::ReturnOp>(ctx), coros(coros) {}
+  AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+      : OpConversionPattern<async::ReturnOp>(ctx), coros_(coros) {}
 
   LogicalResult
   matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto funcCoro = coros.find(func);
-    if (funcCoro == coros.end())
+    auto funcCoro = coros_->find(func);
+    if (funcCoro == coros_->end())
       return rewriter.notifyMatchFailure(
           op, "operation is not inside the async coroutine function");
 
@@ -494,7 +507,7 @@ class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
   }
 
 private:
-  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+  FuncCoroMapPtr coros_;
 };
 } // namespace
 
@@ -509,9 +522,10 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
   using AwaitAdaptor = typename AwaitType::Adaptor;
 
 public:
-  AwaitOpLoweringBase(MLIRContext *ctx,
-                      llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
-      : OpConversionPattern<AwaitType>(ctx), coros(coros) {}
+  AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
+                      bool should_lower_blocking_wait)
+      : OpConversionPattern<AwaitType>(ctx), coros_(coros),
+        should_lower_blocking_wait_(should_lower_blocking_wait) {}
 
   LogicalResult
   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
@@ -521,16 +535,20 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     if (!op.getOperand().getType().template isa<AwaitableType>())
       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
 
-    // Check if await operation is inside the outlined coroutine function.
+    // Check if await operation is inside the coroutine function.
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto funcCoro = coros.find(func);
-    const bool isInCoroutine = funcCoro != coros.end();
+    auto funcCoro = coros_->find(func);
+    const bool isInCoroutine = funcCoro != coros_->end();
 
     Location loc = op->getLoc();
     Value operand = adaptor.getOperand();
 
     Type i1 = rewriter.getI1Type();
 
+    // Delay lowering to block wait in case await op is inside async.execute
+    if (!isInCoroutine && !should_lower_blocking_wait_)
+      return failure();
+
     // Inside regular functions we use the blocking wait operation to wait for
     // the async object (token, value or group) to become available.
     if (!isInCoroutine) {
@@ -602,7 +620,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
   }
 
 private:
-  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+  FuncCoroMapPtr coros_;
+  bool should_lower_blocking_wait_;
 };
 
 /// Lowering for `async.await` with a token operand.
@@ -645,17 +664,16 @@ class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
 
 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
 public:
-  YieldOpLowering(MLIRContext *ctx,
-                  const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
-      : OpConversionPattern<async::YieldOp>(ctx), coros(coros) {}
+  YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+      : OpConversionPattern<async::YieldOp>(ctx), coros_(coros) {}
 
   LogicalResult
   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check if yield operation is inside the async coroutine function.
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto funcCoro = coros.find(func);
-    if (funcCoro == coros.end())
+    auto funcCoro = coros_->find(func);
+    if (funcCoro == coros_->end())
       return rewriter.notifyMatchFailure(
           op, "operation is not inside the async coroutine function");
 
@@ -682,7 +700,7 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
   }
 
 private:
-  const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+  FuncCoroMapPtr coros_;
 };
 
 //===----------------------------------------------------------------------===//
@@ -691,17 +709,16 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
 
 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
 public:
-  AssertOpLowering(MLIRContext *ctx,
-                   llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
-      : OpConversionPattern<cf::AssertOp>(ctx), coros(coros) {}
+  AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+      : OpConversionPattern<cf::AssertOp>(ctx), coros_(coros) {}
 
   LogicalResult
   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check if assert operation is inside the async coroutine function.
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto funcCoro = coros.find(func);
-    if (funcCoro == coros.end())
+    auto funcCoro = coros_->find(func);
+    if (funcCoro == coros_->end())
       return rewriter.notifyMatchFailure(
           op, "operation is not inside the async coroutine function");
 
@@ -721,7 +738,7 @@ class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
   }
 
 private:
-  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+  FuncCoroMapPtr coros_;
 };
 
 //===----------------------------------------------------------------------===//
@@ -730,22 +747,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   SymbolTable symbolTable(module);
 
   // Functions with coroutine CFG setups, which are results of outlining
-  // `async.execute` body regions and converting async.func.
-  llvm::DenseMap<func::FuncOp, CoroMachinery> coros;
+  // `async.execute` body regions
+  FuncCoroMapPtr coros =
+      std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
 
   module.walk([&](ExecuteOp execute) {
-    coros.insert(outlineExecuteOp(symbolTable, execute));
+    coros->insert(outlineExecuteOp(symbolTable, execute));
   });
 
   LLVM_DEBUG({
-    llvm::dbgs() << "Outlined " << coros.size()
+    llvm::dbgs() << "Outlined " << coros->size()
                  << " functions built from async.execute operations\n";
   });
 
   // Returns true if operation is inside the coroutine.
   auto isInCoroutine = [&](Operation *op) -> bool {
     auto parentFunc = op->getParentOfType<func::FuncOp>();
-    return coros.find(parentFunc) != coros.end();
+    return coros->find(parentFunc) != coros->end();
   };
 
   // Lower async operations to async.runtime operations.
@@ -762,22 +780,18 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   // types for async.runtime operations.
   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
 
-  // Lower async.func to func.func with coroutine cfg.
-  asyncPatterns.add<AsyncCallOpLowering>(ctx);
-  asyncPatterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
-
-  asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
-                    AwaitAllOpLowering, YieldOpLowering>(ctx, coros);
+  asyncPatterns
+      .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
+          ctx, coros, /*should_lower_blocking_wait=*/true);
 
   // Lower assertions to conditional branches into error blocks.
-  asyncPatterns.add<AssertOpLowering>(ctx, coros);
+  asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
 
   // All high level async operations must be lowered to the runtime operations.
   ConversionTarget runtimeTarget(*ctx);
   runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
-  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp,
-                             async::FuncOp, async::CallOp, async::ReturnOp>();
+  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
 
   // Decide if structured control flow has to be lowered to branch-based CFG.
   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
@@ -795,7 +809,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
       [&](cf::AssertOp op) -> bool {
         auto func = op->getParentOfType<func::FuncOp>();
-        return coros.find(func) == coros.end();
+        return coros->find(func) == coros->end();
       });
 
   if (failed(applyPartialConversion(module, runtimeTarget,
@@ -805,6 +819,59 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   }
 }
 
+//===----------------------------------------------------------------------===//
+void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
+    RewritePatternSet &patterns, ConversionTarget &target) {
+  // Functions with coroutine CFG setups, which are results of converting
+  // async.func.
+  FuncCoroMapPtr coros =
+      std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
+  MLIRContext *ctx = patterns.getContext();
+  // Lower async.func to func.func with coroutine cfg.
+  patterns.add<AsyncCallOpLowering>(ctx);
+  patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
+
+  patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
+      ctx, coros, /*should_lower_blocking_wait=*/false);
+  patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
+
+  target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
+      [coros](Operation *op) {
+        auto func = op->getParentOfType<func::FuncOp>();
+        return coros->find(func) == coros->end();
+      });
+}
+
+void AsyncFuncToAsyncRuntimePass::runOnOperation() {
+  ModuleOp module = getOperation();
+
+  // Lower async operations to async.runtime operations.
+  MLIRContext *ctx = module->getContext();
+  RewritePatternSet asyncPatterns(ctx);
+  ConversionTarget runtimeTarget(*ctx);
+
+  // Lower async.func to func.func with coroutine cfg.
+  populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
+                                                    runtimeTarget);
+
+  runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
+  runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
+
+  runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
+                           cf::BranchOp, cf::CondBranchOp>();
+
+  if (failed(applyPartialConversion(module, runtimeTarget,
+                                    std::move(asyncPatterns)))) {
+    signalPassFailure();
+    return;
+  }
+}
+
 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
   return std::make_unique<AsyncToAsyncRuntimePass>();
 }
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createAsyncFuncToAsyncRuntimePass() {
+  return std::make_unique<AsyncFuncToAsyncRuntimePass>();
+}

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 1551e55c90c08..38a88cc9de5b8 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime                  \
-// RUN:   | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -split-input-file -async-func-to-async-runtime             \
+// RUN:   -async-to-async-runtime | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: @execute_no_async_args
 func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {

diff  --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir
index 8b3d728d4667f..6f21ba906b222 100644
--- a/mlir/test/mlir-cpu-runner/async-func.mlir
+++ b/mlir/test/mlir-cpu-runner/async-func.mlir
@@ -1,4 +1,4 @@
-// RUN:   mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(async-func-to-async-runtime,async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
 // RUN: | mlir-cpu-runner                                                      \
 // RUN:     -e main -entry-point-result=void -O0                               \
 // RUN:     -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext  \


        


More information about the Mlir-commits mailing list