[Mlir-commits] [mlir] b537c5b - [mlir] Async: clone constants into async.execute functions and parallel compute functions
Eugene Zhulenev
llvmlistbot at llvm.org
Mon Aug 2 12:17:48 PDT 2021
Author: Eugene Zhulenev
Date: 2021-08-02T12:17:41-07:00
New Revision: b537c5b4147b6966fda8d80ed291f6b1f3857b16
URL: https://github.com/llvm/llvm-project/commit/b537c5b4147b6966fda8d80ed291f6b1f3857b16
DIFF: https://github.com/llvm/llvm-project/commit/b537c5b4147b6966fda8d80ed291f6b1f3857b16.diff
LOG: [mlir] Async: clone constants into async.execute functions and parallel compute functions
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D107007
Added:
mlir/lib/Dialect/Async/Transforms/PassDetail.cpp
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
mlir/lib/Dialect/Async/Transforms/PassDetail.h
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Async/async-to-async-runtime.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index a8858913cc1fb..cfc1968d523d1 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -190,6 +190,10 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
ModuleOp module = op->getParentOfType<ModuleOp>();
+ // Make sure that all constants will be inside the parallel operation body to
+ // reduce the number of parallel compute function arguments.
+ cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
+
ParallelComputeFunctionType computeFuncType =
getParallelComputeFunctionType(op, rewriter);
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 10dcba1f30444..9e70853b3fa38 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -235,6 +235,10 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
MLIRContext *ctx = module.getContext();
Location loc = execute.getLoc();
+ // Make sure that all constants will be inside the outlined async function to
+ // reduce the number of function arguments.
+ cloneConstantsIntoTheRegion(execute.body());
+
// Collect all outlined function inputs.
SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());
diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
index b3aa8a978e560..222db220208f7 100644
--- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
AsyncRuntimeRefCounting.cpp
AsyncRuntimeRefCountingOpt.cpp
AsyncToAsyncRuntime.cpp
+ PassDetail.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
diff --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp b/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp
new file mode 100644
index 0000000000000..eaac566716a43
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp
@@ -0,0 +1,43 @@
+//===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+using namespace mlir;
+
+void mlir::async::cloneConstantsIntoTheRegion(Region ®ion) {
+ OpBuilder builder(®ion);
+ cloneConstantsIntoTheRegion(region, builder);
+}
+
+void mlir::async::cloneConstantsIntoTheRegion(Region ®ion,
+ OpBuilder &builder) {
+ // Values implicitly captured by the region.
+ llvm::SetVector<Value> captures;
+ getUsedValuesDefinedAbove(region, region, captures);
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(®ion.front());
+
+ // Clone ConstantLike operations into the region.
+ for (Value capture : captures) {
+ Operation *op = capture.getDefiningOp();
+ if (!op || !op->hasTrait<OpTrait::ConstantLike>())
+ continue;
+
+ Operation *cloned = builder.clone(*op);
+
+ for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) {
+ Value orig = std::get<0>(tuple);
+ Value replacement = std::get<1>(tuple);
+ replaceAllUsesInRegionWith(orig, replacement, region);
+ }
+ }
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.h b/mlir/lib/Dialect/Async/Transforms/PassDetail.h
index c047eaf383d9c..2065c14e13148 100644
--- a/mlir/lib/Dialect/Async/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.h
@@ -25,6 +25,24 @@ class SCFDialect;
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Async/Passes.h.inc"
+// -------------------------------------------------------------------------- //
+// Utility functions shared by Async Transformations.
+// -------------------------------------------------------------------------- //
+
+// Forward declarations.
+class OpBuilder;
+
+namespace async {
+
+/// Clone ConstantLike operations that are defined above the given region and
+/// have users in the region into the region entry block. We do that to reduce
+/// the number of function arguments when we outline `async.execute` and
+/// `scf.parallel` operations body into functions.
+void cloneConstantsIntoTheRegion(Region ®ion);
+void cloneConstantsIntoTheRegion(Region ®ion, OpBuilder &builder);
+
+} // namespace async
+
} // namespace mlir
#endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index da96f306f0bf0..f8afa39060a9f 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -89,13 +89,14 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
}
// Function outlined from the inner async.execute operation.
-// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
+// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
// CHECK-SAME: -> !llvm.ptr<i8>
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
// CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
// CHECK: call @mlirAsyncRuntimeExecute
// CHECK: llvm.intr.coro.suspend
-// CHECK: memref.store %arg0, %arg1[%arg2] : memref<1xf32>
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: memref.store %arg0, %arg1[%[[C0]]] : memref<1xf32>
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
// Function outlined from the outer async.execute operation.
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
new file mode 100644
index 0000000000000..34e9434dd9665
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s \
+// RUN: -async-parallel-for=async-dispatch=true \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s \
+// RUN: -async-parallel-for=async-dispatch=false \
+// RUN: -canonicalize -inline -symbol-dce \
+// RUN: | FileCheck %s
+
+// Check that constants defined outside of the `scf.parallel` body will be
+// sunk into the parallel compute function to avoid blowing up the number
+// of parallel compute function arguments.
+
+// CHECK-LABEL: func @clone_constant(
+func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
+ %one = constant 1.0 : f32
+
+ scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+ memref.store %one, %arg0[%i] : memref<?xf32>
+ }
+
+ return
+}
+
+// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
+// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
+// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index,
+// CHECK-SAME: %[[LB:arg[0-9]+]]: index,
+// CHECK-SAME: %[[UB:arg[0-9]+]]: index,
+// CHECK-SAME: %[[STEP:arg[0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?xf32>
+// CHECK-SAME: ) {
+// CHECK: %[[CST:.*]] = constant 1.0{{.*}} : f32
+// CHECK: scf.for
+// CHECK: memref.store %[[CST]], %[[MEMREF]]
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 661d208e17662..9c61394aa8ed9 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -406,3 +406,26 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
// Check that structured control flow lowered to CFG.
// CHECK-NOT: scf.if
// CHECK: cond_br %[[FLAG]]
+
+// -----
+// Constants captured by the async.execute region should be cloned into the
+// outline async execute function.
+
+// CHECK-LABEL: @clone_constants
+func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
+ %c0 = constant 0 : index
+ %token = async.execute {
+ memref.store %arg0, %arg1[%c0] : memref<1xf32>
+ async.yield
+ }
+ async.await %token : !async.token
+ return
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(
+// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32,
+// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32>
+// CHECK-SAME: ) -> !async.token
+// CHECK: %[[CST:.*]] = constant 0 : index
+// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]
More information about the Mlir-commits
mailing list