[Mlir-commits] [mlir] a8f819c - [mlir:Async] Remove async operations if it is statically known that the parallel operation has a single compute block
Eugene Zhulenev
llvmlistbot at llvm.org
Tue Jun 29 09:26:35 PDT 2021
Author: Eugene Zhulenev
Date: 2021-06-29T09:26:28-07:00
New Revision: a8f819c6d85e1990954d8846dac769bb789d2ba9
URL: https://github.com/llvm/llvm-project/commit/a8f819c6d85e1990954d8846dac769bb789d2ba9
DIFF: https://github.com/llvm/llvm-project/commit/a8f819c6d85e1990954d8846dac769bb789d2ba9.diff
LOG: [mlir:Async] Remove async operations if it is statically known that the parallel operation has a single compute block
Depends On D104850
Add a test that verifies that canonicalization removes all async overheads if it is statically known that the scf.parallel operation will be computed using a single block.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D104891
Added:
mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir
Modified:
mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index d84b8f8ea98a6..0783009d2855c 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index f9ddd67a7961d..d168b8cefad8a 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -177,6 +177,8 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
let arguments = (ins Index:$size);
let results = (outs Async_GroupType:$result);
+ let hasCanonicalizeMethod = 1;
+
let assemblyFormat = "$size `:` type($result) attr-dict";
}
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index a06b2b6664690..bd627edbd4271 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -245,6 +245,36 @@ static LogicalResult verify(ExecuteOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+/// CreateGroupOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
+ PatternRewriter &rewriter) {
+ // Find all `await_all` users of the group.
+ llvm::SmallVector<AwaitAllOp> awaitAllUsers;
+
+ auto isAwaitAll = [&](Operation *op) -> bool {
+ if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
+ awaitAllUsers.push_back(awaitAll);
+ return true;
+ }
+ return false;
+ };
+
+ // Check if all users of the group are `await_all` operations.
+ if (!llvm::all_of(op->getUsers(), isAwaitAll))
+ return failure();
+
+ // If group is only awaited without adding anything to it, we can safely erase
+ // the create operation and all users.
+ for (AwaitAllOp awaitAll : awaitAllUsers)
+ rewriter.eraseOp(awaitAll);
+ rewriter.eraseOp(op);
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
/// AwaitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 1d545a52f7152..a104fb73571d9 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -513,18 +513,48 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
Value groupSize = b.create<SubIOp>(blockCount, c1);
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
- // Pack the async dispath function operands to launch the work splitting.
- SmallVector<Value> asyncDispatchOperands = {group, c0, blockCount, blockSize};
- asyncDispatchOperands.append(tripCounts);
- asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end());
- asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end());
- asyncDispatchOperands.append(op.step().begin(), op.step().end());
- asyncDispatchOperands.append(parallelComputeFunction.captures);
-
- // Launch async dispatch function for [0, blockCount) range.
- b.create<CallOp>(asyncDispatchFunction.sym_name(),
- asyncDispatchFunction.getCallableResults(),
- asyncDispatchOperands);
+ // Appends operands shared by async dispatch and parallel compute functions to
+ // the given operands vector.
+ auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
+ operands.append(tripCounts);
+ operands.append(op.lowerBound().begin(), op.lowerBound().end());
+ operands.append(op.upperBound().begin(), op.upperBound().end());
+ operands.append(op.step().begin(), op.step().end());
+ operands.append(parallelComputeFunction.captures);
+ };
+
+ // Check if the block size is one, in this case we can skip the async dispatch
+ // completely. If this will be known statically, then canonicalization will
+ // erase async group operations.
+ Value isSingleBlock = b.create<CmpIOp>(CmpIPredicate::eq, blockCount, c1);
+
+ auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+ ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+ // Call parallel compute function for the single block.
+ SmallVector<Value> operands = {c0, blockSize};
+ appendBlockComputeOperands(operands);
+
+ nb.create<CallOp>(parallelComputeFunction.func.sym_name(),
+ parallelComputeFunction.func.getCallableResults(),
+ operands);
+ nb.create<scf::YieldOp>();
+ };
+
+ auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+ ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+ // Launch async dispatch function for [0, blockCount) range.
+ SmallVector<Value> operands = {group, c0, blockCount, blockSize};
+ appendBlockComputeOperands(operands);
+
+ nb.create<CallOp>(asyncDispatchFunction.sym_name(),
+ asyncDispatchFunction.getCallableResults(), operands);
+ nb.create<scf::YieldOp>();
+ };
+
+ // Dispatch either single block compute function, or launch async dispatch.
+ b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
// Wait for the completion of all parallel compute operations.
b.create<AwaitAllOp>(group);
diff --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
index df538a4fc7661..a6e308e422e20 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
@@ -3,8 +3,13 @@
// CHECK-LABEL: @loop_1d
func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
+ // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[GROUP:.*]] = async.create_group
- // CHECK: call @async_dispatch_fn
+ // CHECK: scf.if {{.*}} {
+ // CHECK: call @parallel_compute_fn(%[[C0]]
+ // CHECK: } else {
+ // CHECK: call @async_dispatch_fn
+ // CHECK: }
// CHECK: async.await_all %[[GROUP]]
scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
%one = constant 1.0 : f32
diff --git a/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir b/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir
new file mode 100644
index 0000000000000..e26d99006b55a
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s \
+// RUN: -async-parallel-for=async-dispatch=true \
+// RUN: -canonicalize -inline -symbol-dce \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s \
+// RUN: -async-parallel-for=async-dispatch=false \
+// RUN: -canonicalize -inline -symbol-dce \
+// RUN: | FileCheck %s
+
+// Check that if we statically know that the parallel operation has a single
+// block then all async operations will be canonicalized away and we will
+// end up with a single synchonous compute function call.
+
+// CHECK-LABEL: @loop_1d(
+// CHECK: %[[MEMREF:.*]]: memref<?xf32>
+func @loop_1d(%arg0: memref<?xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK-DAG: %[[C100:.*]] = constant 100 : index
+ // CHECK-DAG: %[[ONE:.*]] = constant 1.000000e+00 : f32
+ // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]]
+ // CHECK: memref.store %[[ONE]], %[[MEMREF]][%[[I]]]
+ %lb = constant 0 : index
+ %ub = constant 100 : index
+ %st = constant 1 : index
+ scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+ %one = constant 1.0 : f32
+ memref.store %one, %arg0[%i] : memref<?xf32>
+ }
+
+ return
+}
More information about the Mlir-commits
mailing list