[Mlir-commits] [mlir] [mlir][memref] Fix hoist-static-allocs option of buffer-results-to-out-params when function parameters are returned (PR #102093)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 5 19:47:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Menooker (Menooker)

<details>
<summary>Changes</summary>

buffer-results-to-out-params pass will have a nullptr error when hoist-static-allocs option is on, when the return value of a function is a parameter of the function. This PR fixes this issue and let the pass remove the return value in the ReturnOp when
 * the value type is memref
 * and the value is a function parameter

---
Full diff: https://github.com/llvm/llvm-project/pull/102093.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+15-2) 
- (modified) mlir/test/Transforms/buffer-results-to-out-params-elim.mlir (+12-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1cece818dbbbc..d6f13b2153828 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -321,6 +321,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     This optimization applies on the returned memref which has static shape and
     is allocated by memref.alloc in the function. It will use the memref given
     in function argument to replace the allocated memref.
+
+    If the hoist-static-allocs option is on, and a function returns a memref
+    from the function argument, the pass will avoid the memory-copy from
+    the input function argument to the "out param", and leave the "out param"
+    unused. 
   }];
   let options = [
     Option<"addResultAttribute", "add-result-attr", "bool",
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b19636adaa69e..16a42ca779381 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -102,6 +102,14 @@ updateFuncOp(func::FuncOp func,
   return success();
 }
 
+static bool isFunctionArgument(mlir::Value value) {
+  // Check if the value is a Function argument
+  if (auto blockArg = dyn_cast<mlir::BlockArgument>(value)) {
+    return blockArg.getOwner()->isEntryBlock();
+  }
+  return false;
+}
+
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
@@ -120,10 +128,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
-          mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
+      bool mayHoistStaticAlloc =
+          hoistStaticAllocs &&
+          mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+      if (mayHoistStaticAlloc &&
+          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp())) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
+      } else if (mayHoistStaticAlloc && isFunctionArgument(orig)) {
+        // do nothing but remove the value from the return op.
       } else {
         if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
           return WalkResult::interrupt();
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
index f77dbfaa6cb11..2bd9a9a045531 100644
--- a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -34,4 +34,15 @@ func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
   %b = memref.alloc(%d) : memref<?xf32>
   "test.source"(%b)  : (memref<?xf32>) -> ()
   return %b : memref<?xf32>
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL:   func @return_arg(
+// CHECK-SAME:        %[[ARG0:.*]]: memref<128x256xf32>, %[[ARG1:.*]]: memref<128x256xf32>, %[[ARG2:.*]]: memref<128x256xf32>) {
+// CHECK:           "test.source"(%[[ARG0]], %[[ARG1]])
+// CHECK-NOT:       memref.copy
+// CHECK:           return
+// CHECK:         }
+func.func @return_arg(%arg0: memref<128x256xf32>, %arg1: memref<128x256xf32>) -> memref<128x256xf32> {
+  "test.source"(%arg0, %arg1)  : (memref<128x256xf32>, memref<128x256xf32>) -> ()
+  return %arg0 : memref<128x256xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/102093


More information about the Mlir-commits mailing list