[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add an option to eliminate AllocOp and avoid Copy (PR #90011)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 28 19:45:42 PDT 2024


https://github.com/Menooker updated https://github.com/llvm/llvm-project/pull/90011

>From 5a15f778474d42c29756c5b0c3bb7917fac7dcf4 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Mon, 29 Apr 2024 10:39:19 +0800
Subject: [PATCH] [MLIR][Bufferization] BufferResultsToOutParams: Add an option
 to eliminate AllocOp and avoid Copy

Add an option hoist-static-allocs to remove the unnecessary memref.alloc and memref.copy after this pass, when the memref in ReturnOp is allocated by memref.alloc and is statically shaped. Instead, it replaces the uses of the allocated memref with the memref in the out argument.
By default, BufferResultsToOutParams will result in a memcpy operation to copy the originally returned memref to the output argument memref. This is inefficient when the source of memcpy (the returned memref in the original ReturnOp) is from a local AllocOp. The pass can use the output argument memref to replace the locally allocated memref for better performance. elim-alloc-copy avoids dynamic allocation and memory movement.
This option will be critical for performance-sensivtive applications, which require BufferResultsToOutParams pass for a caller-owned output buffer calling convension.
---
 .../Dialect/Bufferization/Transforms/Passes.h |  4 ++
 .../Bufferization/Transforms/Passes.td        |  9 +++++
 .../Transforms/BufferResultsToOutParams.cpp   | 21 ++++++++---
 .../buffer-results-to-out-params-elim.mlir    | 37 +++++++++++++++++++
 4 files changed, 65 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-elim.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a729bc99b987cd..459c252b707121 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
   /// If true, the pass adds a "bufferize.result" attribute to each output
   /// parameter.
   bool addResultAttribute = false;
+
+  /// If true, the pass eliminates the memref.alloc and memcpy if the returned
+  /// memref is allocated in the current function.
+  bool hoistStaticAllocs = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1303dc2c9ae10f..75ce85c9128c94 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -315,11 +315,20 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     The main issue with this pass (and the out-param calling convention) is that
     buffers for results need to be allocated in the caller. This currently only
     works for static shaped memrefs.
+
+    If the hoist-static-allocs option is on, the pass tries to eliminate the
+    allocation for the returned memref and avoid the memory-copy if possible.
+    This optimization applies on the returned memref which has static shape and
+    is allocated by memref.alloc in the function. It will use the memref given
+    in function argument to replace the allocated memref.
   }];
   let options = [
     Option<"addResultAttribute", "add-result-attr", "bool",
        /*default=*/"false",
        "Add the attribute 'bufferize.result' to all output parameters.">,
+    Option<"hoistStaticAllocs", "hoist-static-allocs",
+       "bool", /*default=*/"false",
+       "Hoist static allocations to call sites.">,
   ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index a2222e169c4d64..a5f01eadb21343 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
 // the given out-params.
 static LogicalResult updateReturnOps(func::FuncOp func,
                                      ArrayRef<BlockArgument> appendedEntryArgs,
-                                     MemCpyFn memCpyFn) {
+                                     MemCpyFn memCpyFn,
+                                     bool hoistStaticAllocs) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -118,10 +119,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
-    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (failed(
-              memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
-        return WalkResult::interrupt();
+    for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+      if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
+          orig.getType().cast<MemRefType>().hasStaticShape()) {
+        orig.replaceAllUsesWith(arg);
+        orig.getDefiningOp()->erase();
+      } else {
+        if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+          return WalkResult::interrupt();
+      }
     }
     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
@@ -212,7 +218,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return success();
     };
     if (failed(updateReturnOps(func, appendedEntryArgs,
-                               options.memCpyFn.value_or(defaultMemCpyFn)))) {
+                               options.memCpyFn.value_or(defaultMemCpyFn),
+                               options.hoistStaticAllocs))) {
       return failure();
     }
   }
@@ -233,6 +240,8 @@ struct BufferResultsToOutParamsPass
     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
     if (addResultAttribute)
       options.addResultAttribute = true;
+    if (hoistStaticAllocs)
+      options.hoistStaticAllocs = true;
 
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
new file mode 100644
index 00000000000000..f77dbfaa6cb11e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-static-allocs})'  %s | FileCheck %s
+
+// CHECK-LABEL:   func @basic(
+// CHECK-SAME:                %[[ARG:.*]]: memref<8x64xf32>) {
+// CHECK-NOT:        memref.alloc()
+// CHECK:           "test.source"(%[[ARG]])  : (memref<8x64xf32>) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @basic() -> (memref<8x64xf32>) {
+  %b = memref.alloc() : memref<8x64xf32>
+  "test.source"(%b)  : (memref<8x64xf32>) -> ()
+  return %b : memref<8x64xf32>
+}
+
+// CHECK-LABEL:   func @basic_no_change(
+// CHECK-SAME:                %[[ARG:.*]]: memref<f32>) {
+// CHECK:           %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
+// CHECK:           memref.copy %[[RESULT]], %[[ARG]]  : memref<f32> to memref<f32>
+// CHECK:           return
+// CHECK:         }
+func.func @basic_no_change() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
+
+// CHECK-LABEL:   func @basic_dynamic(
+// CHECK-SAME:                %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
+// CHECK:           %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
+// CHECK:           "test.source"(%[[RESULT]])  : (memref<?xf32>) -> ()
+// CHECK:           memref.copy %[[RESULT]], %[[ARG]]
+// CHECK:           return
+// CHECK:         }
+func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
+  %b = memref.alloc(%d) : memref<?xf32>
+  "test.source"(%b)  : (memref<?xf32>) -> ()
+  return %b : memref<?xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list