[Mlir-commits] [mlir] e6048b7 - [MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (#84320)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 13 23:50:21 PDT 2024
Author: Matthias Gehre
Date: 2024-03-14T07:50:16+01:00
New Revision: e6048b728d3170651828ad2dd7ed5ad0bc5e4f06
URL: https://github.com/llvm/llvm-project/commit/e6048b728d3170651828ad2dd7ed5ad0bc5e4f06
DIFF: https://github.com/llvm/llvm-project/commit/e6048b728d3170651828ad2dd7ed5ad0bc5e4f06.diff
LOG: [MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (#84320)
Adds a new pass option `add-result-attr` that will make the pass add the
attribute `{bufferize.result}` to each argument that was converted from
a result.
This is important e.g. when later using the python bindings / execution
engine to understand which arguments are actually results.
To be able to test this, the pass option was added to the tablegen. To
avoid collisions with the existing, manually defined option struct
`BufferResultsToOutParamsOptions`, that one was renamed to
`BufferResultsToOutParamsOpts`.
Added:
mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
+
+ /// If true, the pass adds a "bufferize.result" attribute to each output
+ /// parameter.
+ bool addResultAttribute = false;
};
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
- const BufferResultsToOutParamsOptions &options = {});
+ const BufferResultsToOutParamsOpts &options = {});
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- const BufferResultsToOutParamsOptions &options);
+ const BufferResultsToOutParamsOpts &options);
/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
+ let options = [
+ Option<"addResultAttribute", "add-result-attr", "bool",
+ /*default=*/"false",
+ "Add the attribute 'bufferize.result' to all output parameters.">,
+ ];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..a2222e169c4d64 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Updates the func op and entry block.
//
// Any args appended to the entry block are added to `appendedEntryArgs`.
+// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
+// to each newly created function argument.
static LogicalResult
updateFuncOp(func::FuncOp func,
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+ bool addResultAttribute) {
auto functionType = func.getFunctionType();
// Collect information about the results will become appended arguments.
@@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
+ if (addResultAttribute)
+ func.setArgAttr(functionType.getNumInputs() + i,
+ StringAttr::get(func.getContext(), "bufferize.result"),
+ UnitAttr::get(func.getContext()));
}
// Erase the results.
@@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
@@ -189,12 +196,13 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
- if (failed(updateFuncOp(func, appendedEntryArgs)))
+ if (failed(
+ updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
@@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options)
+ const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}
void runOnOperation() override {
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+ if (addResultAttribute)
+ options.addResultAttribute = true;
+
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}
private:
- bufferization::BufferResultsToOutParamsOptions options;
+ bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
- const bufferization::BufferResultsToOutParamsOptions &options) {
+ const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..f4a95c73e29533
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @basic({{.*}}: memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+ %0 = "test.source"() : () -> (memref<f32>)
+ return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME: memref<1xf32> {bufferize.result}
+// CHECK-SAME: memref<2xf32> {bufferize.result}
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+ %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+ return %0, %1 : memref<1xf32>, memref<2xf32>
+}
More information about the Mlir-commits
mailing list