[Mlir-commits] [mlir] [mlir] [bufferize] fix crash when bufferize function without func.return returning op (PR #120675)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 19 19:32:32 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: donald chen (cxy-1993)
<details>
<summary>Changes</summary>
This patch fixes bug where func.return could not be found by using return-like trait to locate the function's return operation in 'getReturnOps'.
Fixed https://github.com/llvm/llvm-project/issues/120535
---
Full diff: https://github.com/llvm/llvm-project/pull/120675.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+1-1)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+10-8)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+7-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac73..caf157b87be872 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -23,7 +23,7 @@ class FuncOp;
namespace bufferization {
/// Helper function that returns all func.return ops in the given function.
-SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
+SmallVector<Operation *> getReturnOps(func::FuncOp funcOp);
namespace func_ext {
/// The state of analysis of a FuncOp.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4dd..7a53c4d17ca085 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -20,11 +20,13 @@
namespace mlir {
/// Return all func.return ops in the given function.
-SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
- SmallVector<func::ReturnOp> result;
- for (Block &b : funcOp.getBody())
- if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
- result.push_back(returnOp);
+SmallVector<Operation *> bufferization::getReturnOps(func::FuncOp funcOp) {
+ SmallVector<Operation *> result;
+ for (Block &b : funcOp.getBody()) {
+ Operation *terminator = b.getTerminator();
+ if (terminator->hasTrait<OpTrait::ReturnLike>())
+ result.push_back(terminator);
+ }
return result;
}
@@ -439,7 +441,7 @@ struct FuncOpInterface
return failure();
// 2. Bufferize the operands of the all return op.
- for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+ for (Operation *returnOp : getReturnOps(funcOp)) {
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
SmallVector<Value> returnValues;
@@ -457,11 +459,11 @@ struct FuncOpInterface
// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- returnOp.getLoc(), bufferizedType, returnVal);
+ returnOp->getLoc(), bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}
- returnOp.getOperandsMutable().assign(returnValues);
+ returnOp->setOperands(returnValues);
}
// 3. Set the new function type.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 71ea0fd9d43cde..0ba4a1ecf79922 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
}
// Find all func.return ops.
- SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Operation *> returnOps = getReturnOps(funcOp);
assert(!returnOps.empty() && "expected at least one ReturnOp");
// Build alias sets. Merge all aliases from all func.return ops.
@@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
int64_t bbArgIdx = bbArg.getArgNumber();
// Store aliases in a set, so that we don't add the same alias twice.
SetVector<int64_t> aliases;
- for (func::ReturnOp returnOp : returnOps) {
+ for (Operation *returnOp : returnOps) {
for (OpOperand &returnVal : returnOp->getOpOperands()) {
if (isa<RankedTensorType>(returnVal.get().getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
@@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// argument for the i-th operand. In contrast to aliasing information,
// which is just "merged", equivalence information must match across all
// func.return ops.
- for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+ for (Operation *returnOp : ArrayRef(returnOps).drop_front()) {
std::optional<int64_t> maybeEquiv =
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
if (maybeEquiv != bbArgIdx) {
@@ -398,7 +398,7 @@ static Value unpackCast(Value v) {
/// func.return ops. This function returns as many types as the return ops have
/// operands. If the i-th operand is not the same for all func.return ops, then
/// the i-th returned type is an "empty" type.
-static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+static SmallVector<Type> getReturnTypes(SmallVector<Operation *> returnOps) {
assert(!returnOps.empty() && "expected at least one ReturnOp");
int numOperands = returnOps.front()->getNumOperands();
@@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
return;
// Compute the common result types of all return ops.
- SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Operation *> returnOps = getReturnOps(funcOp);
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
// Remove direct casts.
- for (func::ReturnOp returnOp : returnOps) {
+ for (Operation *returnOp : returnOps) {
for (OpOperand &operand : returnOp->getOpOperands()) {
// Bail if no common result type was found.
if (resultTypes[operand.getOperandNumber()]) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e65c5b92949f6e..940f280c2aea4d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -268,4 +268,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
%r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
return %0, %r : tensor<5xf32>, f32
-}
\ No newline at end of file
+}
+
+// -----
+// CHECK-LABEL: @llvm_return
+func.func @llvm_return() {
+ llvm.return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/120675
More information about the Mlir-commits
mailing list