[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add option to add attribute to output arguments (PR #84320)

Matthias Gehre llvmlistbot at llvm.org
Thu Mar 7 05:44:42 PST 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/84320

>From 46715655619dba59b6a642ba431e605a81a04860 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 7 Mar 2024 14:42:16 +0100
Subject: [PATCH] [MLIR][Bufferization] BufferResultsToOutParams: Add option to
 add attribute to output arguments

Adds a new pass option `add-result-attr` that will make the pass add the attribute
`{bufferize.result}` to each argument that was converted from a result.

To be able to test this, the pass option was added to the tablegen.
To avoid collisions with the existing, manually defined option struct
`BufferResultsToOutParamsOptions`, that one was renamed to
`BufferResultsToOutParamsOpts`.
---
 .../Dialect/Bufferization/Transforms/Passes.h | 10 ++++---
 .../Bufferization/Transforms/Passes.td        |  5 ++++
 .../Transforms/BufferResultsToOutParams.cpp   | 26 +++++++++++++------
 ...results-to-out-params-add-result-attr.mlir | 18 +++++++++++++
 4 files changed, 48 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 809f03407258a8..a729bc99b987cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -148,7 +148,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
-struct BufferResultsToOutParamsOptions {
+struct BufferResultsToOutParamsOpts {
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions {
   /// Memcpy function; used to create a copy between two memrefs.
   /// If this is empty, memref.copy is used.
   std::optional<MemCpyFn> memCpyFn;
+
+  /// If true, the pass adds a "bufferize.result" attribute to each output
+  /// parameter.
+  bool addResultAttribute = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
 std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
-    const BufferResultsToOutParamsOptions &options = {});
+    const BufferResultsToOutParamsOpts &options = {});
 
 /// Replace buffers that are returned from a function with an out parameter.
 /// Also update all call sites.
 LogicalResult
 promoteBufferResultsToOutParams(ModuleOp module,
-                                const BufferResultsToOutParamsOptions &options);
+                                const BufferResultsToOutParamsOpts &options);
 
 /// Creates a pass that drops memref function results that are equivalent to a
 /// function argument.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e01f36b8daa18d..1c3cdec81a39e0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     buffers for results need to be allocated in the caller. This currently only
     works for static shaped memrefs.
   }];
+  let options = [
+    Option<"addResultAttribute", "add-result-attr", "bool",
+       /*default=*/"false",
+       "Add the attribute 'bufferize.result' to all output parameters.">,
+  ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 930f035339c1d3..5ab347066c90cb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,7 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
-using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
 // Any args appended to the entry block are added to `appendedEntryArgs`.
 static LogicalResult
 updateFuncOp(func::FuncOp func,
-             SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
+             SmallVectorImpl<BlockArgument> &appendedEntryArgs,
+             bool addResultAttribute) {
   auto functionType = func.getFunctionType();
 
   // Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
     func.setArgAttrs(functionType.getNumInputs() + i,
                      func.getResultAttrs(*erasedIndicesIt));
+    if (addResultAttribute)
+      func.setArgAttr(functionType.getNumInputs() + i,
+                      StringAttr::get(func.getContext(), "bufferize.result"),
+                      UnitAttr::get(func.getContext()));
   }
 
   // Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
 // temporary buffers for newly introduced out params.
 static LogicalResult
 updateCalls(ModuleOp module,
-            const bufferization::BufferResultsToOutParamsOptions &options) {
+            const bufferization::BufferResultsToOutParamsOpts &options) {
   bool didFail = false;
   SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
 
 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
     ModuleOp module,
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   for (auto func : module.getOps<func::FuncOp>()) {
     if (!options.filterFn(&func))
       continue;
     SmallVector<BlockArgument, 6> appendedEntryArgs;
-    if (failed(updateFuncOp(func, appendedEntryArgs)))
+    if (failed(
+            updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
       return failure();
     if (func.isExternal())
       continue;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
     : bufferization::impl::BufferResultsToOutParamsBase<
           BufferResultsToOutParamsPass> {
   explicit BufferResultsToOutParamsPass(
-      const bufferization::BufferResultsToOutParamsOptions &options)
+      const bufferization::BufferResultsToOutParamsOpts &options)
       : options(options) {}
 
   void runOnOperation() override {
+    // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
+    if (addResultAttribute)
+      options.addResultAttribute = true;
+
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
       return signalPassFailure();
   }
 
 private:
-  bufferization::BufferResultsToOutParamsOptions options;
+  bufferization::BufferResultsToOutParamsOpts options;
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
-    const bufferization::BufferResultsToOutParamsOptions &options) {
+    const bufferization::BufferResultsToOutParamsOpts &options) {
   return std::make_unique<BufferResultsToOutParamsPass>(options);
 }
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
new file mode 100644
index 00000000000000..48d5d2372b869e
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: basic
+// CHECK-SAME:  memref<f32> {bufferize.result})
+func.func @basic() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
+
+// -----
+
+// CHECK-LABEL: multiple_results
+// CHECK-SAME:  memref<1xf32> {bufferize.result},
+// CHECK-SAME:  memref<2xf32> {bufferize.result})
+func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
+  %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
+  return %0, %1 : memref<1xf32>, memref<2xf32>
+}



More information about the Mlir-commits mailing list