[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (PR #84320)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 7 05:29:49 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Matthias Gehre (mgehre-amd)
<details>
<summary>Changes</summary>
Adds a new pass option `add-result-attr` that will make the pass add the attribute `{bufferize.result}` to each argument that was converted from a result.
This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results.
To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct `BufferResultsToOutParamsOptions`, that one was renamed to `BufferResultsToOutParamsOpts`.
---
Full diff: https://github.com/llvm/llvm-project/pull/84320.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+7-3)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+18-8)
- (added) mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir (+18)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
+
+ /// If true, the pass adds a "bufferize.result" attribute to each output
+ /// parameter.
+ bool addResultAttribute = false;
};
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
- const BufferResultsToOutParamsOptions &options = {});
+ const BufferResultsToOutParamsOpts &options = {});
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- const BufferResultsToOutParamsOptions &options);
+ const BufferResultsToOutParamsOpts &options);
/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
+ let options = [
+ Option<"addResultAttribute", "add-result-attr", "bool",
+ /*default=*/"false",
+ "Add the attribute 'bufferize.result' to all output parameters.">,
+ ];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Any args appended to the entry block are added to `appendedEntryArgs`.
static LogicalResult
updateFuncOp(func::FuncOp func,
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+ bool addResultAttribute) {
auto functionType = func.getFunctionType();
// Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
+ if (addResultAttribute)
+ func.setArgAttr(functionType.getNumInputs() + i,
+ StringAttr::get(func.getContext(), "bufferize.result"),
+ UnitAttr::get(func.getContext()));
}
// Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
- if (failed(updateFuncOp(func, appendedEntryArgs)))
+ if (failed(
+ updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options)
+ const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}
void runOnOperation() override {
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+ if (addResultAttribute)
+ options.addResultAttribute = true;
+
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}
private:
- bufferization::BufferResultsToOutParamsOptions options;
+ bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME: memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME: memref<1xf32> {bufferize.result},
+// CHECK-SAME: memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+ %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+ return %0, %1 : memref<1xf32>, memref<2xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/84320
More information about the Mlir-commits
mailing list