[Mlir-commits] [mlir] 50d65a5 - [mlir][bufferize] Make drop-equivalent-buffer-results support mult blocks (#163388)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 17 02:51:44 PDT 2025
Author: lonely eagle
Date: 2025-10-17T17:51:40+08:00
New Revision: 50d65a5675daf3a433615c7676ee04b5f838af0b
URL: https://github.com/llvm/llvm-project/commit/50d65a5675daf3a433615c7676ee04b5f838af0b
DIFF: https://github.com/llvm/llvm-project/commit/50d65a5675daf3a433615c7676ee04b5f838af0b.diff
LOG: [mlir][bufferize] Make drop-equivalent-buffer-results support mult blocks (#163388)
Enable Make drop-equivalent-buffer-results to handle return ops in
multiple blocks within a function.
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71a5ffbb..bc1799099de31 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,37 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the operands at the specified position for all returnOps.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
+ return returnOp.getOperand(pos);
+ });
+}
+
+/// Check if all given values are the same buffer as the block argument (modulo
+/// cast ops).
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,40 +91,45 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
// Compute erased results.
- SmallVector<Value> newReturnValues;
- BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
+ size_t numReturnOps = returnOps.size();
+ size_t numReturnValues = funcOp.getFunctionType().getNumResults();
+ SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
+ BitVector erasedResultIndices(numReturnValues);
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0; i < numReturnValues; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72bedef6c5..f66cf7ae53266 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,32 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func private @mult_return_callee(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[A]] : index
+// CHECK: ^bb2:
+// CHECK: return %[[B]] : index
+func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ cf.cond_br %cond,^a, ^b
+^a:
+ return %casted, %a : tensor<10xf32>, index
+^b:
+ return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+ // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+ %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
+ return %t_res, %v : tensor<10xf32>, index
+}
More information about the Mlir-commits
mailing list