[Mlir-commits] [mlir] [mlir][bufferization] Add support for non-unique `func.return` (PR #114017)
Matthias Springer
llvmlistbot at llvm.org
Tue Nov 12 04:59:38 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114017
>From 1122ffeddcbcd27838386b952849a29d792dc9f1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 29 Oct 2024 09:51:11 +0100
Subject: [PATCH] [mlir][bufferization] Add support for non-unique
`func.return`
---
.../FuncBufferizableOpInterfaceImpl.h | 4 +
.../FuncBufferizableOpInterfaceImpl.cpp | 79 ++++----
.../Transforms/OneShotModuleBufferize.cpp | 174 +++++++++++++-----
.../one-shot-module-bufferize-analysis.mlir | 46 +++++
.../one-shot-module-bufferize-invalid.mlir | 22 +--
.../Transforms/one-shot-module-bufferize.mlir | 25 +++
6 files changed, 236 insertions(+), 114 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index 0b91d3d675b7c9..e8e6226460ac73 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
class DialectRegistry;
@@ -21,6 +22,9 @@ class FuncOp;
} // namespace func
namespace bufferization {
+/// Helper function that returns all func.return ops in the given function.
+SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
+
namespace func_ext {
/// The state of analysis of a FuncOp.
enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 11ed434f774a87..c45678f1e4b4dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -19,6 +19,15 @@
#include <optional>
namespace mlir {
+/// Return all func.return ops in the given function.
+SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> result;
+ for (Block &b : funcOp.getBody())
+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+ result.push_back(returnOp);
+ return result;
+}
+
namespace bufferization {
namespace func_ext {
@@ -41,20 +50,6 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
#endif // NDEBUG
}
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
-}
-
/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
@@ -391,15 +386,6 @@ struct FuncOpInterface
getBufferType(op, value, options, invocationStack);
}
- LogicalResult verifyAnalysis(Operation *op,
- const AnalysisState &state) const {
- auto funcOp = cast<func::FuncOp>(op);
- // TODO: func.func with multiple returns are not supported.
- if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
- return op->emitOpError("op without unique func.return is not supported");
- return success();
- }
-
/// Rewrite function bbArgs and return values into buffer form. This function
/// bufferizes the function signature and the ReturnOp. When the entire
/// function body has been bufferized, function return types can be switched
@@ -446,41 +432,38 @@ struct FuncOpInterface
return success();
}
- // TODO: Support functions with multiple returns.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
- assert(returnOp->getNumOperands() == retTypes.size() &&
- "incorrect number of return values");
- Location loc = returnOp.getLoc();
-
// 1. Bufferize every block.
for (Block &block : funcOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
options)))
return failure();
- // 2. Bufferize all operands of the return op.
- SmallVector<Value> returnValues;
- for (auto [returnVal, bufferizedType] :
- llvm::zip_equal(returnOp->getOperands(), retTypes)) {
- auto tensorType = dyn_cast<TensorType>(returnVal.getType());
- rewriter.setInsertionPoint(returnOp);
-
- // If not a tensor type just forward it.
- if (!tensorType) {
- returnValues.push_back(returnVal);
- continue;
+ // 2. Bufferize the operands of the all return op.
+ for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+ assert(returnOp->getNumOperands() == retTypes.size() &&
+ "incorrect number of return values");
+ SmallVector<Value> returnValues;
+ for (auto [returnVal, bufferizedType] :
+ llvm::zip_equal(returnOp->getOperands(), retTypes)) {
+ auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+ rewriter.setInsertionPoint(returnOp);
+
+ // If not a tensor type just forward it.
+ if (!tensorType) {
+ returnValues.push_back(returnVal);
+ continue;
+ }
+
+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
+ // away.
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ returnOp.getLoc(), bufferizedType, returnVal);
+ returnValues.push_back(toMemrefOp);
}
- // Note: If `inferFunctionResultLayout = true`, casts are later folded
- // away.
- Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- loc, bufferizedType, returnVal);
- returnValues.push_back(toMemrefOp);
+ returnOp.getOperandsMutable().assign(returnValues);
}
- returnOp.getOperandsMutable().assign(returnValues);
-
// 3. Set the new function type.
funcOp.setType(newFuncType);
return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index a492bcdd0f3e38..71ea0fd9d43cde 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
return state.addExtension<FuncAnalysisState>();
}
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
-}
-
namespace {
/// Annotate IR with the results of the analysis. For testing purposes only.
@@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}
- // Support only single return-terminated block in the function.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- for (OpOperand &returnVal : returnOp->getOpOperands())
- if (isa<RankedTensorType>(returnVal.get().getType()))
- for (BlockArgument bbArg : funcOp.getArguments())
- if (isa<RankedTensorType>(bbArg.getType())) {
- int64_t returnIdx = returnVal.getOperandNumber();
- int64_t bbArgIdx = bbArg.getArgNumber();
- if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
- funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
- if (state.getOptions().testAnalysisOnly)
- annotateEquivalentReturnBbArg(returnVal, bbArg);
+ // Find all func.return ops.
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ assert(!returnOps.empty() && "expected at least one ReturnOp");
+
+ // Build alias sets. Merge all aliases from all func.return ops.
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
+ int64_t bbArgIdx = bbArg.getArgNumber();
+ // Store aliases in a set, so that we don't add the same alias twice.
+ SetVector<int64_t> aliases;
+ for (func::ReturnOp returnOp : returnOps) {
+ for (OpOperand &returnVal : returnOp->getOpOperands()) {
+ if (isa<RankedTensorType>(returnVal.get().getType())) {
+ int64_t returnIdx = returnVal.getOperandNumber();
+ if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
+ aliases.insert(returnIdx);
}
- if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
- funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
}
+ }
+ for (int64_t alias : aliases)
+ funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
+ }
+ }
+
+ // Build equivalence sets.
+ // Helper function that finds an equivalent block argument index for the
+ // given OpOperand. Return std::nullopt if no equivalent block argument could
+ // be found.
+ auto findEquivalentBlockArgIdx =
+ [&](OpOperand &opOperand) -> std::optional<int64_t> {
+ Value v = opOperand.get();
+ if (!isa<TensorType>(v.getType()))
+ return std::nullopt;
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
+ if (state.areEquivalentBufferizedValues(v, bbArg)) {
+ if (state.getOptions().testAnalysisOnly)
+ annotateEquivalentReturnBbArg(opOperand, bbArg);
+ return bbArg.getArgNumber();
+ }
+ }
+ }
+ return std::nullopt;
+ };
+
+ int64_t numResults = returnOps.front()->getNumOperands();
+ for (int64_t i = 0; i < numResults; ++i) {
+ // Find the equivalent block argument index for the i-th operand of the
+ // first func.return op.
+ std::optional<int64_t> maybeEquiv =
+ findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
+ if (!maybeEquiv.has_value())
+ continue;
+ int64_t bbArgIdx = *maybeEquiv;
+ bool allEquiv = true;
+
+ // Check if all other func.return ops have the same equivalent block
+ // argument for the i-th operand. In contrast to aliasing information,
+ // which is just "merged", equivalence information must match across all
+ // func.return ops.
+ for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+ std::optional<int64_t> maybeEquiv =
+ findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
+ if (maybeEquiv != bbArgIdx) {
+ allEquiv = false;
+ break;
+ }
+ }
+
+ // All func.return ops have the same equivalent block argument for the i-th
+ // operand.
+ if (allEquiv)
+ funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
+ }
return success();
}
@@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
- if (!funcOp.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
- }
-
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
return success();
}
+/// Helper function that extracts the source from a memref.cast. If the given
+/// value is not a memref.cast result, simply returns the given value.
+static Value unpackCast(Value v) {
+ auto castOp = v.getDefiningOp<memref::CastOp>();
+ if (!castOp)
+ return v;
+ return castOp.getSource();
+}
+
+/// Helper function that returns the return types (skipping casts) of the given
+/// func.return ops. This function returns as many types as the return ops have
+/// operands. If the i-th operand is not the same for all func.return ops, then
+/// the i-th returned type is an "empty" type.
+static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+ assert(!returnOps.empty() && "expected at least one ReturnOp");
+ int numOperands = returnOps.front()->getNumOperands();
+
+ // Helper function that unpacks memref.cast ops and returns the type.
+ auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+
+ SmallVector<Type> result;
+ for (int i = 0; i < numOperands; ++i) {
+ // Get the type of the i-th operand of the first func.return ops.
+ Type t = getSourceType(returnOps.front()->getOperand(i));
+
+ // Check if all other func.return ops have a matching operand type.
+ for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
+ if (getSourceType(returnOps[j]->getOperand(i)) != t)
+ t = Type();
+
+ result.push_back(t);
+ }
+
+ return result;
+}
+
/// Fold return values that are memref casts and update function return types.
///
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(func::FuncOp funcOp) {
+ // There is nothing to do for bodiless ops.
if (funcOp.getBody().empty())
return;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- SmallVector<Type> resultTypes;
+ // Compute the common result types of all return ops.
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Type> resultTypes = getReturnTypes(returnOps);
- for (OpOperand &operand : returnOp->getOpOperands()) {
- if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
- operand.set(castOp.getSource());
- resultTypes.push_back(castOp.getSource().getType());
- } else {
- resultTypes.push_back(operand.get().getType());
+ // Remove direct casts.
+ for (func::ReturnOp returnOp : returnOps) {
+ for (OpOperand &operand : returnOp->getOpOperands()) {
+ // Bail if no common result type was found.
+ if (resultTypes[operand.getOperandNumber()]) {
+ operand.set(unpackCast(operand.get()));
+ }
}
}
+ // Fill in the missing result types that were not the same among all
+ // func.return ops.
+ for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
+ if (resultTypes[i])
+ continue;
+ resultTypes[i] = funcOp.getFunctionType().getResult(i);
+ }
+
+ // Update the function type.
auto newFuncType = FunctionType::get(
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
funcOp.setType(newFuncType);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 3f6d182b57c031..35b28f7ec83919 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1360,3 +1360,49 @@ func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?
%0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
+func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %t0 : tensor<5xf32>
+^bb2:
+ return %t1 : tensor<5xf32>
+}
+
+// CHECK-ALIAS-SETS: func @caller(
+// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
+ // Check that alias sets are computed correctly.
+ // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
+ // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+ // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
+ call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+ return
+}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
+func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %t0 : tensor<5xf32>
+^bb2:
+ return %t0 : tensor<5xf32>
+}
+
+// CHECK-ALIAS-SETS: func @caller(
+// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ // Check that equivalence sets are computed correctly.
+ // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
+ // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+ // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
+ %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK-ALIAS-SETS-SAME: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
+ return %r : tensor<5xf32>
+}
+
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 28ce0735e47b74..d773e1af43a76e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,24 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
-// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
- -> (tensor<f32>, tensor<f32>)
-{
- cf.cond_br %cond1, ^bb1, ^bb2
-
- ^bb1:
- %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
- scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
- } else {
- scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
- }
- return %T#0, %T#1 : tensor<f32>, tensor<f32>
- ^bb2:
- return %t2, %t1 : tensor<f32>, tensor<f32>
-}
-
-// -----
-
func.func @scf_for(%A : tensor<?xf32>,
%B : tensor<?xf32> {bufferization.writable = true},
%C : tensor<4xf32>,
@@ -146,7 +127,8 @@ func.func @regression_scf_while() {
// -----
-// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+// expected-error @below{{could not infer buffer type of block argument}}
+// expected-error @below{{failed to bufferize op}}
func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
func.return %t : tensor<5xf32>
^bb1(%arg1 : tensor<5xf32>):
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 2b5b8631436705..65557a68d243a2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -771,3 +771,28 @@ func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}
+
+// -----
+
+// The two func.return operands have different types after bufferization. Make
+// sure that memref.cast ops are inserted.
+
+// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
+func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
+ // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
+ %t = tensor.empty() : tensor<10xf32>
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
+ // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
+ %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
+ // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
+ return %0 : tensor<5xf32>
+^bb2:
+ // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
+ // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
+ %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
+ // CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>>
+ return %1 : tensor<5xf32>
+}
+
More information about the Mlir-commits
mailing list