[Mlir-commits] [mlir] [MLIR] BufferResultsToOutParams: Allow to configure memCpyFn (PR #83389)

Matthias Gehre llvmlistbot at llvm.org
Mon Mar 4 14:19:57 PST 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/83389

>From fda142fd14af3006b3dba8253dd356f1733ecf1f Mon Sep 17 00:00:00 2001
From: Matthias Gehre <93204396+mgehre-amd at users.noreply.github.com>
Date: Thu, 29 Feb 2024 07:32:30 +0100
Subject: [PATCH] [MLIR] BufferResultsToOutParams: Allow to configure memCpyFn

This allows us to configure the pass to emit
linalg.copy instead of memref.copy.

This is consistent with one-shot-bufferize, which also allows to configure the `memCpyFn`,
see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087
---
 .../Dialect/Bufferization/Transforms/Passes.h |  8 ++++++
 .../Transforms/BufferResultsToOutParams.cpp   | 28 ++++++++++++++-----
 2 files changed, 29 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index bb4b5221981638..809f03407258a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -149,11 +149,19 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
 struct BufferResultsToOutParamsOptions {
+  /// Memcpy function: Generate a memcpy between two memrefs.
+  using MemCpyFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
   // Filter function; returns true if the function should be converted.
   // Defaults to true, i.e. all functions are converted.
   llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
     return true;
   };
+
+  /// Memcpy function; used to create a copy between two memrefs.
+  /// If this is empty, memref.copy is used.
+  std::optional<MemCpyFn> memCpyFn;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index dd359c2dcca5dd..930f035339c1d3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,6 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
-static void updateReturnOps(func::FuncOp func,
-                            ArrayRef<BlockArgument> appendedEntryArgs) {
-  func.walk([&](func::ReturnOp op) {
+static LogicalResult updateReturnOps(func::FuncOp func,
+                                     ArrayRef<BlockArgument> appendedEntryArgs,
+                                     MemCpyFn memCpyFn) {
+  auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
     for (Value operand : op.getOperands()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
-    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
-      builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
-                                     std::get<1>(t));
+    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+      if (failed(
+              memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
+        return WalkResult::interrupt();
+    }
     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
+    return WalkResult::advance();
   });
+  return failure(res.wasInterrupted());
 }
 
 // Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return failure();
     if (func.isExternal())
       continue;
-    updateReturnOps(func, appendedEntryArgs);
+    auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
+                              Value to) {
+      builder.create<memref::CopyOp>(loc, from, to);
+      return success();
+    };
+    if (failed(updateReturnOps(func, appendedEntryArgs,
+                               options.memCpyFn.value_or(defaultMemCpyFn)))) {
+      return failure();
+    }
   }
   if (failed(updateCalls(module, options)))
     return failure();



More information about the Mlir-commits mailing list