[Mlir-commits] [mlir] [mlir][bufferization] Generalize returns to be ops with ReturnLike trait (PR #124949)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 29 08:57:24 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Yi Zhang (cathyzhyi)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/124949.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+12-8) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac73a..caf157b87be8725 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -23,7 +23,7 @@ class FuncOp;
 
 namespace bufferization {
 /// Helper function that returns all func.return ops in the given function.
-SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
+SmallVector<Operation *> getReturnOps(func::FuncOp funcOp);
 
 namespace func_ext {
 /// The state of analysis of a FuncOp.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4dd7..df2fe08d02c9084 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -20,11 +20,13 @@
 
 namespace mlir {
 /// Return all func.return ops in the given function.
-SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
-  SmallVector<func::ReturnOp> result;
-  for (Block &b : funcOp.getBody())
-    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
-      result.push_back(returnOp);
+SmallVector<Operation *> bufferization::getReturnOps(func::FuncOp funcOp) {
+  SmallVector<Operation *> result;
+  for (Block &b : funcOp.getBody()) {
+    Operation *terminator = b.getTerminator();
+    if (terminator->hasTrait<OpTrait::ReturnLike>())
+      result.push_back(b.getTerminator());
+  }
   return result;
 }
 
@@ -439,7 +441,7 @@ struct FuncOpInterface
         return failure();
 
     // 2. Bufferize the operands of the all return op.
-    for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+    for (Operation *returnOp : getReturnOps(funcOp)) {
       assert(returnOp->getNumOperands() == retTypes.size() &&
              "incorrect number of return values");
       SmallVector<Value> returnValues;
@@ -457,11 +459,13 @@ struct FuncOpInterface
         // Note: If `inferFunctionResultLayout = true`, casts are later folded
         // away.
         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-            returnOp.getLoc(), bufferizedType, returnVal);
+            returnOp->getLoc(), bufferizedType, returnVal);
         returnValues.push_back(toMemrefOp);
       }
 
-      returnOp.getOperandsMutable().assign(returnValues);
+      for (auto [i, operand] : enumerate(returnValues)) {
+        returnOp->setOperand(i, operand);
+      }
     }
 
     // 3. Set the new function type.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 71ea0fd9d43cde2..0ba4a1ecf799227 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
   }
 
   // Find all func.return ops.
-  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Operation *> returnOps = getReturnOps(funcOp);
   assert(!returnOps.empty() && "expected at least one ReturnOp");
 
   // Build alias sets. Merge all aliases from all func.return ops.
@@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
       int64_t bbArgIdx = bbArg.getArgNumber();
       // Store aliases in a set, so that we don't add the same alias twice.
       SetVector<int64_t> aliases;
-      for (func::ReturnOp returnOp : returnOps) {
+      for (Operation *returnOp : returnOps) {
         for (OpOperand &returnVal : returnOp->getOpOperands()) {
           if (isa<RankedTensorType>(returnVal.get().getType())) {
             int64_t returnIdx = returnVal.getOperandNumber();
@@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
     // argument for the i-th operand. In contrast to aliasing information,
     // which is just "merged", equivalence information must match across all
     // func.return ops.
-    for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+    for (Operation *returnOp : ArrayRef(returnOps).drop_front()) {
       std::optional<int64_t> maybeEquiv =
           findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
       if (maybeEquiv != bbArgIdx) {
@@ -398,7 +398,7 @@ static Value unpackCast(Value v) {
 /// func.return ops. This function returns as many types as the return ops have
 /// operands. If the i-th operand is not the same for all func.return ops, then
 /// the i-th returned type is an "empty" type.
-static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+static SmallVector<Type> getReturnTypes(SmallVector<Operation *> returnOps) {
   assert(!returnOps.empty() && "expected at least one ReturnOp");
   int numOperands = returnOps.front()->getNumOperands();
 
@@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
     return;
 
   // Compute the common result types of all return ops.
-  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Operation *> returnOps = getReturnOps(funcOp);
   SmallVector<Type> resultTypes = getReturnTypes(returnOps);
 
   // Remove direct casts.
-  for (func::ReturnOp returnOp : returnOps) {
+  for (Operation *returnOp : returnOps) {
     for (OpOperand &operand : returnOp->getOpOperands()) {
       // Bail if no common result type was found.
       if (resultTypes[operand.getOperandNumber()]) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/124949


More information about the Mlir-commits mailing list