[Mlir-commits] [mlir] b537c5b - [mlir] Async: clone constants into async.execute functions and parallel compute functions

Eugene Zhulenev llvmlistbot at llvm.org
Mon Aug 2 12:17:48 PDT 2021


Author: Eugene Zhulenev
Date: 2021-08-02T12:17:41-07:00
New Revision: b537c5b4147b6966fda8d80ed291f6b1f3857b16

URL: https://github.com/llvm/llvm-project/commit/b537c5b4147b6966fda8d80ed291f6b1f3857b16
DIFF: https://github.com/llvm/llvm-project/commit/b537c5b4147b6966fda8d80ed291f6b1f3857b16.diff

LOG: [mlir] Async: clone constants into async.execute functions and parallel compute functions

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D107007

Added: 
    mlir/lib/Dialect/Async/Transforms/PassDetail.cpp
    mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Async/Transforms/PassDetail.h
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/Async/async-to-async-runtime.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index a8858913cc1fb..cfc1968d523d1 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -190,6 +190,10 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
 
   ModuleOp module = op->getParentOfType<ModuleOp>();
 
+  // Make sure that all constants will be inside the parallel operation body to
+  // reduce the number of parallel compute function arguments.
+  cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
+
   ParallelComputeFunctionType computeFuncType =
       getParallelComputeFunctionType(op, rewriter);
 

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 10dcba1f30444..9e70853b3fa38 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -235,6 +235,10 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   MLIRContext *ctx = module.getContext();
   Location loc = execute.getLoc();
 
+  // Make sure that all constants will be inside the outlined async function to
+  // reduce the number of function arguments.
+  cloneConstantsIntoTheRegion(execute.body());
+
   // Collect all outlined function inputs.
   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
                                         execute.dependencies().end());

diff  --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
index b3aa8a978e560..222db220208f7 100644
--- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
   AsyncRuntimeRefCounting.cpp
   AsyncRuntimeRefCountingOpt.cpp
   AsyncToAsyncRuntime.cpp
+  PassDetail.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async

diff  --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp b/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp
new file mode 100644
index 0000000000000..eaac566716a43
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp
@@ -0,0 +1,43 @@
+//===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+using namespace mlir;
+
+void mlir::async::cloneConstantsIntoTheRegion(Region &region) {
+  OpBuilder builder(&region);
+  cloneConstantsIntoTheRegion(region, builder);
+}
+
+void mlir::async::cloneConstantsIntoTheRegion(Region &region,
+                                              OpBuilder &builder) {
+  // Values implicitly captured by the region.
+  llvm::SetVector<Value> captures;
+  getUsedValuesDefinedAbove(region, region, captures);
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&region.front());
+
+  // Clone ConstantLike operations into the region.
+  for (Value capture : captures) {
+    Operation *op = capture.getDefiningOp();
+    if (!op || !op->hasTrait<OpTrait::ConstantLike>())
+      continue;
+
+    Operation *cloned = builder.clone(*op);
+
+    for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) {
+      Value orig = std::get<0>(tuple);
+      Value replacement = std::get<1>(tuple);
+      replaceAllUsesInRegionWith(orig, replacement, region);
+    }
+  }
+}

diff  --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.h b/mlir/lib/Dialect/Async/Transforms/PassDetail.h
index c047eaf383d9c..2065c14e13148 100644
--- a/mlir/lib/Dialect/Async/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.h
@@ -25,6 +25,24 @@ class SCFDialect;
 #define GEN_PASS_CLASSES
 #include "mlir/Dialect/Async/Passes.h.inc"
 
+// -------------------------------------------------------------------------- //
+// Utility functions shared by Async Transformations.
+// -------------------------------------------------------------------------- //
+
+// Forward declarations.
+class OpBuilder;
+
+namespace async {
+
+/// Clone ConstantLike operations that are defined above the given region and
+/// have users in the region into the region entry block. We do that to reduce
+/// the number of function arguments when we outline `async.execute` and
+/// `scf.parallel` operations body into functions.
+void cloneConstantsIntoTheRegion(Region &region);
+void cloneConstantsIntoTheRegion(Region &region, OpBuilder &builder);
+
+} // namespace async
+
 } // namespace mlir
 
 #endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index da96f306f0bf0..f8afa39060a9f 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -89,13 +89,14 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 }
 
 // Function outlined from the inner async.execute operation.
-// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
+// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
 // CHECK-SAME: -> !llvm.ptr<i8>
 // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
 // CHECK: call @mlirAsyncRuntimeExecute
 // CHECK: llvm.intr.coro.suspend
-// CHECK: memref.store %arg0, %arg1[%arg2] : memref<1xf32>
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: memref.store %arg0, %arg1[%[[C0]]] : memref<1xf32>
 // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
 
 // Function outlined from the outer async.execute operation.

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
new file mode 100644
index 0000000000000..34e9434dd9665
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s                                                            \
+// RUN:    -async-parallel-for=async-dispatch=true                             \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s                                                            \
+// RUN:    -async-parallel-for=async-dispatch=false                            \
+// RUN:    -canonicalize -inline -symbol-dce                                   \
+// RUN: | FileCheck %s
+
+// Check that constants defined outside of the `scf.parallel` body will be
+// sunk into the parallel compute function to avoid blowing up the number
+// of parallel compute function arguments.
+
+// CHECK-LABEL: func @clone_constant(
+func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
+  %one = constant 1.0 : f32
+
+  scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+    memref.store %one, %arg0[%i] : memref<?xf32>
+  }
+
+  return
+}
+
+// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[TRIP_COUNT:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[LB:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[UB:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[STEP:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?xf32>
+// CHECK-SAME: ) {
+// CHECK:        %[[CST:.*]] = constant 1.0{{.*}} : f32
+// CHECK:        scf.for
+// CHECK:          memref.store %[[CST]], %[[MEMREF]]

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 661d208e17662..9c61394aa8ed9 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -406,3 +406,26 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
 // Check that structured control flow lowered to CFG.
 // CHECK-NOT: scf.if
 // CHECK: cond_br %[[FLAG]]
+
+// -----
+// Constants captured by the async.execute region should be cloned into the
+// outline async execute function.
+
+// CHECK-LABEL: @clone_constants
+func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
+  %c0 = constant 0 : index
+  %token = async.execute {
+    memref.store %arg0, %arg1[%c0] : memref<1xf32>
+    async.yield
+  }
+  async.await %token : !async.token
+  return
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(
+// CHECK-SAME:    %[[VALUE:arg[0-9]+]]: f32,
+// CHECK-SAME:    %[[MEMREF:arg[0-9]+]]: memref<1xf32>
+// CHECK-SAME:  ) -> !async.token
+// CHECK:         %[[CST:.*]] = constant 0 : index
+// CHECK:         memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]


        


More information about the Mlir-commits mailing list