[Mlir-commits] [mlir] [mlir][bufferize] Make drop-equivalent-buffer-results support mult blocks (PR #163388)
lonely eagle
llvmlistbot at llvm.org
Tue Oct 14 09:23:47 PDT 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/163388
>From 5c4c6de5d41b237f9aa2d918c0608f34f6e2131d Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Oct 2025 12:23:17 +0000
Subject: [PATCH] make drop-equivalent-buffer-results support mult blocks.
---
.../DropEquivalentBufferResults.cpp | 68 +++++++++++++------
.../Dialect/Tensor/one-shot-bufferize.mlir | 29 ++++++++
2 files changed, 75 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71a5ffbb..9c300cc347ecf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,38 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the values at the same position in the `returnOps`.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ SmallVector<Value> operands;
+ for (func::ReturnOp returnOp : returnOps) {
+ operands.push_back(returnOp.getOperand(pos));
+ }
+ return operands;
+}
+
+/// Check if the value in operands is equal to the argument.
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
+ func::ReturnOp returnOp = returnOps.front();
// Compute erased results.
- SmallVector<Value> newReturnValues;
+ SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72bedef6c5..508e29303d37b 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,32 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func private @mult_return_callee(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ // CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+ // CHECK: ^bb1:
+ // CHECK: return %[[A]] : index
+ // CHECK: ^bb2:
+ // CHECK: return %[[B]] : index
+ cf.cond_br %cond,^a, ^b
+ ^a:
+ return %casted, %a : tensor<10xf32>, index
+ ^b:
+ return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+ // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+ %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
+ return %t_res, %v : tensor<10xf32>, index
+}
More information about the Mlir-commits
mailing list