[Mlir-commits] [mlir] 723979e - Move tosa.reshape lowering patterns from TosaToLinalg to TosaToTensor
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Mar 7 08:06:23 PST 2023
Author: Krzysztof Drewniak
Date: 2023-03-07T16:06:18Z
New Revision: 723979efc8638d192a680d2a0f814f758274a046
URL: https://github.com/llvm/llvm-project/commit/723979efc8638d192a680d2a0f814f758274a046
DIFF: https://github.com/llvm/llvm-project/commit/723979efc8638d192a680d2a0f814f758274a046.diff
LOG: Move tosa.reshape lowering patterns from TosaToLinalg to TosaToTensor
Converting tosa.reshape to tensor.expand_shape and
tensor.collapse_shape logically belongs in the tosa-to-tensor
conversion process. In addition, we (rocMLIR downstream) want to use
the reshape -> expand/collapse_shape logic to simplify parts of our
Tosa integration without using the full tosa-to-linalg flow, further
motivating moving these patterns.
The downside to this change is that it means you need to run
tosa-to-tensor after tosa-to-linalg, which is probably a breaking
change.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D145119
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 46f5292a9ad1a..82cadf2a07daa 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -807,133 +807,12 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
return rewriter.notifyMatchFailure(
op, "unable to create linalg.generic body for reduce op");
- SmallVector<ReassociationExprs, 4> reassociationMap;
- uint64_t expandInputRank =
- linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
- reassociationMap.resize(expandInputRank);
-
- for (uint64_t i = 0; i < expandInputRank; i++) {
- int32_t dimToPush = i > axis ? i + 1 : i;
- reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
- }
-
- if (expandInputRank != 0) {
- int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
- reassociationMap[expandedDim].push_back(
- rewriter.getAffineDimExpr(expandedDim + 1));
- }
-
- rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- op, resultTy, linalgOp.getResults()[0], reassociationMap);
+ rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+ op, resultTy, linalgOp.getResults()[0],
+ rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
return success();
}
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
- ArrayRef<int64_t> rhsShape,
- SmallVector<int64_t> &intermediateShape,
- bool isDynamic) {
- if (isDynamic) {
- // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
- intermediateShape = {ShapedType::kDynamic};
- return true;
- }
-
- if (lhsShape.empty() || rhsShape.empty()) {
- intermediateShape = {};
- return true;
- }
-
- unsigned currLhsDim = 0, currRhsDim = 0;
- while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
- int64_t rhsSize = rhsShape[currRhsDim];
- int64_t lhsSize = lhsShape[currLhsDim];
- while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
- currRhsDim < rhsShape.size()) {
- if (lhsSize < rhsSize) {
- currLhsDim++;
- if (currLhsDim < lhsShape.size()) {
- lhsSize *= lhsShape[currLhsDim];
- }
- } else {
- currRhsDim++;
- if (currRhsDim < rhsShape.size()) {
- rhsSize *= rhsShape[currRhsDim];
- }
- }
- }
- if (lhsSize == rhsSize) {
- intermediateShape.push_back(lhsSize);
- }
- currRhsDim++;
- currLhsDim++;
- }
-
- // If the iterators didn't reach the end and their leftover dimensions are not
- // equal to 1 an intermediate shape was not found.
- while (currLhsDim < lhsShape.size()) {
- if (lhsShape[currLhsDim++] != 1) {
- return false;
- }
- }
-
- while (currRhsDim < rhsShape.size()) {
- if (rhsShape[currRhsDim++] != 1) {
- return false;
- }
- }
-
- return true;
-}
-
-static bool createReassociationMapsForCollapse(
- PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> dstShape,
- SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-
- // If the shape is dynamic, create a map for collapsing into one dimension.
- if (isDynamic) {
- SmallVector<AffineExpr, 2> exprs;
- for (int i = 0, s = srcShape.size(); i < s; ++i)
- exprs.push_back(rewriter.getAffineDimExpr(i));
- reassociationMap = {exprs};
- return true;
- }
-
- if (dstShape.empty()) {
- reassociationMap = {};
- return true;
- }
-
- reassociationMap.resize(dstShape.size());
- unsigned currSrcDim = 0, currDstDim = 0;
- while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
- int64_t dstSize = dstShape[currDstDim];
- int64_t srcSize = srcShape[currSrcDim];
- while (srcSize < dstSize && currSrcDim < srcShape.size()) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- srcSize *= srcShape[currSrcDim];
- }
- if (srcSize == dstSize) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- // If the next dim in collapsedShape is not 1, treat subsequent dims in
- // expandedShape which are 1 to be collapsed.
- if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
- while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- }
- }
- }
- currDstDim++;
- }
-
- // If both iterators didn't reach the end, we have leftover dimentions which
- // implies that we have a mismatch in shape.
- return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
-}
-
namespace {
template <typename SrcOp>
@@ -947,115 +826,6 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
}
};
-class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
-public:
- using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
- ShapedType resultTy = reshape.getType().template cast<ShapedType>();
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && resultTy.getRank() != 1) {
- return rewriter.notifyMatchFailure(
- reshape, "Cannot collapse dynamic dims to more than one dimension");
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
- resultTy.getShape(),
- reassociationMap, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape,
- "tosa.reshape Attempting to collapse into an incompatible shape");
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot collapse into given shape");
- }
-
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
- return success();
- }
-};
-
-class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
-public:
- using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
- ShapedType resultTy = reshape.getType().template cast<ShapedType>();
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && operandTy.getRank() != 1) {
- return rewriter.notifyMatchFailure(
- reshape, "Cannot expand dynamic dims from more than one dimension");
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
- operandTy.getShape(),
- reassociationMap, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape,
- "tosa.reshape Attempting to expand into an incompatible shape");
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic) ||
- intermediateShape != operandTy.getShape()) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot expand into given shape");
- }
- rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
- return success();
- }
-};
-
-class ReshapeConverterCollapseExpand
- : public OpConversionPattern<tosa::ReshapeOp> {
-public:
- using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
- ShapedType resultTy = reshape.getType().template cast<ShapedType>();
- bool isDynamic = !operandTy.hasStaticShape();
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
- intermediateShape, isDynamic)) {
- return rewriter.notifyMatchFailure(
- reshape, "tosa.reshape Cannot identify an intermediate shape between "
- "the given two shapes");
- }
-
- Value collapse = rewriter.create<tosa::ReshapeOp>(
- reshape.getLoc(),
- RankedTensorType::get(intermediateShape,
- reshape.getType().getElementType()),
- adaptor.getInput1());
- Value expand =
- rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
- rewriter.replaceOp(reshape, expand);
-
- return success();
- }
-};
-
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
public:
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
@@ -2295,13 +2065,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
/*benefit=*/300);
- patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
- /*benefit=*/100);
- patterns->add<ReshapeConverterExpand>(patterns->getContext(),
- /*benefit=*/200);
- patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
- /*benefit=*/300);
-
patterns->add<
// clang-format off
PointwiseConverter<tosa::AddOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 4b1e351e9746e..f3a76ddfcde78 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -56,6 +56,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::ConstOp>();
target.addLegalOp<tosa::WhileOp>();
target.addLegalOp<tosa::SliceOp>();
+ target.addLegalOp<tosa::ReshapeOp>();
target.addLegalOp<tosa::PadOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 55acad6d0cf21..b276c773d55ba 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -15,21 +15,236 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace tosa;
+static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
+ ArrayRef<int64_t> rhsShape,
+ SmallVector<int64_t> &intermediateShape,
+ bool isDynamic) {
+ if (isDynamic) {
+ // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
+ intermediateShape = {ShapedType::kDynamic};
+ return true;
+ }
+
+ if (lhsShape.empty() || rhsShape.empty()) {
+ intermediateShape = {};
+ return true;
+ }
+
+ unsigned currLhsDim = 0, currRhsDim = 0;
+ while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
+ int64_t rhsSize = rhsShape[currRhsDim];
+ int64_t lhsSize = lhsShape[currLhsDim];
+ while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
+ currRhsDim < rhsShape.size()) {
+ if (lhsSize < rhsSize) {
+ currLhsDim++;
+ if (currLhsDim < lhsShape.size()) {
+ lhsSize *= lhsShape[currLhsDim];
+ }
+ } else {
+ currRhsDim++;
+ if (currRhsDim < rhsShape.size()) {
+ rhsSize *= rhsShape[currRhsDim];
+ }
+ }
+ }
+ if (lhsSize == rhsSize) {
+ intermediateShape.push_back(lhsSize);
+ }
+ currRhsDim++;
+ currLhsDim++;
+ }
+
+ // If the iterators didn't reach the end and their leftover dimensions are not
+ // equal to 1 an intermediate shape was not found.
+ while (currLhsDim < lhsShape.size()) {
+ if (lhsShape[currLhsDim++] != 1) {
+ return false;
+ }
+ }
+
+ while (currRhsDim < rhsShape.size()) {
+ if (rhsShape[currRhsDim++] != 1) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool createReassociationMapsForCollapse(
+ PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> dstShape,
+ SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
+
+ // If the shape is dynamic, create a map for collapsing into one dimension.
+ if (isDynamic) {
+ SmallVector<AffineExpr, 2> exprs;
+ for (int i = 0, s = srcShape.size(); i < s; ++i)
+ exprs.push_back(rewriter.getAffineDimExpr(i));
+ reassociationMap = {exprs};
+ return true;
+ }
+
+ if (dstShape.empty()) {
+ reassociationMap = {};
+ return true;
+ }
+
+ reassociationMap.resize(dstShape.size());
+ unsigned currSrcDim = 0, currDstDim = 0;
+ while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
+ int64_t dstSize = dstShape[currDstDim];
+ int64_t srcSize = srcShape[currSrcDim];
+ while (srcSize < dstSize && currSrcDim < srcShape.size()) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ srcSize *= srcShape[currSrcDim];
+ }
+ if (srcSize == dstSize) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ // If the next dim in collapsedShape is not 1, treat subsequent dims in
+ // expandedShape which are 1 to be collapsed.
+ if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
+ while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
+ reassociationMap[currDstDim].push_back(
+ rewriter.getAffineDimExpr(currSrcDim++));
+ }
+ }
+ }
+ currDstDim++;
+ }
+
+ // If both iterators didn't reach the end, we have leftover dimentions which
+ // implies that we have a mismatch in shape.
+ return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+}
+
namespace {
+class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+ using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-class SliceConverter : public OpRewritePattern<tosa::SliceOp> {
+ LogicalResult
+ matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
+ ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ bool isDynamic = !operandTy.hasStaticShape();
+
+ if (isDynamic && resultTy.getRank() != 1) {
+ return rewriter.notifyMatchFailure(
+ reshape, "Cannot collapse dynamic dims to more than one dimension");
+ }
+
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
+ resultTy.getShape(),
+ reassociationMap, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape,
+ "tosa.reshape Attempting to collapse into an incompatible shape");
+ }
+
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+ intermediateShape, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape, "tosa.reshape Cannot collapse into given shape");
+ }
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+ return success();
+ }
+};
+
+class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
public:
- using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+ using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
+ ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ bool isDynamic = !operandTy.hasStaticShape();
+
+ if (isDynamic && operandTy.getRank() != 1) {
+ return rewriter.notifyMatchFailure(
+ reshape, "Cannot expand dynamic dims from more than one dimension");
+ }
+
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
+ operandTy.getShape(),
+ reassociationMap, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape,
+ "tosa.reshape Attempting to expand into an incompatible shape");
+ }
+
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+ intermediateShape, isDynamic) ||
+ intermediateShape != operandTy.getShape()) {
+ return rewriter.notifyMatchFailure(
+ reshape, "tosa.reshape Cannot expand into given shape");
+ }
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+ return success();
+ }
+};
+
+class ReshapeConverterCollapseExpand
+ : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+ using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
+ ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ bool isDynamic = !operandTy.hasStaticShape();
+
+ SmallVector<int64_t> intermediateShape;
+ if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
+ intermediateShape, isDynamic)) {
+ return rewriter.notifyMatchFailure(
+ reshape, "tosa.reshape Cannot identify an intermediate shape between "
+ "the given two shapes");
+ }
+
+ Value collapse = rewriter.create<tosa::ReshapeOp>(
+ reshape.getLoc(),
+ RankedTensorType::get(intermediateShape,
+ reshape.getType().getElementType()),
+ adaptor.getInput1());
+ Value expand =
+ rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
+ rewriter.replaceOp(reshape, expand);
+
+ return success();
+ }
+};
+
+class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
+public:
+ using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
- Value input = sliceOp.getInput();
+ Value input = adaptor.getInput();
SmallVector<int64_t> strides, sizes;
ArrayRef<int64_t> starts = sliceOp.getStart();
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
@@ -139,4 +354,10 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<SliceConverter, PadConverter>(patterns->getContext());
+ patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
+ /*benefit=*/100);
+ patterns->add<ReshapeConverterExpand>(patterns->getContext(),
+ /*benefit=*/200);
+ patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
+ /*benefit=*/300);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index d3a84eb513ab4..138a30fa837d4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -96,7 +96,7 @@ func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32>
func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]]
+ // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]])
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor<f32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
// CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
// CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
@@ -116,7 +116,7 @@ func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32>
func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG1]]
+ // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]])
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
// CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
// CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
@@ -137,8 +137,8 @@ func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
- // CHECK: [[RESHAPE1:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK: [[RESHAPE2:%.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+ // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) {new_shape = array<i64: 3>}
+ // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array<i64: 2>}
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
// CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
// CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
@@ -536,94 +536,6 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
return
}
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
- // CHECK: return [[RESHAPE]]
- return %0 : tensor<6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
- // CHECK: return [[RESHAPE]]
- return %0 : tensor<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_uprank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
- // CHECK: return [[RESHAPE]]
- return %0 : tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_uprank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
- // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
- // CHECK: return [[RESHAPE]]
- return %0 : tensor<2x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_samerank
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
-func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
- // CHECK-NEXT: return %[[RESHAPE2]]
- return %0 : tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_samerank_dyn
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
-func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
- // CHECK-NEXT: return %[[RESHAPE2]]
- return %0 : tensor<2x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank_6D
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
- // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
- return %0 : tensor<6x5x77xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank_6D_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
- // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
- // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
- return %0 : tensor<?x5x77xf32>
-}
// -----
@@ -714,7 +626,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
- // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
@@ -724,7 +636,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
- // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
// CHECK: arith.constant 1.0
@@ -764,7 +676,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
+ // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1, 4>}
%0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
return
}
@@ -784,7 +696,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor<f32> into tensor<1xf32>
+ // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<?xf32>) -> tensor<1xf32>
return
}
@@ -806,7 +718,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
- // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
+ // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 5, -9223372036854775808, 1>}
%0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
return
}
@@ -828,7 +740,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
// CHECK: %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[MAX]] : f32
- // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
+ // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1>}
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return
}
@@ -849,7 +761,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
- // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
// CHECK: [[INIT:%.+]] = tensor.empty()
@@ -859,7 +771,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
- // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
// CHECK: arith.constant 1
@@ -899,7 +811,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1)
// CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
// CHECK: linalg.yield [[RES]] : i1
- // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
%0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
// CHECK: arith.constant false
@@ -1231,21 +1143,21 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]]
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 4, 3>}
%0 = "tosa.tile"(%arg0) {multiples = array<i64: 2, 1>} : (tensor<2x3xi8>) -> (tensor<4x3xi8>)
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 2, 6>}
%1 = "tosa.tile"(%arg0) {multiples = array<i64: 1, 2>} : (tensor<2x3xi8>) -> (tensor<2x6xi8>)
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
+ // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 10, 21>}
%2 = "tosa.tile"(%arg0) {multiples = array<i64: 5, 7>} : (tensor<2x3xi8>) -> (tensor<10x21xi8>)
return
@@ -1265,8 +1177,7 @@ func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
- // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+ // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 3>}
%0 = "tosa.tile"(%arg0) {multiples = array<i64: 2, 1>} : (tensor<?x3xi8>) -> (tensor<?x3xi8>)
return
@@ -1286,8 +1197,7 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
- // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+ // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 2, -9223372036854775808>}
%0 = "tosa.tile"(%arg0) {multiples = array<i64: 2, -1>} : (tensor<2x3xi8>) -> (tensor<2x?xi8>)
return
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index b4ed6ca0ceae3..34084ebf5d3ce 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -1,6 +1,95 @@
// RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
-// CHECK-LABEL: @slice
+// CHECK-LABEL: @test_reshape_downrank
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
+ // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_downrank_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
+ // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_uprank
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_uprank_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
+ // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
+ // CHECK: return [[RESHAPE]]
+ return %0 : tensor<2x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_samerank
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
+func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
+ // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
+ // CHECK-NEXT: return %[[RESHAPE2]]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_samerank_dyn
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
+ // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+ // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+ // CHECK-NEXT: return %[[RESHAPE2]]
+ return %0 : tensor<2x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_downrank_6D
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+ // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+ return %0 : tensor<6x5x77xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
+ // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
+ // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+ return %0 : tensor<?x5x77xf32>
+}
+
+// -----
+
+// CHECK-LABLE: func @slice
func.func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
%0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>) -> (tensor<1xf32>)
More information about the Mlir-commits
mailing list