[Mlir-commits] [mlir] c30ab6c - [mlir] Transform scf.parallel to scf.for + async.execute
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Nov 13 04:03:06 PST 2020
Author: Eugene Zhulenev
Date: 2020-11-13T04:02:56-08:00
New Revision: c30ab6c2a307cfdce8323ed94c3d70eb2d26bc14
URL: https://github.com/llvm/llvm-project/commit/c30ab6c2a307cfdce8323ed94c3d70eb2d26bc14
DIFF: https://github.com/llvm/llvm-project/commit/c30ab6c2a307cfdce8323ed94c3d70eb2d26bc14.diff
LOG: [mlir] Transform scf.parallel to scf.for + async.execute
Depends On D89958
1. Adds `async.group`/`async.awaitall` to group together multiple async tokens/values
2. Rewrite scf.parallel operation into multiple concurrent async.execute operations over non overlapping subranges of the original loop.
Example:
```
scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
"do_some_compute"(%i, %j): () -> ()
}
```
Converted to:
```
%c0 = constant 0 : index
%c1 = constant 1 : index
// Compute blocks sizes for each induction variable.
%num_blocks_i = ... : index
%num_blocks_j = ... : index
%block_size_i = ... : index
%block_size_j = ... : index
// Create an async group to track async execute ops.
%group = async.create_group
scf.for %bi = %c0 to %num_blocks_i step %c1 {
%block_start_i = ... : index
%block_end_i = ... : index
scf.for %bj = %c0 t0 %num_blocks_j step %c1 {
%block_start_j = ... : index
%block_end_j = ... : index
// Execute the body of original parallel operation for the current
// block.
%token = async.execute {
scf.for %i = %block_start_i to %block_end_i step %si {
scf.for %j = %block_start_j to %block_end_j step %sj {
"do_some_compute"(%i, %j): () -> ()
}
}
}
// Add produced async token to the group.
async.add_to_group %token, %group
}
}
// Await completion of all async.execute operations.
async.await_all %group
```
In this example outer loop launches inner block level loops as separate async
execute operations which will be executed concurrently.
At the end it waits for the completiom of all async execute operations.
Reviewed By: ftynse, mehdi_amini
Differential Revision: https://reviews.llvm.org/D89963
Added:
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/integration_test/Dialect/Async/CPU/lit.local.cfg
mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
mlir/lib/Dialect/Async/Transforms/PassDetail.h
mlir/test/Dialect/Async/async-parallel-for.mlir
mlir/test/mlir-cpu-runner/async-group.mlir
Modified:
mlir/include/mlir/Dialect/Async/CMakeLists.txt
mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/CMakeLists.txt
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Async/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/CMakeLists.txt
index f33061b2d87c..fd4730cb982d 100644
--- a/mlir/include/mlir/Dialect/Async/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Async/CMakeLists.txt
@@ -1 +1,7 @@
add_subdirectory(IR)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Async)
+add_public_tablegen_target(MLIRAsyncPassIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc AsyncPasses ./)
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index 28790557f36b..ad5a8aa03098 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -47,6 +47,12 @@ class ValueType
Type getValueType();
};
+/// The group type to represent async tokens or values grouped together.
+class GroupType : public Type::TypeBase<GroupType, Type, TypeStorage> {
+public:
+ using Base::Base;
+};
+
} // namespace async
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
index c411c4a21703..e7a5e90298da 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
@@ -56,6 +56,16 @@ class Async_ValueType<Type type>
Type valueType = type;
}
+def Async_GroupType : DialectType<AsyncDialect,
+ CPred<"$_self.isa<::mlir::async::GroupType>()">, "group type">,
+ BuildableType<"$_builder.getType<::mlir::async::GroupType>()"> {
+ let typeDescription = [{
+ `async.group` represent a set of async tokens or values and allows to
+ execute async operations on all of them together (e.g. wait for the
+ completion of all/any of them).
+ }];
+}
+
def Async_AnyValueType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::ValueType>()">,
"async value type">;
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 7fad5ce48214..cc987856a28e 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -81,6 +81,20 @@ def Async_ExecuteOp :
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let verifier = [{ return ::verify(*this); }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilderDAG<(ins "TypeRange":$resultTypes, "ValueRange":$dependencies,
+ "ValueRange":$operands,
+ CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
+ "nullptr">:$bodyBuilder)>,
+ ];
+
+ let extraClassDeclaration = [{
+ using BodyBuilderFn =
+ function_ref<void(OpBuilder &, Location, ValueRange)>;
+
+ }];
}
def Async_YieldOp :
@@ -93,12 +107,12 @@ def Async_YieldOp :
let arguments = (ins Variadic<AnyType>:$operands);
- let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+ let assemblyFormat = "($operands^ `:` type($operands))? attr-dict";
let verifier = [{ return ::verify(*this); }];
}
-def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> {
+def Async_AwaitOp : Async_Op<"await"> {
let summary = "waits for the argument to become ready";
let description = [{
The `async.await` operation waits until the argument becomes ready, and for
@@ -133,12 +147,84 @@ def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> {
}];
let assemblyFormat = [{
- attr-dict $operand `:` custom<AwaitResultType>(
+ $operand `:` custom<AwaitResultType>(
type($operand), type($result)
- )
+ ) attr-dict
}];
let verifier = [{ return ::verify(*this); }];
}
+def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
+ let summary = "creates an empty async group";
+ let description = [{
+ The `async.create_group` allocates an empty async group. Async tokens or
+ values can be added to this group later.
+
+ Example:
+
+ ```mlir
+ %0 = async.create_group
+ ...
+ async.await_all %0
+ ```
+ }];
+
+ let arguments = (ins );
+ let results = (outs Async_GroupType:$result);
+
+ let assemblyFormat = "attr-dict";
+}
+
+def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
+ let summary = "adds and async token or value to the group";
+ let description = [{
+ The `async.add_to_group` adds an async token or value to the async group.
+ Returns the rank of the added element in the group. This rank is fixed
+ for the group lifetime.
+
+ Example:
+
+ ```mlir
+ %0 = async.create_group
+ %1 = ... : !async.token
+ %2 = async.add_to_group %1, %0 : !async.token
+ ```
+ }];
+
+ let arguments = (ins Async_AnyValueOrTokenType:$operand,
+ Async_GroupType:$group);
+ let results = (outs Index:$rank);
+
+ let assemblyFormat = "$operand `,` $group `:` type($operand) attr-dict";
+}
+
+def Async_AwaitAllOp : Async_Op<"await_all", []> {
+ let summary = "waits for the all async tokens or values in the group to "
+ "become ready";
+ let description = [{
+ The `async.await_all` operation waits until all the tokens or values in the
+ group become ready.
+
+ Example:
+
+ ```mlir
+ %0 = async.create_group
+
+ %1 = ... : !async.token
+ %2 = async.add_to_group %1, %0 : !async.token
+
+ %3 = ... : !async.token
+ %4 = async.add_to_group %2, %0 : !async.token
+
+ async.await_all %0
+ ```
+ }];
+
+ let arguments = (ins Async_GroupType:$operand);
+ let results = (outs);
+
+ let assemblyFormat = "$operand attr-dict";
+}
+
#endif // ASYNC_OPS
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
new file mode 100644
index 000000000000..d5a8a82dab49
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -0,0 +1,32 @@
+//===- Passes.h - Async pass entry points -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ASYNC_PASSES_H_
+#define MLIR_DIALECT_ASYNC_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Async/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ASYNC_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
new file mode 100644
index 000000000000..51fd4e32c78e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -0,0 +1,27 @@
+//===-- Passes.td - Async pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ASYNC_PASSES
+#define MLIR_DIALECT_ASYNC_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
+ let summary = "Convert scf.parallel operations to multiple async regions "
+ "executed concurrently for non-overlapping iteration ranges";
+ let constructor = "mlir::createAsyncParallelForPass()";
+ let options = [
+ Option<"numConcurrentAsyncExecute", "num-concurrent-async-execute",
+ "int32_t", /*default=*/"4",
+ "The number of async.execute operations that will be used for concurrent "
+ "loop execution.">
+ ];
+ let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
+}
+
+#endif // MLIR_DIALECT_ASYNC_PASSES
diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index e47c71c44dfc..12beffe9dd1c 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -14,6 +14,8 @@
#ifndef MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
#define MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_
+#include <stdint.h>
+
#ifdef _WIN32
#ifndef MLIR_ASYNCRUNTIME_EXPORT
#ifdef mlir_async_runtime_EXPORTS
@@ -37,6 +39,9 @@
// Runtime implementation of `async.token` data type.
typedef struct AsyncToken MLIR_AsyncToken;
+// Runtime implementation of `async.group` data type.
+typedef struct AsyncGroup MLIR_AsyncGroup;
+
// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
// function is a coroutine handle and a resume function that continue coroutine
// execution from a suspension point.
@@ -46,6 +51,12 @@ using CoroResume = void (*)(void *); // coroutine resume function
// Create a new `async.token` in not-ready state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
+// Create a new `async.group` in empty state.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup();
+
+extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
+mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
+
// Switches `async.token` to ready state and runs all awaiters.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeEmplaceToken(AsyncToken *);
@@ -54,6 +65,10 @@ mlirAsyncRuntimeEmplaceToken(AsyncToken *);
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitToken(AsyncToken *);
+// Blocks the caller thread until the elements in the group become ready.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *);
+
// Executes the task (coro handle + resume function) in one of the threads
// managed by the runtime.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
@@ -64,6 +79,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
+// Executes the task (coro handle + resume function) in one of the threads
+// managed by the runtime after the all members of the group become ready.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume);
+
//===----------------------------------------------------------------------===//
// Small async runtime support library for testing.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 11cf6ca5f8fb..2d57dd5081bf 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -16,6 +16,7 @@
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -47,6 +48,7 @@ inline void registerAllPasses() {
// Dialect passes
registerAffinePasses();
+ registerAsyncPasses();
registerGPUPasses();
registerLinalgPasses();
LLVM::registerLLVMPasses();
diff --git a/mlir/integration_test/Dialect/Async/CPU/lit.local.cfg b/mlir/integration_test/Dialect/Async/CPU/lit.local.cfg
new file mode 100644
index 000000000000..83247d7e3744
--- /dev/null
+++ b/mlir/integration_test/Dialect/Async/CPU/lit.local.cfg
@@ -0,0 +1,5 @@
+import sys
+
+# No JIT on win32.
+if sys.platform == 'win32':
+ config.unsupported = True
diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
new file mode 100644
index 000000000000..4bbc540ba942
--- /dev/null
+++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -convert-async-to-llvm \
+// RUN: -convert-scf-to-std \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
+func @entry() {
+ %c0 = constant 0.0 : f32
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+
+ %lb = constant 0 : index
+ %ub = constant 9 : index
+
+ %A = alloc() : memref<9xf32>
+ %U = memref_cast %A : memref<9xf32> to memref<*xf32>
+
+ // 1. %i = (0) to (9) step (1)
+ scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
+ %0 = index_cast %i : index to i32
+ %1 = sitofp %0 : i32 to f32
+ store %1, %A[%i] : memref<9xf32>
+ }
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8]
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+ scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
+ store %c0, %A[%i] : memref<9xf32>
+ }
+
+ // 2. %i = (0) to (9) step (2)
+ scf.parallel (%i) = (%lb) to (%ub) step (%c2) {
+ %0 = index_cast %i : index to i32
+ %1 = sitofp %0 : i32 to f32
+ store %1, %A[%i] : memref<9xf32>
+ }
+ // CHECK: [0, 0, 2, 0, 4, 0, 6, 0, 8]
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+ scf.parallel (%i) = (%lb) to (%ub) step (%c1) {
+ store %c0, %A[%i] : memref<9xf32>
+ }
+
+ // 3. %i = (-20) to (-11) step (3)
+ %lb0 = constant -20 : index
+ %ub0 = constant -11 : index
+ scf.parallel (%i) = (%lb0) to (%ub0) step (%c3) {
+ %0 = index_cast %i : index to i32
+ %1 = sitofp %0 : i32 to f32
+ %2 = constant 20 : index
+ %3 = addi %i, %2 : index
+ store %1, %A[%3] : memref<9xf32>
+ }
+ // CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0]
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+ dealloc %A : memref<9xf32>
+ return
+}
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
new file mode 100644
index 000000000000..8997b6835d5b
--- /dev/null
+++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -convert-async-to-llvm \
+// RUN: -convert-scf-to-std \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
+func @entry() {
+ %c0 = constant 0.0 : f32
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c8 = constant 8 : index
+
+ %lb = constant 0 : index
+ %ub = constant 8 : index
+
+ %A = alloc() : memref<8x8xf32>
+ %U = memref_cast %A : memref<8x8xf32> to memref<*xf32>
+
+ // 1. (%i, %i) = (0, 8) to (8, 8) step (1, 1)
+ scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) {
+ %0 = muli %i, %c8 : index
+ %1 = addi %j, %0 : index
+ %2 = index_cast %1 : index to i32
+ %3 = sitofp %2 : i32 to f32
+ store %3, %A[%i, %j] : memref<8x8xf32>
+ }
+
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7]
+ // CHECK-NEXT: [8, 9, 10, 11, 12, 13, 14, 15]
+ // CHECK-NEXT: [16, 17, 18, 19, 20, 21, 22, 23]
+ // CHECK-NEXT: [24, 25, 26, 27, 28, 29, 30, 31]
+ // CHECK-NEXT: [32, 33, 34, 35, 36, 37, 38, 39]
+ // CHECK-NEXT: [40, 41, 42, 43, 44, 45, 46, 47]
+ // CHECK-NEXT: [48, 49, 50, 51, 52, 53, 54, 55]
+ // CHECK-NEXT: [56, 57, 58, 59, 60, 61, 62, 63]
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+ scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) {
+ store %c0, %A[%i, %j] : memref<8x8xf32>
+ }
+
+ // 2. (%i, %i) = (0, 8) to (8, 8) step (2, 1)
+ scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c2, %c1) {
+ %0 = muli %i, %c8 : index
+ %1 = addi %j, %0 : index
+ %2 = index_cast %1 : index to i32
+ %3 = sitofp %2 : i32 to f32
+ store %3, %A[%i, %j] : memref<8x8xf32>
+ }
+
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7]
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [16, 17, 18, 19, 20, 21, 22, 23]
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [32, 33, 34, 35, 36, 37, 38, 39]
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [48, 49, 50, 51, 52, 53, 54, 55]
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0]
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+ scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) {
+ store %c0, %A[%i, %j] : memref<8x8xf32>
+ }
+
+ // 3. (%i, %i) = (0, 8) to (8, 8) step (1, 2)
+ scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c2) {
+ %0 = muli %i, %c8 : index
+ %1 = addi %j, %0 : index
+ %2 = index_cast %1 : index to i32
+ %3 = sitofp %2 : i32 to f32
+ store %3, %A[%i, %j] : memref<8x8xf32>
+ }
+
+ // CHECK: [0, 0, 2, 0, 4, 0, 6, 0]
+ // CHECK-NEXT: [8, 0, 10, 0, 12, 0, 14, 0]
+ // CHECK-NEXT: [16, 0, 18, 0, 20, 0, 22, 0]
+ // CHECK-NEXT: [24, 0, 26, 0, 28, 0, 30, 0]
+ // CHECK-NEXT: [32, 0, 34, 0, 36, 0, 38, 0]
+ // CHECK-NEXT: [40, 0, 42, 0, 44, 0, 46, 0]
+ // CHECK-NEXT: [48, 0, 50, 0, 52, 0, 54, 0]
+ // CHECK-NEXT: [56, 0, 58, 0, 60, 0, 62, 0]
+ call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+ dealloc %A : memref<8x8xf32>
+
+ return
+}
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 5233d1db179b..f063e02fd067 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -34,11 +34,17 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
//===----------------------------------------------------------------------===//
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
+static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
+static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
+static constexpr const char *kAddTokenToGroup =
+ "mlirAsyncRuntimeAddTokenToGroup";
static constexpr const char *kAwaitAndExecute =
"mlirAsyncRuntimeAwaitTokenAndExecute";
+static constexpr const char *kAwaitAllAndExecute =
+ "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
namespace {
// Async Runtime API function types.
@@ -47,6 +53,10 @@ struct AsyncAPI {
return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
}
+ static FunctionType createGroupFunctionType(MLIRContext *ctx) {
+ return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
+ }
+
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
}
@@ -55,18 +65,34 @@ struct AsyncAPI {
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
}
+ static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
+ return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
+ }
+
static FunctionType executeFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo();
return FunctionType::get({hdl, resume}, {}, ctx);
}
+ static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
+ auto i64 = IntegerType::get(64, ctx);
+ return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
+ ctx);
+ }
+
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo();
return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
}
+ static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
+ auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto resume = resumeFunctionType(ctx).getPointerTo();
+ return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
+ }
+
// Auxiliary coroutine resume intrinsic wrapper.
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
@@ -87,6 +113,10 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
builder.create<FuncOp>(loc, kCreateToken,
AsyncAPI::createTokenFunctionType(ctx));
+ if (!module.lookupSymbol(kCreateGroup))
+ builder.create<FuncOp>(loc, kCreateGroup,
+ AsyncAPI::createGroupFunctionType(ctx));
+
if (!module.lookupSymbol(kEmplaceToken))
builder.create<FuncOp>(loc, kEmplaceToken,
AsyncAPI::emplaceTokenFunctionType(ctx));
@@ -95,12 +125,24 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
builder.create<FuncOp>(loc, kAwaitToken,
AsyncAPI::awaitTokenFunctionType(ctx));
+ if (!module.lookupSymbol(kAwaitGroup))
+ builder.create<FuncOp>(loc, kAwaitGroup,
+ AsyncAPI::awaitGroupFunctionType(ctx));
+
if (!module.lookupSymbol(kExecute))
builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
+ if (!module.lookupSymbol(kAddTokenToGroup))
+ builder.create<FuncOp>(loc, kAddTokenToGroup,
+ AsyncAPI::addTokenToGroupFunctionType(ctx));
+
if (!module.lookupSymbol(kAwaitAndExecute))
builder.create<FuncOp>(loc, kAwaitAndExecute,
AsyncAPI::awaitAndExecuteFunctionType(ctx));
+
+ if (!module.lookupSymbol(kAwaitAllAndExecute))
+ builder.create<FuncOp>(loc, kAwaitAllAndExecute,
+ AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
}
//===----------------------------------------------------------------------===//
@@ -554,8 +596,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
static Type convertType(Type type) {
MLIRContext *ctx = type.getContext();
- // Convert async tokens to opaque pointers.
- if (type.isa<TokenType>())
+ // Convert async tokens and groups to opaque pointers.
+ if (type.isa<TokenType, GroupType>())
return LLVM::LLVMType::getInt8PtrTy(ctx);
return type;
}
@@ -590,28 +632,81 @@ class CallOpOpConversion : public ConversionPattern {
} // namespace
//===----------------------------------------------------------------------===//
-// async.await op lowering to mlirAsyncRuntimeAwaitToken function call.
+// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
//===----------------------------------------------------------------------===//
namespace {
-class AwaitOpLowering : public ConversionPattern {
+class CreateGroupOpLowering : public ConversionPattern {
public:
- explicit AwaitOpLowering(
+ explicit CreateGroupOpLowering(MLIRContext *ctx)
+ : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto retTy = GroupType::get(op->getContext());
+ rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// async.add_to_group op lowering to runtime function call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AddToGroupOpLowering : public ConversionPattern {
+public:
+ explicit AddToGroupOpLowering(MLIRContext *ctx)
+ : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Currently we can only add tokens to the group.
+ auto addToGroup = cast<AddToGroupOp>(op);
+ if (!addToGroup.operand().getType().isa<TokenType>())
+ return failure();
+
+ auto i64 = IntegerType::get(64, op->getContext());
+ rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// async.await and async.await_all op lowerings to the corresponding async
+// runtime function calls.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename AwaitType, typename AwaitableType>
+class AwaitOpLoweringBase : public ConversionPattern {
+protected:
+ explicit AwaitOpLoweringBase(
MLIRContext *ctx,
- const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
- : ConversionPattern(AwaitOp::getOperationName(), 1, ctx),
- outlinedFunctions(outlinedFunctions) {}
+ const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
+ StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
+ : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
+ outlinedFunctions(outlinedFunctions),
+ blockingAwaitFuncName(blockingAwaitFuncName),
+ coroAwaitFuncName(coroAwaitFuncName) {}
+public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // We can only await on the token operand. Async valus are not supported.
- auto await = cast<AwaitOp>(op);
- if (!await.operand().getType().isa<TokenType>())
+ // We can only await on one the `AwaitableType` (for `await` it can be
+ // only a `token`, for `await_all` it is a `group`).
+ auto await = cast<AwaitType>(op);
+ if (!await.operand().getType().template isa<AwaitableType>())
return failure();
- // Check if `async.await` is inside the outlined coroutine function.
- auto func = await.getParentOfType<FuncOp>();
+ // Check if await operation is inside the outlined coroutine function.
+ auto func = await.template getParentOfType<FuncOp>();
auto outlined = outlinedFunctions.find(func);
const bool isInCoroutine = outlined != outlinedFunctions.end();
@@ -620,7 +715,7 @@ class AwaitOpLowering : public ConversionPattern {
// Inside regular function we convert await operation to the blocking
// async API await function call.
if (!isInCoroutine)
- rewriter.create<CallOp>(loc, Type(), kAwaitToken,
+ rewriter.create<CallOp>(loc, Type(), blockingAwaitFuncName,
ValueRange(op->getOperand(0)));
// Inside the coroutine we convert await operation into coroutine suspension
@@ -645,7 +740,7 @@ class AwaitOpLowering : public ConversionPattern {
// the async await argument becomes ready.
SmallVector<Value, 3> awaitAndExecuteArgs = {
await.getOperand(), coro.coroHandle, resumePtr.res()};
- builder.create<CallOp>(loc, Type(), kAwaitAndExecute,
+ builder.create<CallOp>(loc, Type(), coroAwaitFuncName,
awaitAndExecuteArgs);
// Split the entry block before the await operation.
@@ -660,7 +755,32 @@ class AwaitOpLowering : public ConversionPattern {
private:
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+ StringRef blockingAwaitFuncName;
+ StringRef coroAwaitFuncName;
+};
+
+// Lowering for `async.await` operation (only token operands are supported).
+class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
+ using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
+
+public:
+ explicit AwaitOpLowering(
+ MLIRContext *ctx,
+ const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+ : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
};
+
+// Lowering for `async.await_all` operation.
+class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
+ using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
+
+public:
+ explicit AwaitAllOpLowering(
+ MLIRContext *ctx,
+ const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+ : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -717,7 +837,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
patterns.insert<CallOpOpConversion>(ctx);
- patterns.insert<AwaitOpLowering>(ctx, outlinedFunctions);
+ patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
+ patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
diff --git a/mlir/lib/Dialect/Async/CMakeLists.txt b/mlir/lib/Dialect/Async/CMakeLists.txt
index f33061b2d87c..9f57627c321f 100644
--- a/mlir/lib/Dialect/Async/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 0c22ea22a9da..1e84ba3418bb 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -21,6 +21,7 @@ void AsyncDialect::initialize() {
>();
addTypes<TokenType>();
addTypes<ValueType>();
+ addTypes<GroupType>();
}
/// Parse a type registered to this dialect.
@@ -54,6 +55,7 @@ void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
os.printType(valueTy.getValueType());
os << '>';
})
+ .Case<GroupType>([&](GroupType) { os << "group"; })
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
}
@@ -139,6 +141,51 @@ void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
regions.push_back(RegionSuccessor(&body()));
}
+void ExecuteOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTypes, ValueRange dependencies,
+ ValueRange operands, BodyBuilderFn bodyBuilder) {
+
+ result.addOperands(dependencies);
+ result.addOperands(operands);
+
+ // Add derived `operand_segment_sizes` attribute based on parsed operands.
+ int32_t numDependencies = dependencies.size();
+ int32_t numOperands = operands.size();
+ auto operandSegmentSizes = DenseIntElementsAttr::get(
+ VectorType::get({2}, IntegerType::get(32, result.getContext())),
+ {numDependencies, numOperands});
+ result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
+
+ // First result is always a token, and then `resultTypes` wrapped into
+ // `async.value`.
+ result.addTypes({TokenType::get(result.getContext())});
+ for (Type type : resultTypes)
+ result.addTypes(ValueType::get(type));
+
+ // Add a body region with block arguments as unwrapped async value operands.
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ for (Value operand : operands) {
+ auto valueType = operand.getType().dyn_cast<ValueType>();
+ bodyBlock.addArgument(valueType ? valueType.getValueType()
+ : operand.getType());
+ }
+
+ // Create the default terminator if the builder is not provided and if the
+ // expected result is empty. Otherwise, leave this to the caller
+ // because we don't know which values to return from the execute op.
+ if (resultTypes.empty() && !bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ builder.create<async::YieldOp>(result.location, ValueRange());
+ } else if (bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock.getArguments());
+ }
+}
+
static void print(OpAsmPrinter &p, ExecuteOp op) {
p << op.getOperationName();
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
new file mode 100644
index 000000000000..c6508610c796
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -0,0 +1,278 @@
+//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements scf.parallel to src.for + async.execute conversion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-parallel-for"
+
+namespace {
+
+// Rewrite scf.parallel operation into multiple concurrent async.execute
+// operations over non overlapping subranges of the original loop.
+//
+// Example:
+//
+// scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
+// "do_some_compute"(%i, %j): () -> ()
+// }
+//
+// Converted to:
+//
+// %c0 = constant 0 : index
+// %c1 = constant 1 : index
+//
+// // Compute blocks sizes for each induction variable.
+// %num_blocks_i = ... : index
+// %num_blocks_j = ... : index
+// %block_size_i = ... : index
+// %block_size_j = ... : index
+//
+// // Create an async group to track async execute ops.
+// %group = async.create_group
+//
+// scf.for %bi = %c0 to %num_blocks_i step %c1 {
+// %block_start_i = ... : index
+// %block_end_i = ... : index
+//
+// scf.for %bj = %c0 to %num_blocks_j step %c1 {
+// %block_start_j = ... : index
+// %block_end_j = ... : index
+//
+// // Execute the body of original parallel operation for the current
+// // block.
+// %token = async.execute {
+// scf.for %i = %block_start_i to %block_end_i step %si {
+// scf.for %j = %block_start_j to %block_end_j step %sj {
+// "do_some_compute"(%i, %j): () -> ()
+// }
+// }
+// }
+//
+// // Add produced async token to the group.
+// async.add_to_group %token, %group
+// }
+// }
+//
+// // Await completion of all async.execute operations.
+// async.await_all %group
+//
+// In this example outer loop launches inner block level loops as separate async
+// execute operations which will be executed concurrently.
+//
+// At the end it waits for the completiom of all async execute operations.
+//
+struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
+public:
+ AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute)
+ : OpRewritePattern(ctx),
+ numConcurrentAsyncExecute(numConcurrentAsyncExecute) {}
+
+ LogicalResult matchAndRewrite(scf::ParallelOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ int numConcurrentAsyncExecute;
+};
+
+struct AsyncParallelForPass
+ : public AsyncParallelForBase<AsyncParallelForPass> {
+ AsyncParallelForPass() = default;
+ void runOnFunction() override;
+};
+
+} // namespace
+
+LogicalResult
+AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
+ PatternRewriter &rewriter) const {
+ // We do not currently support rewrite for parallel op with reductions.
+ if (op.getNumReductions() != 0)
+ return failure();
+
+ MLIRContext *ctx = op.getContext();
+ Location loc = op.getLoc();
+
+ // Index constants used below.
+ auto indexTy = IndexType::get(ctx);
+ auto zero = IntegerAttr::get(indexTy, 0);
+ auto one = IntegerAttr::get(indexTy, 1);
+ auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero);
+ auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one);
+
+ // Shorthand for signed integer ceil division operation.
+ auto divup = [&](Value x, Value y) -> Value {
+ return rewriter.create<SignedCeilDivIOp>(loc, x, y);
+ };
+
+ // Compute trip count for each loop induction variable:
+ // tripCount = divUp(upperBound - lowerBound, step);
+ SmallVector<Value, 4> tripCounts(op.getNumLoops());
+ for (size_t i = 0; i < op.getNumLoops(); ++i) {
+ auto lb = op.lowerBound()[i];
+ auto ub = op.upperBound()[i];
+ auto step = op.step()[i];
+ auto range = rewriter.create<SubIOp>(loc, ub, lb);
+ tripCounts[i] = divup(range, step);
+ }
+
+ // The target number of concurrent async.execute ops.
+ auto numExecuteOps = rewriter.create<ConstantOp>(
+ loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute));
+
+ // Blocks sizes configuration for each induction variable.
+
+ // We try to use maximum available concurrency in outer dimensions first
+ // (assuming that parallel induction variables are corresponding to some
+ // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>)
+ // we will try to parallelize iteration along the %d0. If %d0 is too small,
+ // we'll parallelize iteration over %d1, and so on.
+ SmallVector<Value, 4> targetNumBlocks(op.getNumLoops());
+ SmallVector<Value, 4> blockSize(op.getNumLoops());
+ SmallVector<Value, 4> numBlocks(op.getNumLoops());
+
+ // Compute block size and number of blocks along the first induction variable.
+ targetNumBlocks[0] = numExecuteOps;
+ blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]);
+ numBlocks[0] = divup(tripCounts[0], blockSize[0]);
+
+ // Assign remaining available concurrency to other induction variables.
+ for (size_t i = 1; i < op.getNumLoops(); ++i) {
+ targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]);
+ blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]);
+ numBlocks[i] = divup(tripCounts[i], blockSize[i]);
+ }
+
+ // Create an async.group to wait on all async tokens from async execute ops.
+ auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
+
+ // Build a scf.for loop nest from the parallel operation.
+
+ // Lower/upper bounds for nest block level computations.
+ SmallVector<Value, 4> blockLowerBounds(op.getNumLoops());
+ SmallVector<Value, 4> blockUpperBounds(op.getNumLoops());
+ SmallVector<Value, 4> blockInductionVars(op.getNumLoops());
+
+ using LoopBodyBuilder =
+ std::function<void(OpBuilder &, Location, Value, ValueRange)>;
+ using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
+
+ // Builds inner loop nest inside async.execute operation that does all the
+ // work concurrently.
+ LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
+ return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ blockInductionVars[loopIdx] = iv;
+
+ // Continute building async loop nest.
+ if (loopIdx < op.getNumLoops() - 1) {
+ b.create<scf::ForOp>(
+ loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1],
+ op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1));
+ b.create<scf::YieldOp>(loc);
+ return;
+ }
+
+ // Copy the body of the parallel op with new loop bounds.
+ BlockAndValueMapping mapping;
+ mapping.map(op.getInductionVars(), blockInductionVars);
+
+ for (auto &bodyOp : op.getLoopBody().getOps())
+ b.clone(bodyOp, mapping);
+ };
+ };
+
+ // Builds a loop nest that does async execute op dispatching.
+ LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
+ return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ auto lb = op.lowerBound()[loopIdx];
+ auto ub = op.upperBound()[loopIdx];
+ auto step = op.step()[loopIdx];
+
+ // Compute lower bound for the current block:
+ // blockLowerBound = iv * blockSize * step + lowerBound
+ auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]);
+ auto s1 = b.create<MulIOp>(loc, s0, step);
+ auto s2 = b.create<AddIOp>(loc, s1, lb);
+ blockLowerBounds[loopIdx] = s2;
+
+ // Compute upper bound for the current block:
+ // blockUpperBound = min(upperBound,
+ // blockLowerBound + blockSize * step)
+ auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step);
+ auto e1 = b.create<AddIOp>(loc, e0, s2);
+ auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub);
+ auto e3 = b.create<SelectOp>(loc, e2, e1, ub);
+ blockUpperBounds[loopIdx] = e3;
+
+ // Continue building async dispatch loop nest.
+ if (loopIdx < op.getNumLoops() - 1) {
+ b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(),
+ asyncLoopBuilder(loopIdx + 1));
+ b.create<scf::YieldOp>(loc);
+ return;
+ }
+
+ // Build the inner loop nest that will do the actual work inside the
+ // `async.execute` body region.
+ auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
+ Location executeLoc,
+ ValueRange executeArgs) {
+ executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0],
+ blockUpperBounds[0], op.step()[0],
+ ValueRange(), workLoopBuilder(0));
+ executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
+ };
+
+ auto execute = b.create<ExecuteOp>(
+ loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(),
+ /*operands=*/ValueRange(), executeBodyBuilder);
+ auto rankType = IndexType::get(ctx);
+ b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result());
+ b.create<scf::YieldOp>(loc);
+ };
+ };
+
+ // Start building a loop nest from the first induction variable.
+ rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(),
+ asyncLoopBuilder(0));
+
+ // Wait for the completion of all subtasks.
+ rewriter.create<AwaitAllOp>(loc, group.result());
+
+ // Erase the original parallel operation.
+ rewriter.eraseOp(op);
+
+ return success();
+}
+
+void AsyncParallelForPass::runOnFunction() {
+ MLIRContext *ctx = &getContext();
+
+ OwningRewritePatternList patterns;
+ patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
+
+ if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
+ return std::make_unique<AsyncParallelForPass>();
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..9de43873039d
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRAsyncTransforms
+ AsyncParallelFor.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
+
+ DEPENDS
+ MLIRAsyncPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRAsync
+ MLIRSCF
+ MLIRPass
+ MLIRTransforms
+ MLIRTransformUtils
+)
diff --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.h b/mlir/lib/Dialect/Async/Transforms/PassDetail.h
new file mode 100644
index 000000000000..c047eaf383d9
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.h
@@ -0,0 +1,30 @@
+//===- PassDetail.h - Async Pass class details ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_
+#define DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace async {
+class AsyncDialect;
+} // namespace async
+
+namespace scf {
+class SCFDialect;
+} // namespace scf
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/Async/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 9af1a8d89020..332c7ff1e2b9 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -15,6 +15,7 @@
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
+#include <atomic>
#include <condition_variable>
#include <functional>
#include <iostream>
@@ -33,12 +34,50 @@ struct AsyncToken {
std::vector<std::function<void()>> awaiters;
};
+struct AsyncGroup {
+ std::atomic<int> pendingTokens{0};
+ std::atomic<int> rank{0};
+ std::mutex mu;
+ std::condition_variable cv;
+ std::vector<std::function<void()>> awaiters;
+};
+
// Create a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken;
return token;
}
+// Create a new `async.group` in empty state.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
+ AsyncGroup *group = new AsyncGroup;
+ return group;
+}
+
+extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
+mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
+ std::unique_lock<std::mutex> lockToken(token->mu);
+ std::unique_lock<std::mutex> lockGroup(group->mu);
+
+ group->pendingTokens.fetch_add(1);
+
+ auto onTokenReady = [group]() {
+ // Run all group awaiters if it was the last token in the group.
+ if (group->pendingTokens.fetch_sub(1) == 1) {
+ group->cv.notify_all();
+ for (auto &awaiter : group->awaiters)
+ awaiter();
+ }
+ };
+
+ if (token->ready)
+ onTokenReady();
+ else
+ token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
+
+ return group->rank.fetch_add(1);
+}
+
// Switches `async.token` to ready state and runs all awaiters.
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
@@ -52,7 +91,13 @@ extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!token->ready)
token->cv.wait(lock, [token] { return token->ready; });
- delete token;
+}
+
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
+ std::unique_lock<std::mutex> lock(group->mu);
+ if (group->pendingTokens != 0)
+ group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
@@ -69,9 +114,8 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
- auto execute = [token, handle, resume]() {
+ auto execute = [handle, resume]() {
mlirAsyncRuntimeExecute(handle, resume);
- delete token;
};
if (token->ready)
@@ -80,6 +124,21 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
token->awaiters.push_back([execute]() { execute(); });
}
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
+ CoroResume resume) {
+ std::unique_lock<std::mutex> lock(group->mu);
+
+ auto execute = [handle, resume]() {
+ mlirAsyncRuntimeExecute(handle, resume);
+ };
+
+ if (group->pendingTokens == 0)
+ execute();
+ else
+ group->awaiters.push_back([execute]() { execute(); });
+}
+
//===----------------------------------------------------------------------===//
// Small async runtime support library for testing.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index aa912baef566..1fd71a65379e 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -156,4 +156,43 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: store %arg1, %arg2[%c0] : memref<1xf32>
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
+// -----
+
+// CHECK-LABEL: async_group_await_all
+func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
+ // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup()
+ %0 = async.create_group
+
+ // CHECK: %[[TOKEN:.*]] = call @async_execute_fn
+ %token = async.execute { async.yield }
+ // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0)
+ async.add_to_group %token, %0 : !async.token
+
+ // CHECK: call @async_execute_fn_0
+ async.execute {
+ async.await_all %0
+ async.yield
+ }
+
+ // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0)
+ async.await_all %0
+
+ return
+}
+
+// Function outlined from the async.execute operation.
+// CHECK: func private @async_execute_fn_0(%arg0: !llvm.ptr<i8>)
+// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
+// Suspend coroutine in the beginning.
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL_1]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Suspend coroutine second time waiting for the group.
+// CHECK: llvm.call @llvm.coro.save
+// CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute(%arg0, %[[HDL_1]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Emplace result token.
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
diff --git a/mlir/test/Dialect/Async/async-parallel-for.mlir b/mlir/test/Dialect/Async/async-parallel-for.mlir
new file mode 100644
index 000000000000..ad21650ab0ee
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-parallel-for.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -async-parallel-for | FileCheck %s
+
+// CHECK-LABEL: @loop_1d
+func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
+ // CHECK: %[[GROUP:.*]] = async.create_group
+ // CHECK: scf.for
+ // CHECK: %[[TOKEN:.*]] = async.execute {
+ // CHECK: scf.for
+ // CHECK: store
+ // CHECK: async.yield
+ // CHECK: }
+ // CHECK: async.add_to_group %[[TOKEN]], %[[GROUP]]
+ // CHECK: async.await_all %[[GROUP]]
+ scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
+ %one = constant 1.0 : f32
+ store %one, %arg3[%i] : memref<?xf32>
+ }
+
+ return
+}
+
+// CHECK-LABEL: @loop_2d
+func @loop_2d(%arg0: index, %arg1: index, %arg2: index, // lb, ub, step
+ %arg3: index, %arg4: index, %arg5: index, // lb, ub, step
+ %arg6: memref<?x?xf32>) {
+ // CHECK: %[[GROUP:.*]] = async.create_group
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: %[[TOKEN:.*]] = async.execute {
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: store
+ // CHECK: async.yield
+ // CHECK: }
+ // CHECK: async.add_to_group %[[TOKEN]], %[[GROUP]]
+ // CHECK: async.await_all %[[GROUP]]
+ scf.parallel (%i0, %i1) = (%arg0, %arg3) to (%arg1, %arg4)
+ step (%arg2, %arg5) {
+ %one = constant 1.0 : f32
+ store %one, %arg6[%i0, %i1] : memref<?x?xf32>
+ }
+
+ return
+}
diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index 8784b6f05a08..a95be650eff7 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -120,3 +120,17 @@ func @await_value(%arg0: !async.value<f32>) -> f32 {
%0 = async.await %arg0 : !async.value<f32>
return %0 : f32
}
+
+// CHECK-LABEL: @create_group_and_await_all
+func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>) -> index {
+ %0 = async.create_group
+
+ // CHECK: async.add_to_group %arg0
+ // CHECK: async.add_to_group %arg1
+ %1 = async.add_to_group %arg0, %0 : !async.token
+ %2 = async.add_to_group %arg1, %0 : !async.value<f32>
+ async.await_all %0
+
+ %3 = addi %1, %2 : index
+ return %3 : index
+}
diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir
new file mode 100644
index 000000000000..0ae378b45be1
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-group.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -convert-async-to-llvm \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+ %group = async.create_group
+
+ %token0 = async.execute { async.yield }
+ %token1 = async.execute { async.yield }
+ %token2 = async.execute { async.yield }
+ %token3 = async.execute { async.yield }
+ %token4 = async.execute { async.yield }
+
+ %0 = async.add_to_group %token0, %group : !async.token
+ %1 = async.add_to_group %token1, %group : !async.token
+ %2 = async.add_to_group %token2, %group : !async.token
+ %3 = async.add_to_group %token3, %group : !async.token
+ %4 = async.add_to_group %token4, %group : !async.token
+
+ %token5 = async.execute {
+ async.await_all %group
+ async.yield
+ }
+
+ %group0 = async.create_group
+ %5 = async.add_to_group %token5, %group0 : !async.token
+ async.await_all %group0
+
+ // CHECK: Current thread id: [[THREAD:.*]]
+ call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
+
+ return
+}
+
+func @mlirAsyncRuntimePrintCurrentThreadId() -> ()
More information about the Mlir-commits
mailing list