[Mlir-commits] [mlir] [mlir][bufferize] Add hoist-dynamic-allocs-option to buffer-results-to-out-params (PR #160985)

lonely eagle llvmlistbot at llvm.org
Fri Sep 26 23:31:13 PDT 2025


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/160985

>From c1879c5ec1e78d24d5be3b4a3287b5f752c8c1af Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 06:06:33 +0000
Subject: [PATCH 1/3] Add hoist-static-allocs-option to
 buffer-results-to-out-params.

---
 .../Dialect/Bufferization/Transforms/Passes.h | 22 +++--
 .../Bufferization/Transforms/Passes.td        |  2 +
 .../Transforms/BufferResultsToOutParams.cpp   | 92 +++++++++++++++++--
 ...ts-to-out-params-hosit-dynamic-allocs.mlir | 79 ++++++++++++++++
 ...ts-to-out-params-hosit-static-allocs.mlir} |  0
 5 files changed, 181 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir
 rename mlir/test/Transforms/{buffer-results-to-out-params-elim.mlir => buffer-results-to-out-params-hosit-static-allocs.mlir} (100%)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a2409f2796b94..e413a5ede5d64 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -5,6 +5,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/MapVector.h"
 
 namespace mlir {
 class FunctionOpInterface;
@@ -131,8 +132,8 @@ struct BufferResultsToOutParamsOpts {
   /// Allocator function: Generate a memref allocation with the given type.
   /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
   /// results, we don't allow passing a range of values for dynamic dims.
-  using AllocationFn =
-      std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+  using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
+                                                      MemRefType, ValueRange)>;
 
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
@@ -147,8 +148,9 @@ struct BufferResultsToOutParamsOpts {
   /// Allocation function; used to allocate a memref.
   /// Default memref.alloc is used
   AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
-                                 MemRefType type) {
-    return memref::AllocOp::create(builder, loc, type).getResult();
+                                 MemRefType type, ValueRange dynamicSizes) {
+    return memref::AllocOp::create(builder, loc, type, dynamicSizes)
+        .getResult();
   };
 
   /// Memcpy function; used to create a copy between two memrefs.
@@ -164,15 +166,23 @@ struct BufferResultsToOutParamsOpts {
   bool addResultAttribute = false;
 
   /// If true, the pass eliminates the memref.alloc and memcpy if the returned
-  /// memref is allocated in the current function.
+  /// memref is static allocated in the current function.
   bool hoistStaticAllocs = false;
+
+  /// If true, the pass eliminates the memref.alloc and memcpy if the returned
+  /// memref is dynamic allocated in the current function.
+  bool hoistDynamicAllocs = false;
+
+  /// It maps the shape source of the dynamic shape memref returned by each
+  /// function.
+  llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>> dynamicSizesMap;
 };
 
 /// Replace buffers that are returned from a function with an out parameter.
 /// Also update all call sites.
 LogicalResult
 promoteBufferResultsToOutParams(ModuleOp module,
-                                const BufferResultsToOutParamsOpts &options);
+                                BufferResultsToOutParamsOpts &options);
 
 /// Drop all memref function results that are equivalent to a function argument.
 LogicalResult dropEquivalentBufferResults(ModuleOp module);
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index a0d113c150c5e..cad44cb15f479 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass
               "Add the attribute 'bufferize.result' to all output parameters.">,
        Option<"hoistStaticAllocs", "hoist-static-allocs", "bool",
               /*default=*/"false", "Hoist static allocations to call sites.">,
+       Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool",
+              /*default=*/"false", "Hoist dynamic allocations to call sites.">,
   ];
   let dependentDialects = ["memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index e30e094c28467..ae68477f57a0d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -43,6 +43,52 @@ static bool hasStaticIdentityLayout(MemRefType type) {
   return type.getLayout().isIdentity();
 }
 
+/// Return the dynamic shapes of the `memref` based on the define op. If the
+/// complete dynamic shape fails to be captured, return an empty value.
+/// Currently, only function parameters are supported for capturing.
+static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) {
+  auto *defOp = memref.getDefiningOp();
+  if (!defOp)
+    return {};
+  auto operands = defOp->getOperands();
+  SmallVector<Value> dynamicSizes;
+  for (Value size : operands) {
+    BlockArgument sizeSrc = mlir::dyn_cast<BlockArgument>(size);
+    if (!sizeSrc)
+      return {};
+
+    bool finded = false;
+    for (BlockArgument argument : funcOp.getArguments()) {
+      if (argument == sizeSrc) {
+        dynamicSizes.push_back(argument);
+        finded = true;
+        break;
+      }
+    }
+    if (!finded)
+      return {};
+  }
+  return dynamicSizes;
+}
+
+/// Returns the dynamic sizes at the callee, through the call relationship
+/// between the caller and callee.
+static ValueRange mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee,
+                                         ValueRange dynamicSizes) {
+  SmallVector<Value> mapedDynamicSizes;
+  for (Value size : dynamicSizes) {
+    auto callOperands = call.getOperands();
+    for (size_t i = 0, e = callOperands.size(); i < e; ++i) {
+      Value src = callOperands[i];
+      BlockArgument dst = callee.getArgument(i);
+      if (size != dst)
+        continue;
+      mapedDynamicSizes.push_back(src);
+    }
+  }
+  return mapedDynamicSizes;
+}
+
 // Updates the func op and entry block.
 //
 // Any args appended to the entry block are added to `appendedEntryArgs`.
@@ -109,7 +155,7 @@ updateFuncOp(func::FuncOp func,
 // the given out-params.
 static LogicalResult
 updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
-                const bufferization::BufferResultsToOutParamsOpts &options) {
+                bufferization::BufferResultsToOutParamsOpts &options) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,12 +166,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
+    SmallVector<SmallVector<Value>> dynamicSizes;
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (options.hoistStaticAllocs &&
+      bool hoistStaticAllocs =
+          options.hoistStaticAllocs &&
+          mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+      bool hoistDynamicAllocs =
+          options.hoistDynamicAllocs &&
+          !mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+      if ((hoistStaticAllocs || hoistDynamicAllocs) &&
           isa_and_nonnull<bufferization::AllocationOpInterface>(
-              orig.getDefiningOp()) &&
-          mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
+              orig.getDefiningOp())) {
         orig.replaceAllUsesWith(arg);
+        if (hoistDynamicAllocs) {
+          SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
+          dynamicSizes.push_back(dynamicSize);
+        }
         orig.getDefiningOp()->erase();
       } else {
         if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
@@ -134,6 +190,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
     }
     func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
     op.erase();
+    auto dynamicSizePair =
+        std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
+                                                                 dynamicSizes);
+    options.dynamicSizesMap.insert(dynamicSizePair);
     return WalkResult::advance();
   });
   return failure(res.wasInterrupted());
@@ -166,8 +226,16 @@ updateCalls(ModuleOp module,
     }
     SmallVector<Value, 6> outParams;
     OpBuilder builder(op);
+    SmallVector<SmallVector<Value>> dynamicSizes =
+        options.dynamicSizesMap.lookup(callee);
+    size_t dynamicSizesIndex = 0;
     for (Value memref : replaceWithOutParams) {
-      if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
+      ValueRange dynamicSize = dynamicSizes.size() > dynamicSizesIndex
+                                   ? dynamicSizes[dynamicSizesIndex]
+                                   : SmallVector<Value>();
+      bool memrefStaticShape =
+          cast<MemRefType>(memref.getType()).hasStaticShape();
+      if (!memrefStaticShape && dynamicSize.empty()) {
         op.emitError()
             << "cannot create out param for dynamically shaped result";
         didFail = true;
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
       auto allocType =
           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
                           AffineMap(), memrefType.getMemorySpace());
+
+      if (memrefStaticShape) {
+        dynamicSize = {};
+      } else {
+        ++dynamicSizesIndex;
+        dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
+      }
       auto maybeOutParam =
-          options.allocationFn(builder, op.getLoc(), allocType);
+          options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
       if (failed(maybeOutParam)) {
         op.emitError() << "failed to create allocation op";
         didFail = true;
@@ -211,8 +286,7 @@ updateCalls(ModuleOp module,
 }
 
 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
-    ModuleOp module,
-    const bufferization::BufferResultsToOutParamsOpts &options) {
+    ModuleOp module, bufferization::BufferResultsToOutParamsOpts &options) {
   for (auto func : module.getOps<func::FuncOp>()) {
     if (!options.filterFn(&func))
       continue;
@@ -243,6 +317,8 @@ struct BufferResultsToOutParamsPass
       options.addResultAttribute = true;
     if (hoistStaticAllocs)
       options.hoistStaticAllocs = true;
+    if (hoistDynamicAllocs)
+      options.hoistDynamicAllocs = true;
 
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir
new file mode 100644
index 0000000000000..f33eb8e26fbce
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s
+
+func.func private @single_alloc(%size : index) -> (memref<?xf32>) {
+  %alloc = memref.alloc(%size) : memref<?xf32>
+  return %alloc : memref<?xf32>
+}
+
+func.func @single_alloc_test(%size : index) {
+  %alloc = call @single_alloc(%size) : (index) -> (memref<?xf32>)
+  "test.sink"(%alloc) : (memref<?xf32>) -> ()
+}
+
+// CHECK-LABEL: func.func private @single_alloc(
+//  CHECK-SAME:   %{{.*}}: index,
+//  CHECK-SAME:   %{{.*}}: memref<?xf32>) {
+
+// CHECK-LABEL: func.func @single_alloc_test(
+//  CHECK-SAME:   %[[size:.*]]: index) {
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[size]]) : memref<?xf32>
+//       CHECK:   call @single_alloc(%[[size]], %[[alloc]]) : (index, memref<?xf32>) -> ()
+//       CHECK:   "test.sink"(%[[alloc]]) : (memref<?xf32>) -> ()
+//       CHECK: }
+
+// -----
+
+func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<?xf32>) {
+  %alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
+  %alloc1 = memref.alloc(%size1) : memref<?xf32>
+  return %alloc0, %alloc1 : memref<?x?xf32>, memref<?xf32>
+}
+
+func.func @mult_alloc_test(%size0 : index, %size1: index) {
+  %alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<?xf32>)
+  "test.sink"(%alloc0, %alloc1) : (memref<?x?xf32>, memref<?xf32>) -> ()
+}
+
+// CHECK-LABEL: func private @mult_alloc(
+//  CHECK-SAME:    %{{.*}}: index,  %{{.*}}: index,
+//  CHECK-SAME:    %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?xf32>) {
+
+// CHECK-LABEL: func @mult_alloc_test(
+//  CHECK-SAME:   %[[size0:.*]]: index,
+//  CHECK-SAME:   %[[size1:.*]]: index) {
+//       CHECK:   %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
+//       CHECK:   %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
+//       CHECK:   call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref<?x?xf32>, memref<?xf32>) -> ()
+//       CHECK:   "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref<?x?xf32>, memref<?xf32>) -> ()
+//       CHECK: }
+
+
+// -----
+
+func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) {
+  %alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
+  %alloc1 = memref.alloc() : memref<4xf32>
+  %alloc2 = memref.alloc(%size1) : memref<?xf32>
+  return %alloc0, %alloc1, %alloc2 : memref<?x?xf32>, memref<4xf32>, memref<?xf32>
+}
+
+func.func @complex_alloc_test(%size0 : index, %size1: index) {
+  %alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>)
+  "test.sink"(%alloc0, %alloc1, %alloc2) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
+}
+
+// CHECK-LABEL: func private @complex_alloc(
+//  CHECK-SAME:   %{{.*}}: index, %{{.*}}: index,
+//  CHECK-SAME:   %{{.*}}: memref<?x?xf32>,
+//  CHECK-SAME:   %{{.*}}: memref<4xf32>,
+//  CHECK-SAME:   %{{.*}}: memref<?xf32>) {
+
+// CHECK-LABEL: func @complex_alloc_test(
+//  CHECK-SAME:   %[[size0:.*]]: index,
+//  CHECK-SAME:   %[[size1:.*]]: index) {
+//       CHECK:   %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
+//       CHECK:   %[[alloc1:.*]] = memref.alloc() : memref<4xf32>
+//       CHECK:   %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
+//       CHECK:   call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
+//       CHECK:   "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
+//       CHECK: }
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir
similarity index 100%
rename from mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
rename to mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir

>From 1fe33cf8d4d3218cc7cb042005255fac17eb3d4e Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 06:15:19 +0000
Subject: [PATCH 2/3] clearup Passes.h

---
 mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e413a5ede5d64..6ded148ce9d84 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -5,7 +5,6 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
-#include "llvm/ADT/MapVector.h"
 
 namespace mlir {
 class FunctionOpInterface;

>From 348496966fee2eb876fc1ef658cc7b2966c72cad Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 06:30:58 +0000
Subject: [PATCH 3/3] fix build problem

---
 .../Transforms/BufferResultsToOutParams.cpp         | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index ae68477f57a0d..1160f4232172e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -46,7 +46,7 @@ static bool hasStaticIdentityLayout(MemRefType type) {
 /// Return the dynamic shapes of the `memref` based on the define op. If the
 /// complete dynamic shape fails to be captured, return an empty value.
 /// Currently, only function parameters are supported for capturing.
-static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) {
+static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
   auto *defOp = memref.getDefiningOp();
   if (!defOp)
     return {};
@@ -73,8 +73,9 @@ static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) {
 
 /// Returns the dynamic sizes at the callee, through the call relationship
 /// between the caller and callee.
-static ValueRange mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee,
-                                         ValueRange dynamicSizes) {
+static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
+                                                 func::FuncOp callee,
+                                                 ValueRange dynamicSizes) {
   SmallVector<Value> mapedDynamicSizes;
   for (Value size : dynamicSizes) {
     auto callOperands = call.getOperands();
@@ -230,9 +231,9 @@ updateCalls(ModuleOp module,
         options.dynamicSizesMap.lookup(callee);
     size_t dynamicSizesIndex = 0;
     for (Value memref : replaceWithOutParams) {
-      ValueRange dynamicSize = dynamicSizes.size() > dynamicSizesIndex
-                                   ? dynamicSizes[dynamicSizesIndex]
-                                   : SmallVector<Value>();
+      SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
+                                           ? dynamicSizes[dynamicSizesIndex]
+                                           : SmallVector<Value>();
       bool memrefStaticShape =
           cast<MemRefType>(memref.getType()).hasStaticShape();
       if (!memrefStaticShape && dynamicSize.empty()) {



More information about the Mlir-commits mailing list