[Mlir-commits] [mlir] fd52b43 - [mlir] Async: check awaited operand error state after sync await

Eugene Zhulenev llvmlistbot at llvm.org
Sat Sep 4 05:00:25 PDT 2021


Author: Eugene Zhulenev
Date: 2021-09-04T05:00:17-07:00
New Revision: fd52b4357a6eb718c2c7f9cfe1d8f55ef195edb1

URL: https://github.com/llvm/llvm-project/commit/fd52b4357a6eb718c2c7f9cfe1d8f55ef195edb1
DIFF: https://github.com/llvm/llvm-project/commit/fd52b4357a6eb718c2c7f9cfe1d8f55ef195edb1.diff

LOG: [mlir] Async: check awaited operand error state after sync await

Previously only await inside the async function (coroutine after lowering to async runtime) would check the error state

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D109229

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
    mlir/test/Dialect/Async/async-to-async-runtime.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 17e768cee74ba..2dc6cfe9625e5 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -525,10 +525,6 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
     bool isGroup = type.isa<GroupType>();
     bool isValue = type.isa<ValueType>();
 
-    // Drop reference after async token or group await (sync await)
-    if (auto await = dyn_cast<RuntimeAwaitOp>(op))
-      return (isToken || isGroup) ? -1 : 0;
-
     // Drop reference after async token or group error check (coro await).
     if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
       return (isToken || isGroup) ? -1 : 0;

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 9e70853b3fa38..2127d7d4ec065 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -397,10 +397,23 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     Location loc = op->getLoc();
     Value operand = AwaitAdaptor(operands).operand();
 
+    Type i1 = rewriter.getI1Type();
+
     // Inside regular functions we use the blocking wait operation to wait for
     // the async object (token, value or group) to become available.
-    if (!isInCoroutine)
-      rewriter.create<RuntimeAwaitOp>(loc, operand);
+    if (!isInCoroutine) {
+      ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
+      builder.create<RuntimeAwaitOp>(loc, operand);
+
+      // Assert that the awaited operands is not in the error state.
+      Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
+      Value notError = builder.create<XOrOp>(
+          isError,
+          builder.create<ConstantOp>(loc, i1, builder.getIntegerAttr(i1, 1)));
+
+      builder.create<AssertOp>(notError,
+                               "Awaited async operand is in error state");
+    }
 
     // Inside the coroutine we convert await operation into coroutine suspension
     // point, and resume execution asynchronously.
@@ -430,8 +443,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
 
       // Check if the awaited value is in the error state.
       builder.setInsertionPointToStart(resume);
-      auto isError =
-          builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
+      auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
       builder.create<CondBranchOp>(isError,
                                    /*trueDest=*/setupSetErrorBlock(coro),
                                    /*trueArgs=*/ArrayRef<Value>(),
@@ -772,7 +784,8 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
     });
     return !walkResult.wasInterrupted();
   });
-  runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
+  runtimeTarget
+      .addLegalOp<AssertOp, XOrOp, ConstantOp, BranchOp, CondBranchOp>();
 
   // Assertions must be converted to runtime errors inside async functions.
   runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index f8afa39060a9f..cfb7620350402 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -24,6 +24,10 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
     async.yield
   }
   // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
+  // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+  // CHECK: %[[TRUE:.*]] = constant true
+  // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
+  // CHECK: assert %[[NOT_ERROR]]
   // CHECK-NEXT: return
   async.await %token : !async.token
   return
@@ -83,7 +87,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
     async.yield
   }
   // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
-  // CHECK-NEXT: return
+  // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+  // CHECK: %[[TRUE:.*]] = constant true
+  // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
+  // CHECK: assert %[[NOT_ERROR]]
   async.await %token0 : !async.token
   return
 }

diff  --git a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
index 54640f552798d..34cfb84bee35d 100644
--- a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
+++ b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
@@ -4,7 +4,7 @@
 // CHECK:         %[[TOKEN:.*]]: !async.token
 func @token_await(%arg0: !async.token) {
   // CHECK: async.runtime.await %[[TOKEN]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK-NOT: async.runtime.drop_ref
   async.runtime.await %arg0 : !async.token
   return
 }
@@ -13,7 +13,7 @@ func @token_await(%arg0: !async.token) {
 // CHECK:         %[[GROUP:.*]]: !async.group
 func @group_await(%arg0: !async.group) {
   // CHECK: async.runtime.await %[[GROUP]]
-  // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32}
+  // CHECK-NOT: async.runtime.drop_ref
   async.runtime.await %arg0 : !async.group
   return
 }

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 9c61394aa8ed9..9128fc4a18688 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -60,6 +60,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
     async.yield
   }
   // CHECK: async.runtime.await %[[TOKEN]]
+  // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
+  // CHECK: %[[TRUE:.*]] = constant true
+  // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
+  // CHECK: assert %[[NOT_ERROR]]
   // CHECK-NEXT: return
   async.await %token0 : !async.token
   return


        


More information about the Mlir-commits mailing list