[Mlir-commits] [mlir] [mlir] [bufferize] fix crash when bufferize function without func.return returning op (PR #120675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 19 19:32:32 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

This patch fixes bug where func.return could not be found by using return-like trait to locate the function's return operation in 'getReturnOps'.

Fixed https://github.com/llvm/llvm-project/issues/120535

---
Full diff: https://github.com/llvm/llvm-project/pull/120675.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+10-8) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+7-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac73..caf157b87be872 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -23,7 +23,7 @@ class FuncOp;
 
 namespace bufferization {
 /// Helper function that returns all func.return ops in the given function.
-SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
+SmallVector<Operation *> getReturnOps(func::FuncOp funcOp);
 
 namespace func_ext {
 /// The state of analysis of a FuncOp.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4dd..7a53c4d17ca085 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -20,11 +20,13 @@
 
 namespace mlir {
 /// Return all func.return ops in the given function.
-SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
-  SmallVector<func::ReturnOp> result;
-  for (Block &b : funcOp.getBody())
-    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
-      result.push_back(returnOp);
+SmallVector<Operation *> bufferization::getReturnOps(func::FuncOp funcOp) {
+  SmallVector<Operation *> result;
+  for (Block &b : funcOp.getBody()) {
+    Operation *terminator = b.getTerminator();
+    if (terminator->hasTrait<OpTrait::ReturnLike>())
+      result.push_back(terminator);
+  }
   return result;
 }
 
@@ -439,7 +441,7 @@ struct FuncOpInterface
         return failure();
 
     // 2. Bufferize the operands of the all return op.
-    for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+    for (Operation *returnOp : getReturnOps(funcOp)) {
       assert(returnOp->getNumOperands() == retTypes.size() &&
              "incorrect number of return values");
       SmallVector<Value> returnValues;
@@ -457,11 +459,11 @@ struct FuncOpInterface
         // Note: If `inferFunctionResultLayout = true`, casts are later folded
         // away.
         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-            returnOp.getLoc(), bufferizedType, returnVal);
+            returnOp->getLoc(), bufferizedType, returnVal);
         returnValues.push_back(toMemrefOp);
       }
 
-      returnOp.getOperandsMutable().assign(returnValues);
+      returnOp->setOperands(returnValues);
     }
 
     // 3. Set the new function type.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 71ea0fd9d43cde..0ba4a1ecf79922 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
   }
 
   // Find all func.return ops.
-  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Operation *> returnOps = getReturnOps(funcOp);
   assert(!returnOps.empty() && "expected at least one ReturnOp");
 
   // Build alias sets. Merge all aliases from all func.return ops.
@@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
       int64_t bbArgIdx = bbArg.getArgNumber();
       // Store aliases in a set, so that we don't add the same alias twice.
       SetVector<int64_t> aliases;
-      for (func::ReturnOp returnOp : returnOps) {
+      for (Operation *returnOp : returnOps) {
         for (OpOperand &returnVal : returnOp->getOpOperands()) {
           if (isa<RankedTensorType>(returnVal.get().getType())) {
             int64_t returnIdx = returnVal.getOperandNumber();
@@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
     // argument for the i-th operand. In contrast to aliasing information,
     // which is just "merged", equivalence information must match across all
     // func.return ops.
-    for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+    for (Operation *returnOp : ArrayRef(returnOps).drop_front()) {
       std::optional<int64_t> maybeEquiv =
           findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
       if (maybeEquiv != bbArgIdx) {
@@ -398,7 +398,7 @@ static Value unpackCast(Value v) {
 /// func.return ops. This function returns as many types as the return ops have
 /// operands. If the i-th operand is not the same for all func.return ops, then
 /// the i-th returned type is an "empty" type.
-static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+static SmallVector<Type> getReturnTypes(SmallVector<Operation *> returnOps) {
   assert(!returnOps.empty() && "expected at least one ReturnOp");
   int numOperands = returnOps.front()->getNumOperands();
 
@@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
     return;
 
   // Compute the common result types of all return ops.
-  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Operation *> returnOps = getReturnOps(funcOp);
   SmallVector<Type> resultTypes = getReturnTypes(returnOps);
 
   // Remove direct casts.
-  for (func::ReturnOp returnOp : returnOps) {
+  for (Operation *returnOp : returnOps) {
     for (OpOperand &operand : returnOp->getOpOperands()) {
       // Bail if no common result type was found.
       if (resultTypes[operand.getOperandNumber()]) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e65c5b92949f6e..940f280c2aea4d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -268,4 +268,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
   %r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
 
   return %0, %r : tensor<5xf32>, f32
-}
\ No newline at end of file
+}
+
+// -----
+// CHECK-LABEL: @llvm_return
+func.func @llvm_return() {
+   llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/120675


More information about the Mlir-commits mailing list