[Mlir-commits] [mlir] [MLIR] BufferResultsToOutParams: Allow to configure memCpyFn (PR #83389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 28 23:26:31 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

<details>
<summary>Changes</summary>

This allows us to configure the pass to emit `linalg.copy` instead of `memref.copy`.

This is consistent with `one-shot-bufferize`, which also allows to configure the `memCpyFn`, see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087

---
Full diff: https://github.com/llvm/llvm-project/pull/83389.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+8) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+21-7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index bb4b5221981638..809f03407258a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -149,11 +149,19 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
 struct BufferResultsToOutParamsOptions {
+  /// Memcpy function: Generate a memcpy between two memrefs.
+  using MemCpyFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
   // Filter function; returns true if the function should be converted.
   // Defaults to true, i.e. all functions are converted.
   llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
     return true;
   };
+
+  /// Memcpy function; used to create a copy between two memrefs.
+  /// If this is empty, memref.copy is used.
+  std::optional<MemCpyFn> memCpyFn;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index dd359c2dcca5dd..930f035339c1d3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,6 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
-static void updateReturnOps(func::FuncOp func,
-                            ArrayRef<BlockArgument> appendedEntryArgs) {
-  func.walk([&](func::ReturnOp op) {
+static LogicalResult updateReturnOps(func::FuncOp func,
+                                     ArrayRef<BlockArgument> appendedEntryArgs,
+                                     MemCpyFn memCpyFn) {
+  auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
     for (Value operand : op.getOperands()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
-    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
-      builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
-                                     std::get<1>(t));
+    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+      if (failed(
+              memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
+        return WalkResult::interrupt();
+    }
     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
+    return WalkResult::advance();
   });
+  return failure(res.wasInterrupted());
 }
 
 // Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return failure();
     if (func.isExternal())
       continue;
-    updateReturnOps(func, appendedEntryArgs);
+    auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
+                              Value to) {
+      builder.create<memref::CopyOp>(loc, from, to);
+      return success();
+    };
+    if (failed(updateReturnOps(func, appendedEntryArgs,
+                               options.memCpyFn.value_or(defaultMemCpyFn)))) {
+      return failure();
+    }
   }
   if (failed(updateCalls(module, options)))
     return failure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/83389


More information about the Mlir-commits mailing list