[Mlir-commits] [mlir] [MLIR][Bufferization] BufferResultsToOutParams: Add an option to eliminate AllocOp and avoid Copy (PR #90011)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 26 19:30:51 PDT 2024


https://github.com/Menooker updated https://github.com/llvm/llvm-project/pull/90011

>From ffb8740f2bfc8d5987678cd8eaff1fef0a171b4d Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Thu, 18 Apr 2024 16:04:37 +0800
Subject: [PATCH 1/3] [MLIR][Bufferization] BufferResultsToOutParams: Add an
 option to eliminate AllocOp and Copy

---
 .../Dialect/Bufferization/Transforms/Passes.h |  4 ++++
 .../Bufferization/Transforms/Passes.td        |  4 ++++
 .../Transforms/BufferResultsToOutParams.cpp   | 20 +++++++++++-----
 .../buffer-results-to-out-params-elim.mlir    | 24 +++++++++++++++++++
 4 files changed, 46 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-elim.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a729bc99b987cd..6bb436de4f0821 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
   /// If true, the pass adds a "bufferize.result" attribute to each output
   /// parameter.
   bool addResultAttribute = false;
+  
+  /// If true, the pass eliminates the memref.alloc and memcpy if the returned
+  /// memref is allocated in the current function.
+  bool eliminateAllocCopy = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1303dc2c9ae10f..ef5e2293dec2de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -320,6 +320,10 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     Option<"addResultAttribute", "add-result-attr", "bool",
        /*default=*/"false",
        "Add the attribute 'bufferize.result' to all output parameters.">,
+    Option<"eliminateAllocCopy", "elim-alloc-copy", "bool",
+       /*default=*/"false",
+       "When the returned memref is allocated by `memref.alloc`, eliminate the "
+       "allocation, and the memref.copy. And use the argument memref instead">,
   ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index a2222e169c4d64..4a5bfec94b4ff0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
 // the given out-params.
 static LogicalResult updateReturnOps(func::FuncOp func,
                                      ArrayRef<BlockArgument> appendedEntryArgs,
-                                     MemCpyFn memCpyFn) {
+                                     MemCpyFn memCpyFn,
+                                     bool eliminateAllocCopy) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -118,10 +119,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
-    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (failed(
-              memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
-        return WalkResult::interrupt();
+    for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+      if (eliminateAllocCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+        orig.replaceAllUsesWith(arg);
+        orig.getDefiningOp()->erase();
+      } else {
+        if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+          return WalkResult::interrupt();
+      }
     }
     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
@@ -212,7 +217,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return success();
     };
     if (failed(updateReturnOps(func, appendedEntryArgs,
-                               options.memCpyFn.value_or(defaultMemCpyFn)))) {
+                               options.memCpyFn.value_or(defaultMemCpyFn),
+                               options.eliminateAllocCopy))) {
       return failure();
     }
   }
@@ -233,6 +239,8 @@ struct BufferResultsToOutParamsPass
     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
     if (addResultAttribute)
       options.addResultAttribute = true;
+    if (eliminateAllocCopy)
+      options.eliminateAllocCopy = true;
 
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
new file mode 100644
index 00000000000000..ac739a4b9c257a
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{elim-alloc-copy})'  %s | FileCheck %s
+
+// CHECK-LABEL:   func @basic(
+// CHECK-SAME:                %[[ARG:.*]]: memref<8x64xf32>) {
+// CHECK-NOT:        memref.alloc()
+// CHECK:           "test.source"(%[[ARG]])  : (memref<8x64xf32>) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @basic() -> (memref<8x64xf32>) {
+  %b = memref.alloc() : memref<8x64xf32>
+  "test.source"(%b)  : (memref<8x64xf32>) -> ()
+  return %b : memref<8x64xf32>
+}
+
+// CHECK-LABEL:   func @basic_no_change(
+// CHECK-SAME:                %[[ARG:.*]]: memref<f32>) {
+// CHECK:           %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
+// CHECK:           memref.copy %[[RESULT]], %[[ARG]]  : memref<f32> to memref<f32>
+// CHECK:           return
+// CHECK:         }
+func.func @basic_no_change() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
\ No newline at end of file

>From 98f6640773e3c3a2116162acf16751e1abeeda38 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Thu, 25 Apr 2024 10:36:46 +0800
Subject: [PATCH 2/3] rename

---
 .../mlir/Dialect/Bufferization/Transforms/Passes.h     |  2 +-
 .../mlir/Dialect/Bufferization/Transforms/Passes.td    |  9 +++++----
 .../Transforms/BufferResultsToOutParams.cpp            | 10 +++++-----
 .../Transforms/buffer-results-to-out-params-elim.mlir  |  2 +-
 4 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 6bb436de4f0821..e5d026d7469f98 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -169,7 +169,7 @@ struct BufferResultsToOutParamsOpts {
   
   /// If true, the pass eliminates the memref.alloc and memcpy if the returned
   /// memref is allocated in the current function.
-  bool eliminateAllocCopy = false;
+  bool avoidBufferResultAllocAndCopy = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index ef5e2293dec2de..e3197cc16377ee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -320,10 +320,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     Option<"addResultAttribute", "add-result-attr", "bool",
        /*default=*/"false",
        "Add the attribute 'bufferize.result' to all output parameters.">,
-    Option<"eliminateAllocCopy", "elim-alloc-copy", "bool",
-       /*default=*/"false",
-       "When the returned memref is allocated by `memref.alloc`, eliminate the "
-       "allocation, and the memref.copy. And use the argument memref instead">,
+    Option<"avoidBufferResultAllocAndCopy", "avoid-buffer-result-alloc-copy",
+       "bool", /*default=*/"false",
+       "When the returned memref is allocated by `memref.alloc` in the function"
+       ", eliminate the allocation and the memref.copy. And use the memref"
+       " given in function argument instead">,
   ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 4a5bfec94b4ff0..ce6a4821ccc202 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -108,7 +108,7 @@ updateFuncOp(func::FuncOp func,
 static LogicalResult updateReturnOps(func::FuncOp func,
                                      ArrayRef<BlockArgument> appendedEntryArgs,
                                      MemCpyFn memCpyFn,
-                                     bool eliminateAllocCopy) {
+                                     bool avoidBufferResultAllocAndCopy) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,7 +120,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (eliminateAllocCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+      if (avoidBufferResultAllocAndCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
       } else {
@@ -218,7 +218,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
     };
     if (failed(updateReturnOps(func, appendedEntryArgs,
                                options.memCpyFn.value_or(defaultMemCpyFn),
-                               options.eliminateAllocCopy))) {
+                               options.avoidBufferResultAllocAndCopy))) {
       return failure();
     }
   }
@@ -239,8 +239,8 @@ struct BufferResultsToOutParamsPass
     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
     if (addResultAttribute)
       options.addResultAttribute = true;
-    if (eliminateAllocCopy)
-      options.eliminateAllocCopy = true;
+    if (avoidBufferResultAllocAndCopy)
+      options.avoidBufferResultAllocAndCopy = true;
 
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index ac739a4b9c257a..0b2a0b6e14d180 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{elim-alloc-copy})'  %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{avoid-buffer-result-alloc-copy})'  %s | FileCheck %s
 
 // CHECK-LABEL:   func @basic(
 // CHECK-SAME:                %[[ARG:.*]]: memref<8x64xf32>) {

>From 85a257085631779632bb4a2c84bacec69c366a67 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Sat, 27 Apr 2024 10:25:01 +0800
Subject: [PATCH 3/3] skip dynamic shape

---
 .../mlir/Dialect/Bufferization/Transforms/Passes.td |  6 +++---
 .../Transforms/BufferResultsToOutParams.cpp         |  4 +++-
 .../buffer-results-to-out-params-elim.mlir          | 13 +++++++++++++
 3 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index e3197cc16377ee..390a07c6b5512c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -322,9 +322,9 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
        "Add the attribute 'bufferize.result' to all output parameters.">,
     Option<"avoidBufferResultAllocAndCopy", "avoid-buffer-result-alloc-copy",
        "bool", /*default=*/"false",
-       "When the returned memref is allocated by `memref.alloc` in the function"
-       ", eliminate the allocation and the memref.copy. And use the memref"
-       " given in function argument instead">,
+       "When the returned memref has static shape and is allocated by "
+       "memref.alloc in the function, eliminate the allocation and avoid the"
+       "memref.copy. And use the memref given in function argument instead">,
   ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index ce6a4821ccc202..1cb777b4148be3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -120,7 +120,9 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (avoidBufferResultAllocAndCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+      if (avoidBufferResultAllocAndCopy &&
+          isa<memref::AllocOp>(orig.getDefiningOp()) &&
+          orig.getType().cast<MemRefType>().hasStaticShape()) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
       } else {
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index 0b2a0b6e14d180..d3209a182034f2 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -21,4 +21,17 @@ func.func @basic() -> (memref<8x64xf32>) {
 func.func @basic_no_change() -> (memref<f32>) {
   %0 = "test.source"() : () -> (memref<f32>)
   return %0 : memref<f32>
+}
+
+// CHECK-LABEL:   func @basic_dynamic(
+// CHECK-SAME:                %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
+// CHECK:           %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
+// CHECK:           "test.source"(%[[RESULT]])  : (memref<?xf32>) -> ()
+// CHECK:           memref.copy %[[RESULT]], %[[ARG]]
+// CHECK:           return
+// CHECK:         }
+func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
+  %b = memref.alloc(%d) : memref<?xf32>
+  "test.source"(%b)  : (memref<?xf32>) -> ()
+  return %b : memref<?xf32>
 }
\ No newline at end of file



More information about the Mlir-commits mailing list