[Mlir-commits] [mlir] [mlir][bufferization] Add support for non-unique `func.return` (PR #114017)

Matthias Springer llvmlistbot at llvm.org
Tue Nov 12 04:59:38 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114017

>From 1122ffeddcbcd27838386b952849a29d792dc9f1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 29 Oct 2024 09:51:11 +0100
Subject: [PATCH] [mlir][bufferization] Add support for non-unique
 `func.return`

---
 .../FuncBufferizableOpInterfaceImpl.h         |   4 +
 .../FuncBufferizableOpInterfaceImpl.cpp       |  79 ++++----
 .../Transforms/OneShotModuleBufferize.cpp     | 174 +++++++++++++-----
 .../one-shot-module-bufferize-analysis.mlir   |  46 +++++
 .../one-shot-module-bufferize-invalid.mlir    |  22 +--
 .../Transforms/one-shot-module-bufferize.mlir |  25 +++
 6 files changed, 236 insertions(+), 114 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index 0b91d3d675b7c9..e8e6226460ac73 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 class DialectRegistry;
@@ -21,6 +22,9 @@ class FuncOp;
 } // namespace func
 
 namespace bufferization {
+/// Helper function that returns all func.return ops in the given function.
+SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
+
 namespace func_ext {
 /// The state of analysis of a FuncOp.
 enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 11ed434f774a87..c45678f1e4b4dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -19,6 +19,15 @@
 #include <optional>
 
 namespace mlir {
+/// Return all func.return ops in the given function.
+SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
+  SmallVector<func::ReturnOp> result;
+  for (Block &b : funcOp.getBody())
+    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+      result.push_back(returnOp);
+  return result;
+}
+
 namespace bufferization {
 namespace func_ext {
 
@@ -41,20 +50,6 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 #endif // NDEBUG
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
-  func::ReturnOp returnOp;
-  for (Block &b : funcOp.getBody()) {
-    if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
-}
-
 /// Return the index-th bufferized function argument type. This assumes that the
 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
 /// specified by the user (as per `options.functionArgTypeConverterFn`).
@@ -391,15 +386,6 @@ struct FuncOpInterface
         getBufferType(op, value, options, invocationStack);
   }
 
-  LogicalResult verifyAnalysis(Operation *op,
-                               const AnalysisState &state) const {
-    auto funcOp = cast<func::FuncOp>(op);
-    // TODO: func.func with multiple returns are not supported.
-    if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
-      return op->emitOpError("op without unique func.return is not supported");
-    return success();
-  }
-
   /// Rewrite function bbArgs and return values into buffer form. This function
   /// bufferizes the function signature and the ReturnOp. When the entire
   /// function body has been bufferized, function return types can be switched
@@ -446,41 +432,38 @@ struct FuncOpInterface
       return success();
     }
 
-    // TODO: Support functions with multiple returns.
-    func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-    assert(returnOp && "expected func with single return op");
-    assert(returnOp->getNumOperands() == retTypes.size() &&
-           "incorrect number of return values");
-    Location loc = returnOp.getLoc();
-
     // 1. Bufferize every block.
     for (Block &block : funcOp.getBody())
       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
                                                         options)))
         return failure();
 
-    // 2. Bufferize all operands of the return op.
-    SmallVector<Value> returnValues;
-    for (auto [returnVal, bufferizedType] :
-         llvm::zip_equal(returnOp->getOperands(), retTypes)) {
-      auto tensorType = dyn_cast<TensorType>(returnVal.getType());
-      rewriter.setInsertionPoint(returnOp);
-
-      // If not a tensor type just forward it.
-      if (!tensorType) {
-        returnValues.push_back(returnVal);
-        continue;
+    // 2. Bufferize the operands of the all return op.
+    for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+      assert(returnOp->getNumOperands() == retTypes.size() &&
+             "incorrect number of return values");
+      SmallVector<Value> returnValues;
+      for (auto [returnVal, bufferizedType] :
+           llvm::zip_equal(returnOp->getOperands(), retTypes)) {
+        auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+        rewriter.setInsertionPoint(returnOp);
+
+        // If not a tensor type just forward it.
+        if (!tensorType) {
+          returnValues.push_back(returnVal);
+          continue;
+        }
+
+        // Note: If `inferFunctionResultLayout = true`, casts are later folded
+        // away.
+        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+            returnOp.getLoc(), bufferizedType, returnVal);
+        returnValues.push_back(toMemrefOp);
       }
 
-      // Note: If `inferFunctionResultLayout = true`, casts are later folded
-      // away.
-      Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-          loc, bufferizedType, returnVal);
-      returnValues.push_back(toMemrefOp);
+      returnOp.getOperandsMutable().assign(returnValues);
     }
 
-    returnOp.getOperandsMutable().assign(returnValues);
-
     // 3. Set the new function type.
     funcOp.setType(newFuncType);
     return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index a492bcdd0f3e38..71ea0fd9d43cde 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
   return state.addExtension<FuncAnalysisState>();
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
-  func::ReturnOp returnOp;
-  for (Block &b : funcOp.getBody()) {
-    if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
-}
-
 namespace {
 
 /// Annotate IR with the results of the analysis. For testing purposes only.
@@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
     return success();
   }
 
-  // Support only single return-terminated block in the function.
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
-
-  for (OpOperand &returnVal : returnOp->getOpOperands())
-    if (isa<RankedTensorType>(returnVal.get().getType()))
-      for (BlockArgument bbArg : funcOp.getArguments())
-        if (isa<RankedTensorType>(bbArg.getType())) {
-          int64_t returnIdx = returnVal.getOperandNumber();
-          int64_t bbArgIdx = bbArg.getArgNumber();
-          if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
-            funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
-            if (state.getOptions().testAnalysisOnly)
-              annotateEquivalentReturnBbArg(returnVal, bbArg);
+  // Find all func.return ops.
+  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  assert(!returnOps.empty() && "expected at least one ReturnOp");
+
+  // Build alias sets. Merge all aliases from all func.return ops.
+  for (BlockArgument bbArg : funcOp.getArguments()) {
+    if (isa<RankedTensorType>(bbArg.getType())) {
+      int64_t bbArgIdx = bbArg.getArgNumber();
+      // Store aliases in a set, so that we don't add the same alias twice.
+      SetVector<int64_t> aliases;
+      for (func::ReturnOp returnOp : returnOps) {
+        for (OpOperand &returnVal : returnOp->getOpOperands()) {
+          if (isa<RankedTensorType>(returnVal.get().getType())) {
+            int64_t returnIdx = returnVal.getOperandNumber();
+            if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
+              aliases.insert(returnIdx);
           }
-          if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
-            funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
         }
+      }
+      for (int64_t alias : aliases)
+        funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
+    }
+  }
+
+  // Build equivalence sets.
+  // Helper function that finds an equivalent block argument index for the
+  // given OpOperand. Return std::nullopt if no equivalent block argument could
+  // be found.
+  auto findEquivalentBlockArgIdx =
+      [&](OpOperand &opOperand) -> std::optional<int64_t> {
+    Value v = opOperand.get();
+    if (!isa<TensorType>(v.getType()))
+      return std::nullopt;
+    for (BlockArgument bbArg : funcOp.getArguments()) {
+      if (isa<RankedTensorType>(bbArg.getType())) {
+        if (state.areEquivalentBufferizedValues(v, bbArg)) {
+          if (state.getOptions().testAnalysisOnly)
+            annotateEquivalentReturnBbArg(opOperand, bbArg);
+          return bbArg.getArgNumber();
+        }
+      }
+    }
+    return std::nullopt;
+  };
+
+  int64_t numResults = returnOps.front()->getNumOperands();
+  for (int64_t i = 0; i < numResults; ++i) {
+    // Find the equivalent block argument index for the i-th operand of the
+    // first func.return op.
+    std::optional<int64_t> maybeEquiv =
+        findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
+    if (!maybeEquiv.has_value())
+      continue;
+    int64_t bbArgIdx = *maybeEquiv;
+    bool allEquiv = true;
+
+    // Check if all other func.return ops have the same equivalent block
+    // argument for the i-th operand. In contrast to aliasing information,
+    // which is just "merged", equivalence information must match across all
+    // func.return ops.
+    for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+      std::optional<int64_t> maybeEquiv =
+          findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
+      if (maybeEquiv != bbArgIdx) {
+        allEquiv = false;
+        break;
+      }
+    }
+
+    // All func.return ops have the same equivalent block argument for the i-th
+    // operand.
+    if (allEquiv)
+      funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
+  }
 
   return success();
 }
@@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
   WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
-    if (!funcOp.getBody().empty()) {
-      func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-      if (!returnOp)
-        return funcOp->emitError()
-               << "cannot bufferize a FuncOp with tensors and "
-                  "without a unique ReturnOp";
-    }
-
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
   return success();
 }
 
+/// Helper function that extracts the source from a memref.cast. If the given
+/// value is not a memref.cast result, simply returns the given value.
+static Value unpackCast(Value v) {
+  auto castOp = v.getDefiningOp<memref::CastOp>();
+  if (!castOp)
+    return v;
+  return castOp.getSource();
+}
+
+/// Helper function that returns the return types (skipping casts) of the given
+/// func.return ops. This function returns as many types as the return ops have
+/// operands. If the i-th operand is not the same for all func.return ops, then
+/// the i-th returned type is an "empty" type.
+static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+  assert(!returnOps.empty() && "expected at least one ReturnOp");
+  int numOperands = returnOps.front()->getNumOperands();
+
+  // Helper function that unpacks memref.cast ops and returns the type.
+  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+
+  SmallVector<Type> result;
+  for (int i = 0; i < numOperands; ++i) {
+    // Get the type of the i-th operand of the first func.return ops.
+    Type t = getSourceType(returnOps.front()->getOperand(i));
+
+    // Check if all other func.return ops have a matching operand type.
+    for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
+      if (getSourceType(returnOps[j]->getOperand(i)) != t)
+        t = Type();
+
+    result.push_back(t);
+  }
+
+  return result;
+}
+
 /// Fold return values that are memref casts and update function return types.
 ///
 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
 /// entire function body, a more concise memref type can potentially be used for
 /// the return type of the function.
 static void foldMemRefCasts(func::FuncOp funcOp) {
+  // There is nothing to do for bodiless ops.
   if (funcOp.getBody().empty())
     return;
 
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  SmallVector<Type> resultTypes;
+  // Compute the common result types of all return ops.
+  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
 
-  for (OpOperand &operand : returnOp->getOpOperands()) {
-    if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
-      operand.set(castOp.getSource());
-      resultTypes.push_back(castOp.getSource().getType());
-    } else {
-      resultTypes.push_back(operand.get().getType());
+  // Remove direct casts.
+  for (func::ReturnOp returnOp : returnOps) {
+    for (OpOperand &operand : returnOp->getOpOperands()) {
+      // Bail if no common result type was found.
+      if (resultTypes[operand.getOperandNumber()]) {
+        operand.set(unpackCast(operand.get()));
+      }
     }
   }
 
+  // Fill in the missing result types that were not the same among all
+  // func.return ops.
+  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
+    if (resultTypes[i])
+      continue;
+    resultTypes[i] = funcOp.getFunctionType().getResult(i);
+  }
+
+  // Update the function type.
   auto newFuncType = FunctionType::get(
       funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
   funcOp.setType(newFuncType);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 3f6d182b57c031..35b28f7ec83919 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1360,3 +1360,49 @@ func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?
   %0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
   return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
 }
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
+func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %t0 : tensor<5xf32>
+^bb2:
+  return %t1 : tensor<5xf32>
+}
+
+//       CHECK-ALIAS-SETS: func @caller(
+//  CHECK-ALIAS-SETS-SAME:     %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
+  // Check that alias sets are computed correctly.
+  //      CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
+  // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+  // CHECK-ALIAS-SETS-SAME:  __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
+  call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+  return
+}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
+func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %t0 : tensor<5xf32>
+^bb2:
+  return %t0 : tensor<5xf32>
+}
+
+//       CHECK-ALIAS-SETS: func @caller(
+//  CHECK-ALIAS-SETS-SAME:     %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  // Check that equivalence sets are computed correctly.
+  //      CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
+  // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+  // CHECK-ALIAS-SETS-SAME:  __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
+  %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+  // CHECK-ALIAS-SETS-SAME: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
+  return %r : tensor<5xf32>
+}
+
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 28ce0735e47b74..d773e1af43a76e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,24 +1,5 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
 
-// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-    -> (tensor<f32>, tensor<f32>)
-{
-  cf.cond_br %cond1, ^bb1, ^bb2
-
-  ^bb1:
-    %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
-      scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
-    } else {
-      scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
-    }
-    return %T#0, %T#1 : tensor<f32>, tensor<f32>
-  ^bb2:
-    return %t2, %t1 : tensor<f32>, tensor<f32>
-}
-
-// -----
-
 func.func @scf_for(%A : tensor<?xf32>,
               %B : tensor<?xf32> {bufferization.writable = true},
               %C : tensor<4xf32>,
@@ -146,7 +127,8 @@ func.func @regression_scf_while() {
 
 // -----
 
-// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+// expected-error @below{{could not infer buffer type of block argument}}
+// expected-error @below{{failed to bufferize op}}
 func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
   func.return %t : tensor<5xf32>
 ^bb1(%arg1 : tensor<5xf32>):
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 2b5b8631436705..65557a68d243a2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -771,3 +771,28 @@ func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
   %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
   return %0 : tensor<5xf32>
 }
+
+// -----
+
+// The two func.return operands have different types after bufferization. Make
+// sure that memref.cast ops are inserted.
+
+// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
+func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
+  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
+  %t = tensor.empty() : tensor<10xf32>
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
+  // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
+  %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
+  // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
+  return %0 : tensor<5xf32>
+^bb2:
+  // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
+  // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
+  %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
+  // CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>>
+  return %1 : tensor<5xf32>
+}
+



More information about the Mlir-commits mailing list