[Mlir-commits] [mlir] [mlir][bufferize] Add hoist-dynamic-allocs-option to buffer-results-to-out-params (PR #160985)
lonely eagle
llvmlistbot at llvm.org
Sat Sep 27 22:54:15 PDT 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/160985
>From c1879c5ec1e78d24d5be3b4a3287b5f752c8c1af Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 06:06:33 +0000
Subject: [PATCH 1/7] Add hoist-static-allocs-option to
buffer-results-to-out-params.
---
.../Dialect/Bufferization/Transforms/Passes.h | 22 +++--
.../Bufferization/Transforms/Passes.td | 2 +
.../Transforms/BufferResultsToOutParams.cpp | 92 +++++++++++++++++--
...ts-to-out-params-hosit-dynamic-allocs.mlir | 79 ++++++++++++++++
...ts-to-out-params-hosit-static-allocs.mlir} | 0
5 files changed, 181 insertions(+), 14 deletions(-)
create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir
rename mlir/test/Transforms/{buffer-results-to-out-params-elim.mlir => buffer-results-to-out-params-hosit-static-allocs.mlir} (100%)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a2409f2796b94..e413a5ede5d64 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -5,6 +5,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/MapVector.h"
namespace mlir {
class FunctionOpInterface;
@@ -131,8 +132,8 @@ struct BufferResultsToOutParamsOpts {
/// Allocator function: Generate a memref allocation with the given type.
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
/// results, we don't allow passing a range of values for dynamic dims.
- using AllocationFn =
- std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+ using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
+ MemRefType, ValueRange)>;
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
@@ -147,8 +148,9 @@ struct BufferResultsToOutParamsOpts {
/// Allocation function; used to allocate a memref.
/// Default memref.alloc is used
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
- MemRefType type) {
- return memref::AllocOp::create(builder, loc, type).getResult();
+ MemRefType type, ValueRange dynamicSizes) {
+ return memref::AllocOp::create(builder, loc, type, dynamicSizes)
+ .getResult();
};
/// Memcpy function; used to create a copy between two memrefs.
@@ -164,15 +166,23 @@ struct BufferResultsToOutParamsOpts {
bool addResultAttribute = false;
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
- /// memref is allocated in the current function.
+ /// memref is static allocated in the current function.
bool hoistStaticAllocs = false;
+
+ /// If true, the pass eliminates the memref.alloc and memcpy if the returned
+ /// memref is dynamic allocated in the current function.
+ bool hoistDynamicAllocs = false;
+
+ /// It maps the shape source of the dynamic shape memref returned by each
+ /// function.
+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>> dynamicSizesMap;
};
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- const BufferResultsToOutParamsOpts &options);
+ BufferResultsToOutParamsOpts &options);
/// Drop all memref function results that are equivalent to a function argument.
LogicalResult dropEquivalentBufferResults(ModuleOp module);
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index a0d113c150c5e..cad44cb15f479 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass
"Add the attribute 'bufferize.result' to all output parameters.">,
Option<"hoistStaticAllocs", "hoist-static-allocs", "bool",
/*default=*/"false", "Hoist static allocations to call sites.">,
+ Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool",
+ /*default=*/"false", "Hoist dynamic allocations to call sites.">,
];
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index e30e094c28467..ae68477f57a0d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -43,6 +43,52 @@ static bool hasStaticIdentityLayout(MemRefType type) {
return type.getLayout().isIdentity();
}
+/// Return the dynamic shapes of the `memref` based on the define op. If the
+/// complete dynamic shape fails to be captured, return an empty value.
+/// Currently, only function parameters are supported for capturing.
+static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) {
+ auto *defOp = memref.getDefiningOp();
+ if (!defOp)
+ return {};
+ auto operands = defOp->getOperands();
+ SmallVector<Value> dynamicSizes;
+ for (Value size : operands) {
+ BlockArgument sizeSrc = mlir::dyn_cast<BlockArgument>(size);
+ if (!sizeSrc)
+ return {};
+
+ bool finded = false;
+ for (BlockArgument argument : funcOp.getArguments()) {
+ if (argument == sizeSrc) {
+ dynamicSizes.push_back(argument);
+ finded = true;
+ break;
+ }
+ }
+ if (!finded)
+ return {};
+ }
+ return dynamicSizes;
+}
+
+/// Returns the dynamic sizes at the callee, through the call relationship
+/// between the caller and callee.
+static ValueRange mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee,
+ ValueRange dynamicSizes) {
+ SmallVector<Value> mapedDynamicSizes;
+ for (Value size : dynamicSizes) {
+ auto callOperands = call.getOperands();
+ for (size_t i = 0, e = callOperands.size(); i < e; ++i) {
+ Value src = callOperands[i];
+ BlockArgument dst = callee.getArgument(i);
+ if (size != dst)
+ continue;
+ mapedDynamicSizes.push_back(src);
+ }
+ }
+ return mapedDynamicSizes;
+}
+
// Updates the func op and entry block.
//
// Any args appended to the entry block are added to `appendedEntryArgs`.
@@ -109,7 +155,7 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
- const bufferization::BufferResultsToOutParamsOpts &options) {
+ bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,12 +166,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
+ SmallVector<SmallVector<Value>> dynamicSizes;
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (options.hoistStaticAllocs &&
+ bool hoistStaticAllocs =
+ options.hoistStaticAllocs &&
+ mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+ bool hoistDynamicAllocs =
+ options.hoistDynamicAllocs &&
+ !mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+ if ((hoistStaticAllocs || hoistDynamicAllocs) &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
- orig.getDefiningOp()) &&
- mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
+ orig.getDefiningOp())) {
orig.replaceAllUsesWith(arg);
+ if (hoistDynamicAllocs) {
+ SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
+ dynamicSizes.push_back(dynamicSize);
+ }
orig.getDefiningOp()->erase();
} else {
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
@@ -134,6 +190,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
}
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
op.erase();
+ auto dynamicSizePair =
+ std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
+ dynamicSizes);
+ options.dynamicSizesMap.insert(dynamicSizePair);
return WalkResult::advance();
});
return failure(res.wasInterrupted());
@@ -166,8 +226,16 @@ updateCalls(ModuleOp module,
}
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
+ SmallVector<SmallVector<Value>> dynamicSizes =
+ options.dynamicSizesMap.lookup(callee);
+ size_t dynamicSizesIndex = 0;
for (Value memref : replaceWithOutParams) {
- if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
+ ValueRange dynamicSize = dynamicSizes.size() > dynamicSizesIndex
+ ? dynamicSizes[dynamicSizesIndex]
+ : SmallVector<Value>();
+ bool memrefStaticShape =
+ cast<MemRefType>(memref.getType()).hasStaticShape();
+ if (!memrefStaticShape && dynamicSize.empty()) {
op.emitError()
<< "cannot create out param for dynamically shaped result";
didFail = true;
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
+
+ if (memrefStaticShape) {
+ dynamicSize = {};
+ } else {
+ ++dynamicSizesIndex;
+ dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
+ }
auto maybeOutParam =
- options.allocationFn(builder, op.getLoc(), allocType);
+ options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
@@ -211,8 +286,7 @@ updateCalls(ModuleOp module,
}
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
- ModuleOp module,
- const bufferization::BufferResultsToOutParamsOpts &options) {
+ ModuleOp module, bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
@@ -243,6 +317,8 @@ struct BufferResultsToOutParamsPass
options.addResultAttribute = true;
if (hoistStaticAllocs)
options.hoistStaticAllocs = true;
+ if (hoistDynamicAllocs)
+ options.hoistDynamicAllocs = true;
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir
new file mode 100644
index 0000000000000..f33eb8e26fbce
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s
+
+func.func private @single_alloc(%size : index) -> (memref<?xf32>) {
+ %alloc = memref.alloc(%size) : memref<?xf32>
+ return %alloc : memref<?xf32>
+}
+
+func.func @single_alloc_test(%size : index) {
+ %alloc = call @single_alloc(%size) : (index) -> (memref<?xf32>)
+ "test.sink"(%alloc) : (memref<?xf32>) -> ()
+}
+
+// CHECK-LABEL: func.func private @single_alloc(
+// CHECK-SAME: %{{.*}}: index,
+// CHECK-SAME: %{{.*}}: memref<?xf32>) {
+
+// CHECK-LABEL: func.func @single_alloc_test(
+// CHECK-SAME: %[[size:.*]]: index) {
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref<?xf32>
+// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref<?xf32>) -> ()
+// CHECK: "test.sink"(%[[alloc]]) : (memref<?xf32>) -> ()
+// CHECK: }
+
+// -----
+
+func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<?xf32>) {
+ %alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
+ %alloc1 = memref.alloc(%size1) : memref<?xf32>
+ return %alloc0, %alloc1 : memref<?x?xf32>, memref<?xf32>
+}
+
+func.func @mult_alloc_test(%size0 : index, %size1: index) {
+ %alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<?xf32>)
+ "test.sink"(%alloc0, %alloc1) : (memref<?x?xf32>, memref<?xf32>) -> ()
+}
+
+// CHECK-LABEL: func private @mult_alloc(
+// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
+// CHECK-SAME: %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?xf32>) {
+
+// CHECK-LABEL: func @mult_alloc_test(
+// CHECK-SAME: %[[size0:.*]]: index,
+// CHECK-SAME: %[[size1:.*]]: index) {
+// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
+// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
+// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref<?x?xf32>, memref<?xf32>) -> ()
+// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref<?x?xf32>, memref<?xf32>) -> ()
+// CHECK: }
+
+
+// -----
+
+func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) {
+ %alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
+ %alloc1 = memref.alloc() : memref<4xf32>
+ %alloc2 = memref.alloc(%size1) : memref<?xf32>
+ return %alloc0, %alloc1, %alloc2 : memref<?x?xf32>, memref<4xf32>, memref<?xf32>
+}
+
+func.func @complex_alloc_test(%size0 : index, %size1: index) {
+ %alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>)
+ "test.sink"(%alloc0, %alloc1, %alloc2) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
+}
+
+// CHECK-LABEL: func private @complex_alloc(
+// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
+// CHECK-SAME: %{{.*}}: memref<?x?xf32>,
+// CHECK-SAME: %{{.*}}: memref<4xf32>,
+// CHECK-SAME: %{{.*}}: memref<?xf32>) {
+
+// CHECK-LABEL: func @complex_alloc_test(
+// CHECK-SAME: %[[size0:.*]]: index,
+// CHECK-SAME: %[[size1:.*]]: index) {
+// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
+// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32>
+// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
+// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
+// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
+// CHECK: }
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir
similarity index 100%
rename from mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
rename to mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir
>From 1fe33cf8d4d3218cc7cb042005255fac17eb3d4e Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 06:15:19 +0000
Subject: [PATCH 2/7] clearup Passes.h
---
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e413a5ede5d64..6ded148ce9d84 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -5,7 +5,6 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
-#include "llvm/ADT/MapVector.h"
namespace mlir {
class FunctionOpInterface;
>From 348496966fee2eb876fc1ef658cc7b2966c72cad Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 06:30:58 +0000
Subject: [PATCH 3/7] fix build problem
---
.../Transforms/BufferResultsToOutParams.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index ae68477f57a0d..1160f4232172e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -46,7 +46,7 @@ static bool hasStaticIdentityLayout(MemRefType type) {
/// Return the dynamic shapes of the `memref` based on the define op. If the
/// complete dynamic shape fails to be captured, return an empty value.
/// Currently, only function parameters are supported for capturing.
-static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) {
+static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
auto *defOp = memref.getDefiningOp();
if (!defOp)
return {};
@@ -73,8 +73,9 @@ static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) {
/// Returns the dynamic sizes at the callee, through the call relationship
/// between the caller and callee.
-static ValueRange mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee,
- ValueRange dynamicSizes) {
+static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
+ func::FuncOp callee,
+ ValueRange dynamicSizes) {
SmallVector<Value> mapedDynamicSizes;
for (Value size : dynamicSizes) {
auto callOperands = call.getOperands();
@@ -230,9 +231,9 @@ updateCalls(ModuleOp module,
options.dynamicSizesMap.lookup(callee);
size_t dynamicSizesIndex = 0;
for (Value memref : replaceWithOutParams) {
- ValueRange dynamicSize = dynamicSizes.size() > dynamicSizesIndex
- ? dynamicSizes[dynamicSizesIndex]
- : SmallVector<Value>();
+ SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
+ ? dynamicSizes[dynamicSizesIndex]
+ : SmallVector<Value>();
bool memrefStaticShape =
cast<MemRefType>(memref.getType()).hasStaticShape();
if (!memrefStaticShape && dynamicSize.empty()) {
>From 100dfcc773b7b82014d27f78b84d4ef5827297b1 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 11:41:39 +0000
Subject: [PATCH 4/7] update.
---
.../Dialect/Bufferization/Transforms/Passes.h | 6 +--
.../Transforms/BufferResultsToOutParams.cpp | 54 ++++++++++---------
2 files changed, 29 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 6ded148ce9d84..78bd33ff619ce 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -171,17 +171,13 @@ struct BufferResultsToOutParamsOpts {
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
/// memref is dynamic allocated in the current function.
bool hoistDynamicAllocs = false;
-
- /// It maps the shape source of the dynamic shape memref returned by each
- /// function.
- llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>> dynamicSizesMap;
};
/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
- BufferResultsToOutParamsOpts &options);
+ const BufferResultsToOutParamsOpts &options);
/// Drop all memref function results that are equivalent to a function argument.
LogicalResult dropEquivalentBufferResults(ModuleOp module);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 1160f4232172e..a5a7b6222125d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -23,6 +23,8 @@ namespace bufferization {
using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
+using AllocDynamicSizesMap =
+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -43,30 +45,24 @@ static bool hasStaticIdentityLayout(MemRefType type) {
return type.getLayout().isIdentity();
}
-/// Return the dynamic shapes of the `memref` based on the define op. If the
+/// Return the dynamic shapes of the `memref` based on the defining op. If the
/// complete dynamic shape fails to be captured, return an empty value.
-/// Currently, only function parameters are supported for capturing.
+/// Currently, only function block arguments are supported for capturing.
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
- auto *defOp = memref.getDefiningOp();
+ Operation *defOp = memref.getDefiningOp();
if (!defOp)
return {};
auto operands = defOp->getOperands();
SmallVector<Value> dynamicSizes;
for (Value size : operands) {
- BlockArgument sizeSrc = mlir::dyn_cast<BlockArgument>(size);
+ BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
if (!sizeSrc)
return {};
- bool finded = false;
- for (BlockArgument argument : funcOp.getArguments()) {
- if (argument == sizeSrc) {
- dynamicSizes.push_back(argument);
- finded = true;
- break;
- }
- }
- if (!finded)
+ auto iter = llvm::find(funcOp.getArguments(), sizeSrc);
+ if (!iter)
return {};
+ dynamicSizes.push_back(*iter);
}
return dynamicSizes;
}
@@ -76,7 +72,7 @@ static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
func::FuncOp callee,
ValueRange dynamicSizes) {
- SmallVector<Value> mapedDynamicSizes;
+ SmallVector<Value> mappedDynamicSizes;
for (Value size : dynamicSizes) {
auto callOperands = call.getOperands();
for (size_t i = 0, e = callOperands.size(); i < e; ++i) {
@@ -84,10 +80,12 @@ static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
BlockArgument dst = callee.getArgument(i);
if (size != dst)
continue;
- mapedDynamicSizes.push_back(src);
+ mappedDynamicSizes.push_back(src);
}
}
- return mapedDynamicSizes;
+ assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
+ "could not find all dynamic sizes");
+ return mappedDynamicSizes;
}
// Updates the func op and entry block.
@@ -156,7 +154,8 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
- bufferization::BufferResultsToOutParamsOpts &options) {
+ AllocDynamicSizesMap &map,
+ const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
@@ -171,10 +170,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
bool hoistStaticAllocs =
options.hoistStaticAllocs &&
- mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+ cast<MemRefType>(orig.getType()).hasStaticShape();
bool hoistDynamicAllocs =
options.hoistDynamicAllocs &&
- !mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
+ !cast<MemRefType>(orig.getType()).hasStaticShape();
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
orig.getDefiningOp())) {
@@ -194,7 +193,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
auto dynamicSizePair =
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
dynamicSizes);
- options.dynamicSizesMap.insert(dynamicSizePair);
+ map.insert(dynamicSizePair);
return WalkResult::advance();
});
return failure(res.wasInterrupted());
@@ -203,7 +202,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult
-updateCalls(ModuleOp module,
+updateCalls(ModuleOp module, AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
@@ -227,8 +226,7 @@ updateCalls(ModuleOp module,
}
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
- SmallVector<SmallVector<Value>> dynamicSizes =
- options.dynamicSizesMap.lookup(callee);
+ SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
size_t dynamicSizesIndex = 0;
for (Value memref : replaceWithOutParams) {
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
@@ -287,7 +285,11 @@ updateCalls(ModuleOp module,
}
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
- ModuleOp module, bufferization::BufferResultsToOutParamsOpts &options) {
+ ModuleOp module,
+ const bufferization::BufferResultsToOutParamsOpts &options) {
+ /// It maps the shape source of the dynamic shape memref returned by each
+ /// function.
+ AllocDynamicSizesMap map;
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
@@ -297,11 +299,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
- if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
+ if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
return failure();
}
}
- if (failed(updateCalls(module, options)))
+ if (failed(updateCalls(module, map, options)))
return failure();
return success();
}
>From c44b91ec4161d7a3a5f92193e8e2c5017ad26d45 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sat, 27 Sep 2025 11:42:36 +0000
Subject: [PATCH 5/7] fix nit.
---
.../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index a5a7b6222125d..06f6acd0febc8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -287,8 +287,8 @@ updateCalls(ModuleOp module, AllocDynamicSizesMap &map,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
const bufferization::BufferResultsToOutParamsOpts &options) {
- /// It maps the shape source of the dynamic shape memref returned by each
- /// function.
+ // It maps the shape source of the dynamic shape memref returned by each
+ // function.
AllocDynamicSizesMap map;
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
>From bb63a6d0186daa2013a58bd7db6a8436939fef29 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 28 Sep 2025 02:50:25 +0000
Subject: [PATCH 6/7] fix nit.
---
.../mlir/Dialect/Bufferization/Transforms/Passes.h | 4 ++--
.../Transforms/BufferResultsToOutParams.cpp | 13 ++++++-------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 78bd33ff619ce..67ac487d8226d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -165,11 +165,11 @@ struct BufferResultsToOutParamsOpts {
bool addResultAttribute = false;
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
- /// memref is static allocated in the current function.
+ /// memref is allocated in the current function.
bool hoistStaticAllocs = false;
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
- /// memref is dynamic allocated in the current function.
+ /// memref is allocated in the current function and has dynamic shape.
bool hoistDynamicAllocs = false;
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 06f6acd0febc8..aec54dfe7ceab 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -59,8 +59,9 @@ static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
if (!sizeSrc)
return {};
- auto iter = llvm::find(funcOp.getArguments(), sizeSrc);
- if (!iter)
+ auto arguments = funcOp.getArguments();
+ auto iter = llvm::find(arguments, sizeSrc);
+ if (iter == arguments.end())
return {};
dynamicSizes.push_back(*iter);
}
@@ -74,10 +75,8 @@ static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
ValueRange dynamicSizes) {
SmallVector<Value> mappedDynamicSizes;
for (Value size : dynamicSizes) {
- auto callOperands = call.getOperands();
- for (size_t i = 0, e = callOperands.size(); i < e; ++i) {
- Value src = callOperands[i];
- BlockArgument dst = callee.getArgument(i);
+ for (auto [src, dst] :
+ llvm::zip_first(call.getOperands(), callee.getArguments())) {
if (size != dst)
continue;
mappedDynamicSizes.push_back(src);
@@ -202,7 +201,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult
-updateCalls(ModuleOp module, AllocDynamicSizesMap &map,
+updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
>From 6e4abeb2bf1020d88566ef8c12fe9f93a4e5250a Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Sun, 28 Sep 2025 03:10:26 +0000
Subject: [PATCH 7/7] supoort realloc.
---
.../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index aec54dfe7ceab..25f941dc16516 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -55,10 +55,12 @@ static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
auto operands = defOp->getOperands();
SmallVector<Value> dynamicSizes;
for (Value size : operands) {
+ if (!isa<IndexType>(size.getType()))
+ continue;
+
BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
if (!sizeSrc)
return {};
-
auto arguments = funcOp.getArguments();
auto iter = llvm::find(arguments, sizeSrc);
if (iter == arguments.end())
More information about the Mlir-commits
mailing list