[Mlir-commits] [mlir] d8c84d2 - [mlir] Async: Add error propagation support to async groups

Eugene Zhulenev llvmlistbot at llvm.org
Thu May 27 09:35:17 PDT 2021


Author: Eugene Zhulenev
Date: 2021-05-27T09:35:11-07:00
New Revision: d8c84d2a4efc87b756d9d3df42b80d6f8762f62a

URL: https://github.com/llvm/llvm-project/commit/d8c84d2a4efc87b756d9d3df42b80d6f8762f62a
DIFF: https://github.com/llvm/llvm-project/commit/d8c84d2a4efc87b756d9d3df42b80d6f8762f62a.diff

LOG: [mlir] Async: Add error propagation support to async groups

Depends On D103109

If any of the tokens/values added to the `!async.group` switches to the error state, than the group itself switches to the error state.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D103203

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp
    mlir/test/Dialect/Async/async-to-async-runtime.mlir
    mlir/test/Dialect/Async/runtime.mlir
    mlir/test/mlir-cpu-runner/async-error.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 5f6eece6cf69d..9ef218ebe560f 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -373,7 +373,7 @@ def Async_RuntimeIsErrorOp : Async_Op<"runtime.is_error"> {
     resuming after `await_and_resume`.
   }];
 
-  let arguments = (ins Async_AnyValueOrTokenType:$operand);
+  let arguments = (ins Async_AnyAsyncType:$operand);
   let results = (outs I1:$is_error);
 
   let assemblyFormat = "$operand attr-dict `:` type($operand)";

diff  --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index df9adebbd6206..3b26bf61e622f 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -88,6 +88,10 @@ extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *);
 // Returns true if value is in the error state.
 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *);
 
+// Returns true if group is in the error state (any of the tokens or values
+// added to the group are in the error state).
+extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *);
+
 // Blocks the caller thread until the token becomes ready.
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *);
 

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 55d1714171774..a66f246cf55ed 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #define DEBUG_TYPE "convert-async-to-llvm"
 
@@ -39,6 +40,7 @@ static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
+static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
@@ -125,6 +127,11 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {value}, {i1});
   }
 
+  static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
+    auto i1 = IntegerType::get(ctx, 1);
+    return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
+  }
+
   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
@@ -201,6 +208,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
+  addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
@@ -587,10 +595,13 @@ class RuntimeSetAvailableOpLowering
   LogicalResult
   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Type operandType = op.operand().getType();
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, operandType.isa<TokenType>() ? kEmplaceToken : kEmplaceValue,
-        TypeRange(), operands);
+    StringRef apiFuncName =
+        TypeSwitch<Type, StringRef>(op.operand().getType())
+            .Case<TokenType>([](Type) { return kEmplaceToken; })
+            .Case<ValueType>([](Type) { return kEmplaceValue; });
+
+    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+
     return success();
   }
 };
@@ -609,10 +620,13 @@ class RuntimeSetErrorOpLowering
   LogicalResult
   matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Type operandType = op.operand().getType();
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, operandType.isa<TokenType>() ? kSetTokenError : kSetValueError,
-        TypeRange(), operands);
+    StringRef apiFuncName =
+        TypeSwitch<Type, StringRef>(op.operand().getType())
+            .Case<TokenType>([](Type) { return kSetTokenError; })
+            .Case<ValueType>([](Type) { return kSetValueError; });
+
+    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+
     return success();
   }
 };
@@ -630,10 +644,14 @@ class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
   LogicalResult
   matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Type operandType = op.operand().getType();
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, operandType.isa<TokenType>() ? kIsTokenError : kIsValueError,
-        rewriter.getI1Type(), operands);
+    StringRef apiFuncName =
+        TypeSwitch<Type, StringRef>(op.operand().getType())
+            .Case<TokenType>([](Type) { return kIsTokenError; })
+            .Case<GroupType>([](Type) { return kIsGroupError; })
+            .Case<ValueType>([](Type) { return kIsValueError; });
+
+    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
+                                        operands);
     return success();
   }
 };
@@ -651,17 +669,11 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
   LogicalResult
   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Type operandType = op.operand().getType();
-
-    StringRef apiFuncName;
-    if (operandType.isa<TokenType>())
-      apiFuncName = kAwaitToken;
-    else if (operandType.isa<ValueType>())
-      apiFuncName = kAwaitValue;
-    else if (operandType.isa<GroupType>())
-      apiFuncName = kAwaitGroup;
-    else
-      return rewriter.notifyMatchFailure(op, "unsupported async type");
+    StringRef apiFuncName =
+        TypeSwitch<Type, StringRef>(op.operand().getType())
+            .Case<TokenType>([](Type) { return kAwaitToken; })
+            .Case<ValueType>([](Type) { return kAwaitValue; })
+            .Case<GroupType>([](Type) { return kAwaitGroup; });
 
     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
     rewriter.eraseOp(op);
@@ -684,17 +696,11 @@ class RuntimeAwaitAndResumeOpLowering
   LogicalResult
   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    Type operandType = op.operand().getType();
-
-    StringRef apiFuncName;
-    if (operandType.isa<TokenType>())
-      apiFuncName = kAwaitTokenAndExecute;
-    else if (operandType.isa<ValueType>())
-      apiFuncName = kAwaitValueAndExecute;
-    else if (operandType.isa<GroupType>())
-      apiFuncName = kAwaitAllAndExecute;
-    else
-      return rewriter.notifyMatchFailure(op, "unsupported async type");
+    StringRef apiFuncName =
+        TypeSwitch<Type, StringRef>(op.operand().getType())
+            .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
+            .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
+            .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
 
     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index baa7d5d8c46ae..6ebf48ad64756 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -399,25 +399,22 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
                                     coro.cleanup);
 
-      // TODO: Async groups do not yet support runtime errors.
-      if (!std::is_same<AwaitAllOp, AwaitType>::value) {
-        // Split the resume block into error checking and continuation.
-        Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
-
-        // Check if the awaited value is in the error state.
-        builder.setInsertionPointToStart(resume);
-        auto isError = builder.create<RuntimeIsErrorOp>(
-            loc, rewriter.getI1Type(), operand);
-        builder.create<CondBranchOp>(isError,
-                                     /*trueDest=*/setupSetErrorBlock(coro),
-                                     /*trueArgs=*/ArrayRef<Value>(),
-                                     /*falseDest=*/continuation,
-                                     /*falseArgs=*/ArrayRef<Value>());
-
-        // Make sure that replacement value will be constructed in the
-        // continuation block.
-        rewriter.setInsertionPointToStart(continuation);
-      }
+      // Split the resume block into error checking and continuation.
+      Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
+
+      // Check if the awaited value is in the error state.
+      builder.setInsertionPointToStart(resume);
+      auto isError =
+          builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
+      builder.create<CondBranchOp>(isError,
+                                   /*trueDest=*/setupSetErrorBlock(coro),
+                                   /*trueArgs=*/ArrayRef<Value>(),
+                                   /*falseDest=*/continuation,
+                                   /*falseArgs=*/ArrayRef<Value>());
+
+      // Make sure that replacement value will be constructed in the
+      // continuation block.
+      rewriter.setInsertionPointToStart(continuation);
     }
 
     // Erase or replace the await operation with the new value.

diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 856d1c7b74f9c..6bbb8e4052b10 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -212,9 +212,10 @@ struct AsyncValue : public RefCounted {
 // tokens or values added to the group).
 struct AsyncGroup : public RefCounted {
   AsyncGroup(AsyncRuntime *runtime)
-      : RefCounted(runtime), pendingTokens(0), rank(0) {}
+      : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
 
   std::atomic<int> pendingTokens;
+  std::atomic<int> numErrors;
   std::atomic<int> rank;
 
   // Pending awaiters are guarded by a mutex.
@@ -262,7 +263,11 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
   int rank = group->rank.fetch_add(1);
   group->pendingTokens.fetch_add(1);
 
-  auto onTokenReady = [group]() {
+  auto onTokenReady = [group, token]() {
+    // Increment the number of errors in the group.
+    if (State(token->state).isError())
+      group->numErrors.fetch_add(1);
+
     // Run all group awaiters if it was the last token in the group.
     if (group->pendingTokens.fetch_sub(1) == 1) {
       group->cv.notify_all();
@@ -356,6 +361,10 @@ extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
   return State(value->state).isError();
 }
 
+extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
+  return group->numErrors.load() > 0;
+}
+
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
   std::unique_lock<std::mutex> lock(token->mu);
   if (!State(token->state).isAvailableOrError())
@@ -483,6 +492,8 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
                &mlir::runtime::mlirAsyncRuntimeIsTokenError);
   exportSymbol("mlirAsyncRuntimeIsValueError",
                &mlir::runtime::mlirAsyncRuntimeIsValueError);
+  exportSymbol("mlirAsyncRuntimeIsGroupError",
+               &mlir::runtime::mlirAsyncRuntimeIsGroupError);
   exportSymbol("mlirAsyncRuntimeAwaitToken",
                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
   exportSymbol("mlirAsyncRuntimeAwaitValue",

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 08607f89e2e57..b77b0d6e89a9a 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -216,8 +216,13 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
 // CHECK:   async.coro.suspend
 // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
 
-// Emplace result token.
+// Check the error of the awaited token after resumption.
 // CHECK: ^[[RESUME_1]]:
+// CHECK:   %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
+// CHECK:   cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// Emplace result token after error checking.
+// CHECK: ^[[CONTINUATION:.*]]:
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 
 // CHECK: ^[[CLEANUP]]:

diff  --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index c8f6f65cb8469..ede523f3c084a 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -66,6 +66,13 @@ func @is_value_error(%arg0: !async.value<f32>) -> i1 {
   return %0 : i1
 }
 
+// CHECK-LABEL: @is_group_error
+func @is_group_error(%arg0: !async.group) -> i1 {
+  // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.group
+  %0 = async.runtime.is_error %arg0 : !async.group
+  return %0 : i1
+}
+
 // CHECK-LABEL: @await_token
 func @await_token(%arg0: !async.token) {
   // CHECK: async.runtime.await %arg0 : !async.token

diff  --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir
index 139296d6614f4..63b9a0077f5d9 100644
--- a/mlir/test/mlir-cpu-runner/async-error.mlir
+++ b/mlir/test/mlir-cpu-runner/async-error.mlir
@@ -81,5 +81,29 @@ func @main() {
   vector.print %err3_0 : i1
   vector.print %err3_1 : i1
 
+  // ------------------------------------------------------------------------ //
+  // Check error propagation from a token to the group.
+  // ------------------------------------------------------------------------ //
+
+  %group0 = async.create_group
+
+  %token4 = async.execute {
+    async.yield
+  }
+
+  %token5 = async.execute {
+    assert %false, "error"
+    async.yield
+  }
+
+  %idx0 = async.add_to_group %token4, %group0 : !async.token
+  %idx1 = async.add_to_group %token5, %group0 : !async.token
+
+  async.runtime.await %group0 : !async.group
+
+  // CHECK: 1
+  %err4 = async.runtime.is_error %group0 : !async.group
+  vector.print %err4 : i1
+
   return
 }


        


More information about the Mlir-commits mailing list