[Mlir-commits] [mlir] fd52b43 - [mlir] Async: check awaited operand error state after sync await
Eugene Zhulenev
llvmlistbot at llvm.org
Sat Sep 4 05:00:25 PDT 2021
Author: Eugene Zhulenev
Date: 2021-09-04T05:00:17-07:00
New Revision: fd52b4357a6eb718c2c7f9cfe1d8f55ef195edb1
URL: https://github.com/llvm/llvm-project/commit/fd52b4357a6eb718c2c7f9cfe1d8f55ef195edb1
DIFF: https://github.com/llvm/llvm-project/commit/fd52b4357a6eb718c2c7f9cfe1d8f55ef195edb1.diff
LOG: [mlir] Async: check awaited operand error state after sync await
Previously only await inside the async function (coroutine after lowering to async runtime) would check the error state
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D109229
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
mlir/test/Dialect/Async/async-to-async-runtime.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 17e768cee74ba..2dc6cfe9625e5 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -525,10 +525,6 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
bool isGroup = type.isa<GroupType>();
bool isValue = type.isa<ValueType>();
- // Drop reference after async token or group await (sync await)
- if (auto await = dyn_cast<RuntimeAwaitOp>(op))
- return (isToken || isGroup) ? -1 : 0;
-
// Drop reference after async token or group error check (coro await).
if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
return (isToken || isGroup) ? -1 : 0;
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 9e70853b3fa38..2127d7d4ec065 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -397,10 +397,23 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
Location loc = op->getLoc();
Value operand = AwaitAdaptor(operands).operand();
+ Type i1 = rewriter.getI1Type();
+
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
- if (!isInCoroutine)
- rewriter.create<RuntimeAwaitOp>(loc, operand);
+ if (!isInCoroutine) {
+ ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
+ builder.create<RuntimeAwaitOp>(loc, operand);
+
+ // Assert that the awaited operands is not in the error state.
+ Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
+ Value notError = builder.create<XOrOp>(
+ isError,
+ builder.create<ConstantOp>(loc, i1, builder.getIntegerAttr(i1, 1)));
+
+ builder.create<AssertOp>(notError,
+ "Awaited async operand is in error state");
+ }
// Inside the coroutine we convert await operation into coroutine suspension
// point, and resume execution asynchronously.
@@ -430,8 +443,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Check if the awaited value is in the error state.
builder.setInsertionPointToStart(resume);
- auto isError =
- builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
+ auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
builder.create<CondBranchOp>(isError,
/*trueDest=*/setupSetErrorBlock(coro),
/*trueArgs=*/ArrayRef<Value>(),
@@ -772,7 +784,8 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
});
return !walkResult.wasInterrupted();
});
- runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
+ runtimeTarget
+ .addLegalOp<AssertOp, XOrOp, ConstantOp, BranchOp, CondBranchOp>();
// Assertions must be converted to runtime errors inside async functions.
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index f8afa39060a9f..cfb7620350402 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -24,6 +24,10 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
async.yield
}
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
+ // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+ // CHECK: %[[TRUE:.*]] = constant true
+ // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
+ // CHECK: assert %[[NOT_ERROR]]
// CHECK-NEXT: return
async.await %token : !async.token
return
@@ -83,7 +87,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
async.yield
}
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
- // CHECK-NEXT: return
+ // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+ // CHECK: %[[TRUE:.*]] = constant true
+ // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
+ // CHECK: assert %[[NOT_ERROR]]
async.await %token0 : !async.token
return
}
diff --git a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
index 54640f552798d..34cfb84bee35d 100644
--- a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
+++ b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
@@ -4,7 +4,7 @@
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_await(%arg0: !async.token) {
// CHECK: async.runtime.await %[[TOKEN]]
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK-NOT: async.runtime.drop_ref
async.runtime.await %arg0 : !async.token
return
}
@@ -13,7 +13,7 @@ func @token_await(%arg0: !async.token) {
// CHECK: %[[GROUP:.*]]: !async.group
func @group_await(%arg0: !async.group) {
// CHECK: async.runtime.await %[[GROUP]]
- // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32}
+ // CHECK-NOT: async.runtime.drop_ref
async.runtime.await %arg0 : !async.group
return
}
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 9c61394aa8ed9..9128fc4a18688 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -60,6 +60,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
async.yield
}
// CHECK: async.runtime.await %[[TOKEN]]
+ // CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
+ // CHECK: %[[TRUE:.*]] = constant true
+ // CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
+ // CHECK: assert %[[NOT_ERROR]]
// CHECK-NEXT: return
async.await %token0 : !async.token
return
More information about the Mlir-commits
mailing list