[Mlir-commits] [mlir] [mlir][bufferization] Generalize returns to be ops with ReturnLike trait (PR #124949)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 29 08:57:24 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Yi Zhang (cathyzhyi)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/124949.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+1-1)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+12-8)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac73a..caf157b87be8725 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -23,7 +23,7 @@ class FuncOp;
namespace bufferization {
/// Helper function that returns all func.return ops in the given function.
-SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
+SmallVector<Operation *> getReturnOps(func::FuncOp funcOp);
namespace func_ext {
/// The state of analysis of a FuncOp.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4dd7..df2fe08d02c9084 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -20,11 +20,13 @@
namespace mlir {
/// Return all func.return ops in the given function.
-SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
- SmallVector<func::ReturnOp> result;
- for (Block &b : funcOp.getBody())
- if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
- result.push_back(returnOp);
+SmallVector<Operation *> bufferization::getReturnOps(func::FuncOp funcOp) {
+ SmallVector<Operation *> result;
+ for (Block &b : funcOp.getBody()) {
+ Operation *terminator = b.getTerminator();
+ if (terminator->hasTrait<OpTrait::ReturnLike>())
+ result.push_back(b.getTerminator());
+ }
return result;
}
@@ -439,7 +441,7 @@ struct FuncOpInterface
return failure();
// 2. Bufferize the operands of the all return op.
- for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+ for (Operation *returnOp : getReturnOps(funcOp)) {
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
SmallVector<Value> returnValues;
@@ -457,11 +459,13 @@ struct FuncOpInterface
// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- returnOp.getLoc(), bufferizedType, returnVal);
+ returnOp->getLoc(), bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}
- returnOp.getOperandsMutable().assign(returnValues);
+ for (auto [i, operand] : enumerate(returnValues)) {
+ returnOp->setOperand(i, operand);
+ }
}
// 3. Set the new function type.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 71ea0fd9d43cde2..0ba4a1ecf799227 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
}
// Find all func.return ops.
- SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Operation *> returnOps = getReturnOps(funcOp);
assert(!returnOps.empty() && "expected at least one ReturnOp");
// Build alias sets. Merge all aliases from all func.return ops.
@@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
int64_t bbArgIdx = bbArg.getArgNumber();
// Store aliases in a set, so that we don't add the same alias twice.
SetVector<int64_t> aliases;
- for (func::ReturnOp returnOp : returnOps) {
+ for (Operation *returnOp : returnOps) {
for (OpOperand &returnVal : returnOp->getOpOperands()) {
if (isa<RankedTensorType>(returnVal.get().getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
@@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// argument for the i-th operand. In contrast to aliasing information,
// which is just "merged", equivalence information must match across all
// func.return ops.
- for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+ for (Operation *returnOp : ArrayRef(returnOps).drop_front()) {
std::optional<int64_t> maybeEquiv =
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
if (maybeEquiv != bbArgIdx) {
@@ -398,7 +398,7 @@ static Value unpackCast(Value v) {
/// func.return ops. This function returns as many types as the return ops have
/// operands. If the i-th operand is not the same for all func.return ops, then
/// the i-th returned type is an "empty" type.
-static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+static SmallVector<Type> getReturnTypes(SmallVector<Operation *> returnOps) {
assert(!returnOps.empty() && "expected at least one ReturnOp");
int numOperands = returnOps.front()->getNumOperands();
@@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
return;
// Compute the common result types of all return ops.
- SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Operation *> returnOps = getReturnOps(funcOp);
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
// Remove direct casts.
- for (func::ReturnOp returnOp : returnOps) {
+ for (Operation *returnOp : returnOps) {
for (OpOperand &operand : returnOp->getOpOperands()) {
// Bail if no common result type was found.
if (resultTypes[operand.getOperandNumber()]) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/124949
More information about the Mlir-commits
mailing list