[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (PR #84320)

Matthias Gehre llvmlistbot at llvm.org
Thu Mar 7 05:29:20 PST 2024


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/84320

Adds a new pass option `add-result-attr` that will make the pass add the attribute `{bufferize.result}` to each argument that was converted from a result.
This is important e.g. when later using the python bindings / execution engine to understand which arguments are actually results.

To be able to test this, the pass option was added to the tablegen. To avoid collisions with the existing, manually defined option struct `BufferResultsToOutParamsOptions`, that one was renamed to `BufferResultsToOutParamsOpts`.

>From b0c081b9c9a2ea8f97affd726d36dfb335ced688 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <93204396+mgehre-amd at users.noreply.github.com>
Date: Thu, 29 Feb 2024 07:46:04 +0100
Subject: [PATCH] [MLIR][Bufferization] BufferResultsToOutParams: Add option to
 add attribute to output arguments

Adds a new pass option `add-result-attr` that will make the pass add the attribute
`{bufferize.result}` to each argument that was converted from a result.

To be able to test this, the pass option was added to the tablegen.
To avoid collisions with the existing, manually defined option struct
`BufferResultsToOutParamsOptions`, that one was renamed to
`BufferResultsToOutParamsOpts`.
---
 .../Dialect/Bufferization/Transforms/Passes.h | 10 ++++---
 .../Bufferization/Transforms/Passes.td        |  5 ++++
 .../Transforms/BufferResultsToOutParams.cpp   | 26 +++++++++++++------
 ...results-to-out-params-add-result-attr.mlir | 18 +++++++++++++
 4 files changed, 48 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
   /// Memcpy function; used to create a copy between two memrefs.
   /// If this is empty, memref.copy is used.
   std::optional<MemCpyFn> memCpyFn;
+
+  /// If true, the pass adds a "bufferize.result" attribute to each output
+  /// parameter.
+  bool addResultAttribute = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
 std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
-    const BufferResultsToOutParamsOptions &options = {});
+    const BufferResultsToOutParamsOpts &options = {});
 
 /// Replace buffers that are returned from a function with an out parameter.
 /// Also update all call sites.
 LogicalResult
 promoteBufferResultsToOutParams(ModuleOp module,
-                                const BufferResultsToOutParamsOptions &options);
+                                const BufferResultsToOutParamsOpts &options);
 
 /// Creates a pass that drops memref function results that are equivalent to a
 /// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     buffers for results need to be allocated in the caller. This currently only
     works for static shaped memrefs.
   }];
+  let options = [
+    Option<"addResultAttribute", "add-result-attr", "bool",
+       /*default=*/"false",
+       "Add the attribute 'bufferize.result' to all output parameters.">,
+  ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
 // Any args appended to the entry block are added to `appendedEntryArgs`.
 static LogicalResult
 updateFuncOp(func::FuncOp func,
-             SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+             SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+             bool addResultAttribute) {
   auto functionType = func.getFunctionType();
 
   // Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
     func.setArgAttrs(functionType.getNumInputs() + i,
                      func.getResultAttrs(*erasedIndicesIt));
+    if (addResultAttribute)
+      func.setArgAttr(functionType.getNumInputs() + i,
+                      StringAttr::get(func.getContext(), "bufferize.result"),
+                      UnitAttr::get(func.getContext()));
   }
 
   // Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
 // temporary buffers for newly introduced out params.
 static LogicalResult
 updateCalls(ModuleOp module,
-            const bufferization::BufferResultsToOutParamsOptions &options) {
+            const bufferization::BufferResultsToOutParamsOpts &options) {
   bool didFail = false;
   SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
 
 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
     ModuleOp module,
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   for (auto func : module.getOps<func::FuncOp>()) {
     if (!options.filterFn(&func))
       continue;
     SmallVector<BlockArgument, 6> appendedEntryArgs;
-    if (failed(updateFuncOp(func, appendedEntryArgs)))
+    if (failed(
+            updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
       return failure();
     if (func.isExternal())
       continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
     : bufferization::impl::BufferResultsToOutParamsBase<
           BufferResultsToOutParamsPass> {
   explicit BufferResultsToOutParamsPass(
-      const bufferization::BufferResultsToOutParamsOptions &options)
+      const bufferization::BufferResultsToOutParamsOpts &options)
       : options(options) {}
 
   void runOnOperation() override {
+    // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+    if (addResultAttribute)
+      options.addResultAttribute = true;
+
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
       return signalPassFailure();
   }
 
 private:
-  bufferization::BufferResultsToOutParamsOptions options;
+  bufferization::BufferResultsToOutParamsOpts options;
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   return std::make_unique<BufferResultsToOutParamsPass>(options);
 }
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME:  memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME:  memref<1xf32> {bufferize.result},
+// CHECK-SAME:  memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+  %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+  return %0, %1 : memref<1xf32>, memref<2xf32>
+}



More information about the Mlir-commits mailing list