[Mlir-commits] [mlir] 26d896f - Fixes in 'tosa.reshape' lowering and folder (#85798)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 26 07:52:59 PDT 2024
Author: Rafael Ubal
Date: 2024-03-26T10:52:55-04:00
New Revision: 26d896f3688a8bff6faf85ccce557e320108997f
URL: https://github.com/llvm/llvm-project/commit/26d896f3688a8bff6faf85ccce557e320108997f
DIFF: https://github.com/llvm/llvm-project/commit/26d896f3688a8bff6faf85ccce557e320108997f.diff
LOG: Fixes in 'tosa.reshape' lowering and folder (#85798)
- Revamped lowering conversion pattern for `tosa.reshape` to handle previously unsupported combinations of dynamic dimensions in input and output tensors. The lowering strategy continues to rely on pairs `tensor.collapse_shape` + `tensor.expand_shape`, which allow for downstream fusion with surrounding `linalg.generic` ops.
- Fixed bug in canonicalization pattern `ReshapeOp::fold()` in `TosaCanonicalizations.cpp`. The input and result types being equal is not a sufficient condition for folding. If there is more than 1 dynamic dimension in the input and result types, a productive reshape could still occur.
- This work exposed the fact that bufferization does not properly handle a `tensor.collapse_shape` op producing a 0D tensor from a dynamically shaped one due to a limitation in `memref.collapse_shape`. While the proper way to address this would involve releasing the `memref.collapse_shape` restriction and verifying correct bufferization, this is left as possible future work. For now, this scenario is avoided by casting the `tosa.reshape` input tensor to a static shape if necessary (see `inferReshapeInputType()`.
- An extended set of tests are intended to cover relevant conversion paths. Tests are named using pattern `test_reshape_<rank>_{up|down|same}_{s2s|s2d|d2s|d2d}_{explicit|auto}[_empty][_identity]`, where:
- `<rank>` is the input rank (e.g., 3d, 6d)
- `{up|down|same}` indicates whether the reshape increases, decreases, or retains the input rank.
- `{s2s|s2d|d2s|d2d}` indicates whether reshape converts a statically shaped input to a statically shaped result (`s2s`), a statically shaped input to a dynamically shaped result (`s2d`), etc.
- `{explicit|auto}` is used to indicate that all values in the `new_shape` attribute are >=0 (`explicit`) or that a -1 placeholder value is used (`auto`).
- `empty` is used to indicate that `new_shape` includes a component set to 0.
- `identity` is used when the input and result shapes are the same.
Added:
Modified:
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 505d85f211111c..11ba98ddf352b4 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -19,24 +19,99 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
+
using namespace mlir;
using namespace tosa;
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
- ArrayRef<int64_t> rhsShape,
- SmallVector<int64_t> &intermediateShape,
- bool isDynamic) {
- if (isDynamic) {
- // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
- intermediateShape = {ShapedType::kDynamic};
- return true;
- }
+namespace {
- if (lhsShape.empty() || rhsShape.empty()) {
- intermediateShape = {};
- return true;
- }
+// Infer the type to which the input of a 'tosa.reshape' op must be cast when
+// lowered.
+TensorType inferReshapeInputType(TypedValue<TensorType> input,
+ ArrayRef<int64_t> newShape) {
+ // No need to cast input for non-empty target shape
+ if (!newShape.empty())
+ return input.getType();
+
+ // The input type must be cast into a tensor with the same rank and all static
+ // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
+ // op that converts a dynamically shaped tensor into a 0D tensor. While such
+ // construct is not incorrect on its own, bufferization cannot properly handle
+ // it at the moment, so we avoid it.
+ SmallVector<int64_t> shape(input.getType().getRank(), 1);
+ return input.getType().clone(shape);
+}
+
+// Infer the result type of 'tensor.expand_shape' in the collapse-expand
+// pair emitted for a 'tosa.reshape' op.
+TensorType inferReshapeExpandedType(TensorType inputType,
+ ArrayRef<int64_t> newShape) {
+ // Special case for 0D output tensor. Note: Watch out when using Type::clone()
+ // with just '{}', as it will invoke the incorrect overload.
+ if (newShape.empty())
+ return inputType.clone(ArrayRef<int64_t>{});
+
+ // Check if the input is static, and if so, get its total size
+ bool inputIsStatic = inputType.hasStaticShape();
+ int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
+
+ // Compute result shape
+ bool resultIsStatic = true;
+ auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
+ // If this is not a placeholder, do not change it
+ if (size >= 0)
+ return size;
+
+ // If we do not know the total size of the tensor, keep this dimension
+ // dynamic in the result shape.
+ if (!inputIsStatic) {
+ resultIsStatic = false;
+ return ShapedType::kDynamic;
+ }
+ // Calculate the product of all elements in 'newShape' except for the -1
+ // placeholder, which we discard by negating the result.
+ int64_t totalSizeNoPlaceholder = -std::accumulate(
+ newShape.begin(), newShape.end(), 1, std::multiplies());
+
+ // If there is a 0 component in 'newShape', resolve the placeholder as 0.
+ if (totalSizeNoPlaceholder == 0)
+ return 0;
+
+ // Resolve the placeholder as the quotient between the total tensor size and
+ // the product of all other sizes.
+ return totalSize / totalSizeNoPlaceholder;
+ });
+
+ // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
+ // shaped input from being reshaped into a statically shaped result. We may
+ // simply turn the first result dimension dynamic to address this.
+ if (!inputIsStatic && resultIsStatic)
+ resultShape[0] = ShapedType::kDynamic;
+
+ // The 'tensor.expand_shape' op also forbids a statically shaped input from
+ // being reshaped into a dynamically shaped result, but the placeholder
+ // inference algorithm above guarantees that this will never be the case.
+ assert(!inputIsStatic || resultIsStatic);
+
+ // Create result type
+ return inputType.clone(resultShape);
+}
+
+// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
+// pair emitted for a 'tosa.reshape' op.
+TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
+ auto lhsShape = lhsType.getShape();
+ auto rhsShape = rhsType.getShape();
+
+ if (lhsShape.empty() || rhsShape.empty())
+ return lhsType.clone(ArrayRef<int64_t>{});
+
+ if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
+ return lhsType.clone({ShapedType::kDynamic});
+
+ SmallVector<int64_t> intermediateShape;
unsigned currLhsDim = 0, currRhsDim = 0;
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
int64_t rhsSize = rhsShape[currRhsDim];
@@ -62,174 +137,113 @@ static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
currLhsDim++;
}
- // If the iterators didn't reach the end and their leftover dimensions are not
- // equal to 1 an intermediate shape was not found.
- while (currLhsDim < lhsShape.size()) {
- if (lhsShape[currLhsDim++] != 1) {
- return false;
- }
+ // Static shapes are guaranteed to be compatible by the op verifier, so all
+ // leftover dimensions should be 1.
+ for (; currLhsDim < lhsShape.size(); currLhsDim++) {
+ assert(lhsShape[currLhsDim] == 1);
}
-
- while (currRhsDim < rhsShape.size()) {
- if (rhsShape[currRhsDim++] != 1) {
- return false;
- }
+ for (; currRhsDim < rhsShape.size(); currRhsDim++) {
+ assert(rhsShape[currRhsDim] == 1);
}
-
- return true;
+
+ return lhsType.clone(intermediateShape);
}
-static bool createReassociationMapsForCollapse(
- PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> dstShape,
- SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
+SmallVector<ReassociationExprs>
+createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
+ auto srcShape = cast<TensorType>(srcType).getShape();
+ auto dstShape = cast<TensorType>(dstType).getShape();
- // If the shape is dynamic, create a map for collapsing into one dimension.
- if (isDynamic) {
- SmallVector<AffineExpr, 2> exprs;
- for (int i = 0, s = srcShape.size(); i < s; ++i)
- exprs.push_back(rewriter.getAffineDimExpr(i));
- reassociationMap = {exprs};
- return true;
- }
+ if (srcShape.empty() || dstShape.empty())
+ return {};
- if (dstShape.empty()) {
- reassociationMap = {};
- return true;
+ if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
+ assert(dstShape.size() == 1);
+ SmallVector<AffineExpr, 2> exprs;
+ for (auto i : llvm::seq<int64_t>(srcShape.size()))
+ exprs.push_back(builder.getAffineDimExpr(i));
+ return {exprs};
}
- reassociationMap.resize(dstShape.size());
+ SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
unsigned currSrcDim = 0, currDstDim = 0;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
+ builder.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
+ builder.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
+ builder.getAffineDimExpr(currSrcDim++));
}
}
}
currDstDim++;
}
- // If both iterators didn't reach the end, we have leftover dimentions which
- // implies that we have a mismatch in shape.
- return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+ // If the source and target shapes are compatible, both iterators must have
+ // reached the end. This condition is guaranteed by the op verifier for
+ // static shapes.
+ assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
+ return reassociationMap;
}
-namespace {
-Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
- ShapedType resultTy, Value operand) {
- ShapedType operandTy = cast<ShapedType>(operand.getType());
- if (resultTy == operandTy)
- return operand;
-
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && resultTy.getRank() != 1) {
- (void)rewriter.notifyMatchFailure(
- loc, "Cannot collapse dynamic dims to more than one dimension");
- return {};
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
- resultTy.getShape(),
- reassociationMap, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Attempting to collapse into an incompatible shape");
- return {};
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Cannot collapse into given shape");
- return {};
- }
- return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
- reassociationMap);
+// Create a tensor.collapse_shape op that reshapes the input into the given
+// result type.
+Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
+ Value input) {
+ auto reassociationMap =
+ createReassociationMapForCollapse(builder, input.getType(), resultType);
+ return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
+ reassociationMap);
}
-Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
- ShapedType resultTy, Value operand) {
- ShapedType operandTy = cast<ShapedType>(operand.getType());
- if (resultTy == operandTy)
- return operand;
-
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && operandTy.getRank() != 1) {
- (void)rewriter.notifyMatchFailure(
- loc, "Cannot expand dynamic dims from more than one dimension");
- return {};
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
- operandTy.getShape(),
- reassociationMap, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Attempting to expand into an incompatible shape");
- return {};
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic) ||
- intermediateShape != operandTy.getShape()) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Cannot expand into given shape");
- return {};
- }
- return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
- reassociationMap);
+// Create a tensor.expand_shape op that reshapes the input into the given result
+// type.
+Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
+ Value input) {
+ auto reassociationMap =
+ createReassociationMapForCollapse(builder, resultType, input.getType());
+ return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
+ reassociationMap);
}
-class ReshapeConverterCollapseExpand
- : public OpConversionPattern<tosa::ReshapeOp> {
+class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
- ShapedType resultTy = cast<ShapedType>(reshape.getType());
- bool isDynamic = !operandTy.hasStaticShape();
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
- intermediateShape, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot identify an intermediate shape between "
- "the given two shapes");
- }
- auto intermediateTy = RankedTensorType::get(
- intermediateShape, reshape.getType().getElementType());
-
- Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
- adaptor.getInput1());
- if (!collapse)
- return failure();
-
- Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
- if (!expand)
- return failure();
-
- rewriter.replaceOp(reshape, expand);
+ auto loc = reshape.getLoc();
+ auto resultType = reshape.getResult().getType();
+ auto input = reshape.getInput1();
+ auto newShape = reshape.getNewShape();
+
+ // Infer all intermediate types
+ auto inputType = inferReshapeInputType(input, newShape);
+ auto expandedType = inferReshapeExpandedType(inputType, newShape);
+ auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
+
+ // Cast input if needed
+ auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
+
+ // Emit collaspe-expand pair
+ auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
+ auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
+
+ // Cast to final result type if needed
+ auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
+ rewriter.replaceOp(reshape, result);
return success();
}
};
@@ -416,8 +430,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<SliceConverter, PadConverter, ConcatConverter>(
- patterns->getContext());
-
- patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
+ patterns->add<
+ ConcatConverter,
+ PadConverter,
+ ReshapeConverter,
+ SliceConverter
+ >(patterns->getContext());
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4c50aaecfe9488..d23c9fe824c94a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -795,7 +795,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!inputTy || !outputTy)
return {};
- if (inputTy == outputTy)
+ // Fold when the input and output types are the same. This is only safe when
+ // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
+ // there may still be a productive reshape.
+ if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
return getInput1();
// reshape(reshape(x)) -> reshape(x)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f461e7e1a555b8..6e6e8435073812 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -970,6 +970,11 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
<< " elements into " << outputElementsNum;
}
}
+
+ int missingDims = llvm::count(getNewShape(), -1);
+ if (missingDims > 1)
+ return emitOpError() << "At most one target dimension can be -1";
+
return mlir::success();
}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index daaa68a7260b71..a8a3c42e168422 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -1,95 +1,363 @@
// RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
-// CHECK-LABEL: @test_reshape_downrank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
- // CHECK: return [[RESHAPE]]
- return %0 : tensor<6xf32>
+// -----
+
+// CHECK-LABEL: test_reshape_0d_same_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: return %[[ARG_0]] : tensor<f32>
+func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
}
// -----
-// CHECK-LABEL: @test_reshape_downrank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
- // CHECK: return [[RESHAPE]]
+// CHECK-LABEL: test_reshape_0d_up_s2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
+// CHECK: return %[[VAL_1]] : tensor<?xf32>
+func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_uprank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
- // CHECK: return [[RESHAPE]]
- return %0 : tensor<2x3xf32>
+// CHECK-LABEL: test_reshape_0d_up_s2d_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
+// CHECK: return %[[VAL_1]] : tensor<?xf32>
+func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_0d_up_s2s_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: return %[[VAL_0]] : tensor<1xf32>
+func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_0d_up_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: return %[[VAL_0]] : tensor<1xf32>
+func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor<f32>) -> tensor<1xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_1d_down_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?xf32> to tensor<1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1xf32> into tensor<f32>
+// CHECK: return %[[VAL_1]] : tensor<f32>
+func.func @test_reshape_1d_down_d2s_explicit(%arg0: tensor<?xf32>) -> tensor<f32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<?xf32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_1d_down_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] [] : tensor<1xf32> into tensor<f32>
+// CHECK: return %[[VAL_0]] : tensor<f32>
+func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor<f32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<1xf32>) -> tensor<f32>
+ return %0 : tensor<f32>
}
// -----
-// CHECK-LABEL: @test_reshape_uprank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-LABEL: test_reshape_1d_up_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[VAL_0]] : tensor<2x?xf32>
+func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
- // CHECK: return [[RESHAPE]]
return %0 : tensor<2x?xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_samerank
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
-func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
- // CHECK-NEXT: return %[[RESHAPE2]]
+// CHECK-LABEL: test_reshape_1d_up_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<6xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: return %[[VAL_0]] : tensor<2x3xf32>
+func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_samerank_dyn
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
-func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+// CHECK-LABEL: test_reshape_2d_down_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x?xf32> into tensor<?xf32>
+// CHECK: return %[[VAL_0]] : tensor<?xf32>
+func.func @test_reshape_2d_down_d2d_auto(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_2d_down_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x3xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x3xf32> into tensor<6xf32>
+// CHECK: return %[[VAL_0]] : tensor<6xf32>
+func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
+ return %0 : tensor<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_2d_same_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x2xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[VAL_1]] : tensor<2x?xf32>
+func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
- // CHECK-NEXT: return %[[RESHAPE2]]
return %0 : tensor<2x?xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_downrank_6D
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
- // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
- return %0 : tensor<6x5x77xf32>
+// CHECK-LABEL: test_reshape_2d_same_s2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
+func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2>} : (tensor<2x4xf32>) -> tensor<?x2xf32>
+ return %0 : tensor<?x2xf32>
}
// -----
-// CHECK-LABEL: @test_reshape_downrank_6D_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
- // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
- // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
+// CHECK-LABEL: test_reshape_2d_same_s2d_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
+func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 4, 2>} : (tensor<2x4xf32>) -> tensor<?x2xf32>
+ return %0 : tensor<?x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_2d_same_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: return %[[VAL_1]] : tensor<2x3xf32>
+func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_auto_empty
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<0x3x?xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, -1>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<2x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<2x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_auto_identity
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x3x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x3x4xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x3x?xf32>
+// CHECK: return %[[VAL_1]] : tensor<2x3x?xf32>
+func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, -1>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
+ return %0 : tensor<2x3x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_explicit_empty
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, 2>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_explicit_identity
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x3x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?x3x4xf32> to tensor<2x3x?xf32>
+// CHECK: return %[[VAL_0]] : tensor<2x3x?xf32>
+func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
+ return %0 : tensor<2x3x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2s_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
+func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+ return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
+func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+ return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_s2s_explicit_identity
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x3x4xf32>
+// CHECK: return %[[ARG_0]] : tensor<2x3x4xf32>
+func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+ return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_3d_up_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : tensor<?xf32> into tensor<?x3x2x1xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x3x2x1xf32>
+func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
+ %0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
+ return %0 : tensor<1x3x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_4d_down_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?x?x?x?xf32> to tensor<1x1x1x1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1x1x1x1xf32> into tensor<f32>
+// CHECK: return %[[VAL_1]] : tensor<f32>
+func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor<?x?x?x?xf32>) -> tensor<f32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<?x?x?x?xf32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_5d_down_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x2x3xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x2x3xf32>
+// CHECK: return %[[VAL_1]] : tensor<?x2x3xf32>
+func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2, 3>} : (tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32>
+ return %0 : tensor<?x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_6d_down_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x?x5x7x11xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x5x77xf32>
+// CHECK: return %[[VAL_1]] : tensor<?x5x77xf32>
+func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
return %0 : tensor<?x5x77xf32>
}
// -----
-// CHECK-LABLE: func @slice
+// CHECK-LABEL: test_reshape_6d_down_s2s_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x3x5x7x11xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>
+// CHECK: return %[[VAL_0]] : tensor<6x5x77xf32>
+func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, -1>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+ return %0 : tensor<6x5x77xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_6d_down_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x3x5x7x11xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>
+// CHECK: return %[[VAL_0]] : tensor<6x5x77xf32>
+func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+ return %0 : tensor<6x5x77xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @slice
func.func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
%0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>) -> (tensor<1xf32>)
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e7ede2e0ccef9a..6eac759a083645 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -365,6 +365,14 @@ func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
return %0 : tensor<?x10xf32>
}
+// CHECK-LABEL: @reshape_canonicalize_dyn_nofold
+func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor<?x?x10xf32>) -> tensor<?x?x10xf32> {
+ // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
+ // CHECK: return %[[VAR0]] : tensor<?x?x10xf32>
+ %0 = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
+ return %0 : tensor<?x?x10xf32>
+}
+
// CHECK-LABEL: @reshape_canonicalize_double
func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
// CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>}
More information about the Mlir-commits
mailing list