[Mlir-commits] [mlir] [MLIR] BufferResultsToOutParams: Allow to configure memCpyFn (PR #83389)
Matthias Gehre
llvmlistbot at llvm.org
Wed Feb 28 23:26:02 PST 2024
https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/83389
This allows us to configure the pass to emit `linalg.copy` instead of `memref.copy`.
This is consistent with `one-shot-bufferize`, which also allows to configure the `memCpyFn`, see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087
>From 0e3952ce38273c09b2b4fc8847ac983029a108be Mon Sep 17 00:00:00 2001
From: Matthias Gehre <93204396+mgehre-amd at users.noreply.github.com>
Date: Thu, 29 Feb 2024 07:32:30 +0100
Subject: [PATCH] [MLIR] BufferResultsToOutParams: Allow to configure memCpyFn
This allows us to configure the pass to emit
linalg.copy instead of memref.copy.
This is consistent with one-shot-bufferize, which also allows to configure the `memCpyFn`,
see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087
---
.../Dialect/Bufferization/Transforms/Passes.h | 8 ++++++
.../Transforms/BufferResultsToOutParams.cpp | 28 ++++++++++++++-----
2 files changed, 29 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index bb4b5221981638..809f03407258a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -149,11 +149,19 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOptions {
+ /// Memcpy function: Generate a memcpy between two memrefs.
+ using MemCpyFn =
+ std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
// Filter function; returns true if the function should be converted.
// Defaults to true, i.e. all functions are converted.
llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
return true;
};
+
+ /// Memcpy function; used to create a copy between two memrefs.
+ /// If this is empty, memref.copy is used.
+ std::optional<MemCpyFn> memCpyFn;
};
/// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index dd359c2dcca5dd..930f035339c1d3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,6 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
-static void updateReturnOps(func::FuncOp func,
- ArrayRef<BlockArgument> appendedEntryArgs) {
- func.walk([&](func::ReturnOp op) {
+static LogicalResult updateReturnOps(func::FuncOp func,
+ ArrayRef<BlockArgument> appendedEntryArgs,
+ MemCpyFn memCpyFn) {
+ auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
- for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
- builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
- std::get<1>(t));
+ for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+ if (failed(
+ memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
+ return WalkResult::interrupt();
+ }
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
+ return WalkResult::advance();
});
+ return failure(res.wasInterrupted());
}
// Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
- updateReturnOps(func, appendedEntryArgs);
+ auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ builder.create<memref::CopyOp>(loc, from, to);
+ return success();
+ };
+ if (failed(updateReturnOps(func, appendedEntryArgs,
+ options.memCpyFn.value_or(defaultMemCpyFn)))) {
+ return failure();
+ }
}
if (failed(updateCalls(module, options)))
return failure();
More information about the Mlir-commits
mailing list