[Mlir-commits] [mlir] [mlir][memref] Fix hoist-static-allocs option of buffer-results-to-out-params when function parameters are returned (PR #102093)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 3 00:04:10 PDT 2024


https://github.com/Menooker updated https://github.com/llvm/llvm-project/pull/102093

>From 41ea0981e2a42b175a42367ea11189eec5e51c5e Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 30 Jul 2024 15:39:53 +0800
Subject: [PATCH 1/3] fix return arg

---
 .../Dialect/Bufferization/Transforms/Passes.td  |  5 +++++
 .../Transforms/BufferResultsToOutParams.cpp     | 17 +++++++++++++++--
 .../buffer-results-to-out-params-elim.mlir      | 13 ++++++++++++-
 3 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1cece818dbbbc3..d6f13b21538286 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -321,6 +321,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     This optimization applies on the returned memref which has static shape and
     is allocated by memref.alloc in the function. It will use the memref given
     in function argument to replace the allocated memref.
+
+    If the hoist-static-allocs option is on, and a function returns a memref
+    from the function argument, the pass will avoid the memory-copy from
+    the input function argument to the "out param", and leave the "out param"
+    unused. 
   }];
   let options = [
     Option<"addResultAttribute", "add-result-attr", "bool",
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b19636adaa69e6..16a42ca7793819 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -102,6 +102,14 @@ updateFuncOp(func::FuncOp func,
   return success();
 }
 
+static bool isFunctionArgument(mlir::Value value) {
+  // Check if the value is a Function argument
+  if (auto blockArg = dyn_cast<mlir::BlockArgument>(value)) {
+    return blockArg.getOwner()->isEntryBlock();
+  }
+  return false;
+}
+
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
@@ -120,10 +128,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
-          mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
+      bool mayHoistStaticAlloc =
+          hoistStaticAllocs &&
+          mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+      if (mayHoistStaticAlloc &&
+          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp())) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
+      } else if (mayHoistStaticAlloc && isFunctionArgument(orig)) {
+        // do nothing but remove the value from the return op.
       } else {
         if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
           return WalkResult::interrupt();
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index f77dbfaa6cb11e..2bd9a9a0455318 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -34,4 +34,15 @@ func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
   %b = memref.alloc(%d) : memref<?xf32>
   "test.source"(%b)  : (memref<?xf32>) -> ()
   return %b : memref<?xf32>
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL:   func @return_arg(
+// CHECK-SAME:        %[[ARG0:.*]]: memref<128x256xf32>, %[[ARG1:.*]]: memref<128x256xf32>, %[[ARG2:.*]]: memref<128x256xf32>) {
+// CHECK:           "test.source"(%[[ARG0]], %[[ARG1]])
+// CHECK-NOT:       memref.copy
+// CHECK:           return
+// CHECK:         }
+func.func @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {
+  "test.source"(%arg0, %arg1)  : (memref<128x256xf32>, memref<128x256xf32>) -> ()
+  return %arg0 : memref<128x256xf32>
+}

>From 0aa9133cbf903196388a9bd95c8a41ad7a5c1c91 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 20 Aug 2024 21:31:32 +0800
Subject: [PATCH 2/3] fix comment

---
 .../Bufferization/Transforms/BufferResultsToOutParams.cpp     | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 16a42ca7793819..bf5671cf934f73 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -105,7 +105,9 @@ updateFuncOp(func::FuncOp func,
 static bool isFunctionArgument(mlir::Value value) {
   // Check if the value is a Function argument
   if (auto blockArg = dyn_cast<mlir::BlockArgument>(value)) {
-    return blockArg.getOwner()->isEntryBlock();
+    return blockArg.getOwner()->isEntryBlock() &&
+           isa_and_nonnull<mlir::func::FuncOp>(
+               blockArg.getOwner()->getParentOp());
   }
   return false;
 }

>From 30d7d86faa91166faada3b168c28f49145ca771e Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 3 Sep 2024 15:01:02 +0800
Subject: [PATCH 3/3] remove optimizations

---
 .../Bufferization/Transforms/Passes.td        |  5 -----
 .../Transforms/BufferResultsToOutParams.cpp   | 20 +++----------------
 .../buffer-results-to-out-params-elim.mlir    |  4 +++-
 3 files changed, 6 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 9d4145b5989bc9..a610ddcc9899ed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -337,11 +337,6 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     This optimization applies on the returned memref which has static shape and
     is allocated by memref.alloc in the function. It will use the memref given
     in function argument to replace the allocated memref.
-
-    If the hoist-static-allocs option is on, and a function returns a memref
-    from the function argument, the pass will avoid the memory-copy from
-    the input function argument to the "out param", and leave the "out param"
-    unused. 
   }];
   let options = [
     Option<"addResultAttribute", "add-result-attr", "bool",
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index bf5671cf934f73..b7755b2be8483b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -102,16 +102,6 @@ updateFuncOp(func::FuncOp func,
   return success();
 }
 
-static bool isFunctionArgument(mlir::Value value) {
-  // Check if the value is a Function argument
-  if (auto blockArg = dyn_cast<mlir::BlockArgument>(value)) {
-    return blockArg.getOwner()->isEntryBlock() &&
-           isa_and_nonnull<mlir::func::FuncOp>(
-               blockArg.getOwner()->getParentOp());
-  }
-  return false;
-}
-
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
@@ -130,15 +120,11 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      bool mayHoistStaticAlloc =
-          hoistStaticAllocs &&
-          mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
-      if (mayHoistStaticAlloc &&
-          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp())) {
+      if (hoistStaticAllocs &&
+          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
+          mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
-      } else if (mayHoistStaticAlloc && isFunctionArgument(orig)) {
-        // do nothing but remove the value from the return op.
       } else {
         if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
           return WalkResult::interrupt();
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index 2bd9a9a0455318..ee105cd7a29624 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -36,10 +36,12 @@ func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
   return %b : memref<?xf32>
 }
 
+// no change due to writing to func args
+
 // CHECK-LABEL:   func @return_arg(
 // CHECK-SAME:        %[[ARG0:.*]]: memref<128x256xf32>, %[[ARG1:.*]]: memref<128x256xf32>, %[[ARG2:.*]]: memref<128x256xf32>) {
 // CHECK:           "test.source"(%[[ARG0]], %[[ARG1]])
-// CHECK-NOT:       memref.copy
+// CHECK:           memref.copy
 // CHECK:           return
 // CHECK:         }
 func.func @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {



More information about the Mlir-commits mailing list