[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (PR #84320)
Matthias Gehre
llvmlistbot at llvm.org
Thu Mar 7 05:44:42 PST 2024
https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/84320
>From 46715655619dba59b6a642ba431e605a81a04860 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 7 Mar 2024 14:42:16 +0100
Subject: [PATCH] [MLIR][Bufferization] BufferResultsToOutParams: Add option to
add attribute to output arguments
Adds a new pass option `add-result-attr` that will make the pass add the attribute
`{bufferize.result}` to each argument that was converted from a result.
To be able to test this, the pass option was added to the tablegen.
To avoid collisions with the existing, manually defined option struct
`BufferResultsToOutParamsOptions`, that one was renamed to
`BufferResultsToOutParamsOpts`.
---
.../Dialect/Bufferization/Transforms/Passes.h | 10 ++++---
.../Bufferization/Transforms/Passes.td | 5 ++++
.../Transforms/BufferResultsToOutParams.cpp | 26 +++++++++++++------
...results-to-out-params-add-result-attr.mlir | 18 +++++++++++++
4 files changed, 48 insertions(+), 11 deletions(-)
create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
+
+ /// If true, the pass adds a "bufferize.result" attribute to each output
+ /// parameter.
+ bool addResultAttribute = false;
};
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
- const BufferResultsToOutParamsOptions &options = {});
+ const BufferResultsToOutParamsOpts &options = {});
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- const BufferResultsToOutParamsOptions &options);
+ const BufferResultsToOutParamsOpts &options);
/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
+ let options = [
+ Option<"addResultAttribute", "add-result-attr", "bool",
+ /*default=*/"false",
+ "Add the attribute 'bufferize.result' to all output parameters.">,
+ ];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Any args appended to the entry block are added to `appendedEntryArgs`.
static LogicalResult
updateFuncOp(func::FuncOp func,
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+ bool addResultAttribute) {
auto functionType = func.getFunctionType();
// Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
+ if (addResultAttribute)
+ func.setArgAttr(functionType.getNumInputs() + i,
+ StringAttr::get(func.getContext(), "bufferize.result"),
+ UnitAttr::get(func.getContext()));
}
// Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
- if (failed(updateFuncOp(func, appendedEntryArgs)))
+ if (failed(
+ updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options)
+ const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}
void runOnOperation() override {
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+ if (addResultAttribute)
+ options.addResultAttribute = true;
+
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}
private:
- bufferization::BufferResultsToOutParamsOptions options;
+ bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME: memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME: memref<1xf32> {bufferize.result},
+// CHECK-SAME: memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+ %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+ return %0, %1 : memref<1xf32>, memref<2xf32>
+}
More information about the Mlir-commits
mailing list