[Mlir-commits] [mlir] 723979e - Move tosa.reshape lowering patterns from TosaToLinalg to TosaToTensor

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Mar 7 08:06:23 PST 2023


Author: Krzysztof Drewniak
Date: 2023-03-07T16:06:18Z
New Revision: 723979efc8638d192a680d2a0f814f758274a046

URL: https://github.com/llvm/llvm-project/commit/723979efc8638d192a680d2a0f814f758274a046
DIFF: https://github.com/llvm/llvm-project/commit/723979efc8638d192a680d2a0f814f758274a046.diff

LOG: Move tosa.reshape lowering patterns from TosaToLinalg to TosaToTensor

Converting tosa.reshape to tensor.expand_shape and
tensor.collapse_shape logically belongs in the tosa-to-tensor
conversion process. In addition, we (rocMLIR downstream) want to use
the reshape -> expand/collapse_shape logic to simplify parts of our
Tosa integration without using the full tosa-to-linalg flow, further
motivating moving these patterns.

The downside to this change is that it means you need to run
tosa-to-tensor after tosa-to-linalg, which is probably a breaking
change.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D145119

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 46f5292a9ad1a..82cadf2a07daa 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -807,133 +807,12 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
     return rewriter.notifyMatchFailure(
         op, "unable to create linalg.generic body for reduce op");
 
-  SmallVector<ReassociationExprs, 4> reassociationMap;
-  uint64_t expandInputRank =
-      linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
-  reassociationMap.resize(expandInputRank);
-
-  for (uint64_t i = 0; i < expandInputRank; i++) {
-    int32_t dimToPush = i > axis ? i + 1 : i;
-    reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
-  }
-
-  if (expandInputRank != 0) {
-    int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
-    reassociationMap[expandedDim].push_back(
-        rewriter.getAffineDimExpr(expandedDim + 1));
-  }
-
-  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      op, resultTy, linalgOp.getResults()[0], reassociationMap);
+  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+      op, resultTy, linalgOp.getResults()[0],
+      rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
   return success();
 }
 
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
-                                  ArrayRef<int64_t> rhsShape,
-                                  SmallVector<int64_t> &intermediateShape,
-                                  bool isDynamic) {
-  if (isDynamic) {
-    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
-    intermediateShape = {ShapedType::kDynamic};
-    return true;
-  }
-
-  if (lhsShape.empty() || rhsShape.empty()) {
-    intermediateShape = {};
-    return true;
-  }
-
-  unsigned currLhsDim = 0, currRhsDim = 0;
-  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
-    int64_t rhsSize = rhsShape[currRhsDim];
-    int64_t lhsSize = lhsShape[currLhsDim];
-    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
-           currRhsDim < rhsShape.size()) {
-      if (lhsSize < rhsSize) {
-        currLhsDim++;
-        if (currLhsDim < lhsShape.size()) {
-          lhsSize *= lhsShape[currLhsDim];
-        }
-      } else {
-        currRhsDim++;
-        if (currRhsDim < rhsShape.size()) {
-          rhsSize *= rhsShape[currRhsDim];
-        }
-      }
-    }
-    if (lhsSize == rhsSize) {
-      intermediateShape.push_back(lhsSize);
-    }
-    currRhsDim++;
-    currLhsDim++;
-  }
-
-  // If the iterators didn't reach the end and their leftover dimensions are not
-  // equal to 1 an intermediate shape was not found.
-  while (currLhsDim < lhsShape.size()) {
-    if (lhsShape[currLhsDim++] != 1) {
-      return false;
-    }
-  }
-
-  while (currRhsDim < rhsShape.size()) {
-    if (rhsShape[currRhsDim++] != 1) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
-static bool createReassociationMapsForCollapse(
-    PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
-    ArrayRef<int64_t> dstShape,
-    SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-
-  // If the shape is dynamic, create a map for collapsing into one dimension.
-  if (isDynamic) {
-    SmallVector<AffineExpr, 2> exprs;
-    for (int i = 0, s = srcShape.size(); i < s; ++i)
-      exprs.push_back(rewriter.getAffineDimExpr(i));
-    reassociationMap = {exprs};
-    return true;
-  }
-
-  if (dstShape.empty()) {
-    reassociationMap = {};
-    return true;
-  }
-
-  reassociationMap.resize(dstShape.size());
-  unsigned currSrcDim = 0, currDstDim = 0;
-  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
-    int64_t dstSize = dstShape[currDstDim];
-    int64_t srcSize = srcShape[currSrcDim];
-    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      srcSize *= srcShape[currSrcDim];
-    }
-    if (srcSize == dstSize) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      // If the next dim in collapsedShape is not 1, treat subsequent dims in
-      // expandedShape which are 1 to be collapsed.
-      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
-        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
-          reassociationMap[currDstDim].push_back(
-              rewriter.getAffineDimExpr(currSrcDim++));
-        }
-      }
-    }
-    currDstDim++;
-  }
-
-  // If both iterators didn't reach the end, we have leftover dimentions which
-  // implies that we have a mismatch in shape.
-  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
-}
-
 namespace {
 
 template <typename SrcOp>
@@ -947,115 +826,6 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
-class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
-public:
-  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
-    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
-    bool isDynamic = !operandTy.hasStaticShape();
-
-    if (isDynamic && resultTy.getRank() != 1) {
-      return rewriter.notifyMatchFailure(
-          reshape, "Cannot collapse dynamic dims to more than one dimension");
-    }
-
-    SmallVector<ReassociationExprs, 4> reassociationMap;
-    if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
-                                            resultTy.getShape(),
-                                            reassociationMap, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape,
-          "tosa.reshape Attempting to collapse into an incompatible shape");
-    }
-
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                               intermediateShape, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot collapse into given shape");
-    }
-
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
-    return success();
-  }
-};
-
-class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
-public:
-  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
-    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
-    bool isDynamic = !operandTy.hasStaticShape();
-
-    if (isDynamic && operandTy.getRank() != 1) {
-      return rewriter.notifyMatchFailure(
-          reshape, "Cannot expand dynamic dims from more than one dimension");
-    }
-
-    SmallVector<ReassociationExprs, 4> reassociationMap;
-    if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
-                                            operandTy.getShape(),
-                                            reassociationMap, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape,
-          "tosa.reshape Attempting to expand into an incompatible shape");
-    }
-
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                               intermediateShape, isDynamic) ||
-        intermediateShape != operandTy.getShape()) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot expand into given shape");
-    }
-    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
-    return success();
-  }
-};
-
-class ReshapeConverterCollapseExpand
-    : public OpConversionPattern<tosa::ReshapeOp> {
-public:
-  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
-    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
-    bool isDynamic = !operandTy.hasStaticShape();
-
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
-                               intermediateShape, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot identify an intermediate shape between "
-                   "the given two shapes");
-    }
-
-    Value collapse = rewriter.create<tosa::ReshapeOp>(
-        reshape.getLoc(),
-        RankedTensorType::get(intermediateShape,
-                              reshape.getType().getElementType()),
-        adaptor.getInput1());
-    Value expand =
-        rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
-    rewriter.replaceOp(reshape, expand);
-
-    return success();
-  }
-};
-
 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 public:
   using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
@@ -2295,13 +2065,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
   patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
                                             /*benefit=*/300);
 
-  patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
-                                          /*benefit=*/100);
-  patterns->add<ReshapeConverterExpand>(patterns->getContext(),
-                                        /*benefit=*/200);
-  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
-                                                /*benefit=*/300);
-
   patterns->add<
       // clang-format off
       PointwiseConverter<tosa::AddOp>,

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 4b1e351e9746e..f3a76ddfcde78 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -56,6 +56,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
     target.addLegalOp<tosa::ConstOp>();
     target.addLegalOp<tosa::WhileOp>();
     target.addLegalOp<tosa::SliceOp>();
+    target.addLegalOp<tosa::ReshapeOp>();
     target.addLegalOp<tosa::PadOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 55acad6d0cf21..b276c773d55ba 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -15,21 +15,236 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace tosa;
 
+static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
+                                  ArrayRef<int64_t> rhsShape,
+                                  SmallVector<int64_t> &intermediateShape,
+                                  bool isDynamic) {
+  if (isDynamic) {
+    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
+    intermediateShape = {ShapedType::kDynamic};
+    return true;
+  }
+
+  if (lhsShape.empty() || rhsShape.empty()) {
+    intermediateShape = {};
+    return true;
+  }
+
+  unsigned currLhsDim = 0, currRhsDim = 0;
+  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
+    int64_t rhsSize = rhsShape[currRhsDim];
+    int64_t lhsSize = lhsShape[currLhsDim];
+    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
+           currRhsDim < rhsShape.size()) {
+      if (lhsSize < rhsSize) {
+        currLhsDim++;
+        if (currLhsDim < lhsShape.size()) {
+          lhsSize *= lhsShape[currLhsDim];
+        }
+      } else {
+        currRhsDim++;
+        if (currRhsDim < rhsShape.size()) {
+          rhsSize *= rhsShape[currRhsDim];
+        }
+      }
+    }
+    if (lhsSize == rhsSize) {
+      intermediateShape.push_back(lhsSize);
+    }
+    currRhsDim++;
+    currLhsDim++;
+  }
+
+  // If the iterators didn't reach the end and their leftover dimensions are not
+  // equal to 1 an intermediate shape was not found.
+  while (currLhsDim < lhsShape.size()) {
+    if (lhsShape[currLhsDim++] != 1) {
+      return false;
+    }
+  }
+
+  while (currRhsDim < rhsShape.size()) {
+    if (rhsShape[currRhsDim++] != 1) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+static bool createReassociationMapsForCollapse(
+    PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
+    ArrayRef<int64_t> dstShape,
+    SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
+
+  // If the shape is dynamic, create a map for collapsing into one dimension.
+  if (isDynamic) {
+    SmallVector<AffineExpr, 2> exprs;
+    for (int i = 0, s = srcShape.size(); i < s; ++i)
+      exprs.push_back(rewriter.getAffineDimExpr(i));
+    reassociationMap = {exprs};
+    return true;
+  }
+
+  if (dstShape.empty()) {
+    reassociationMap = {};
+    return true;
+  }
+
+  reassociationMap.resize(dstShape.size());
+  unsigned currSrcDim = 0, currDstDim = 0;
+  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
+    int64_t dstSize = dstShape[currDstDim];
+    int64_t srcSize = srcShape[currSrcDim];
+    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
+      reassociationMap[currDstDim].push_back(
+          rewriter.getAffineDimExpr(currSrcDim++));
+      srcSize *= srcShape[currSrcDim];
+    }
+    if (srcSize == dstSize) {
+      reassociationMap[currDstDim].push_back(
+          rewriter.getAffineDimExpr(currSrcDim++));
+      // If the next dim in collapsedShape is not 1, treat subsequent dims in
+      // expandedShape which are 1 to be collapsed.
+      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
+        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
+          reassociationMap[currDstDim].push_back(
+              rewriter.getAffineDimExpr(currSrcDim++));
+        }
+      }
+    }
+    currDstDim++;
+  }
+
+  // If both iterators didn't reach the end, we have leftover dimentions which
+  // implies that we have a mismatch in shape.
+  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+}
+
 namespace {
+class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
 
-class SliceConverter : public OpRewritePattern<tosa::SliceOp> {
+  LogicalResult
+  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
+    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+    bool isDynamic = !operandTy.hasStaticShape();
+
+    if (isDynamic && resultTy.getRank() != 1) {
+      return rewriter.notifyMatchFailure(
+          reshape, "Cannot collapse dynamic dims to more than one dimension");
+    }
+
+    SmallVector<ReassociationExprs, 4> reassociationMap;
+    if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
+                                            resultTy.getShape(),
+                                            reassociationMap, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape,
+          "tosa.reshape Attempting to collapse into an incompatible shape");
+    }
+
+    SmallVector<int64_t> intermediateShape;
+    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+                               intermediateShape, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape, "tosa.reshape Cannot collapse into given shape");
+    }
+
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+    return success();
+  }
+};
+
+class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
 public:
-  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
-                                PatternRewriter &rewriter) const final {
+  LogicalResult
+  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
+    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+    bool isDynamic = !operandTy.hasStaticShape();
+
+    if (isDynamic && operandTy.getRank() != 1) {
+      return rewriter.notifyMatchFailure(
+          reshape, "Cannot expand dynamic dims from more than one dimension");
+    }
+
+    SmallVector<ReassociationExprs, 4> reassociationMap;
+    if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
+                                            operandTy.getShape(),
+                                            reassociationMap, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape,
+          "tosa.reshape Attempting to expand into an incompatible shape");
+    }
+
+    SmallVector<int64_t> intermediateShape;
+    if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
+                               intermediateShape, isDynamic) ||
+        intermediateShape != operandTy.getShape()) {
+      return rewriter.notifyMatchFailure(
+          reshape, "tosa.reshape Cannot expand into given shape");
+    }
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+        reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
+    return success();
+  }
+};
+
+class ReshapeConverterCollapseExpand
+    : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
+    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+    bool isDynamic = !operandTy.hasStaticShape();
+
+    SmallVector<int64_t> intermediateShape;
+    if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
+                               intermediateShape, isDynamic)) {
+      return rewriter.notifyMatchFailure(
+          reshape, "tosa.reshape Cannot identify an intermediate shape between "
+                   "the given two shapes");
+    }
+
+    Value collapse = rewriter.create<tosa::ReshapeOp>(
+        reshape.getLoc(),
+        RankedTensorType::get(intermediateShape,
+                              reshape.getType().getElementType()),
+        adaptor.getInput1());
+    Value expand =
+        rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
+    rewriter.replaceOp(reshape, expand);
+
+    return success();
+  }
+};
+
+class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
+public:
+  using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
     Location loc = sliceOp.getLoc();
-    Value input = sliceOp.getInput();
+    Value input = adaptor.getInput();
     SmallVector<int64_t> strides, sizes;
     ArrayRef<int64_t> starts = sliceOp.getStart();
     strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
@@ -139,4 +354,10 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<SliceConverter, PadConverter>(patterns->getContext());
+  patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
+                                          /*benefit=*/100);
+  patterns->add<ReshapeConverterExpand>(patterns->getContext(),
+                                        /*benefit=*/200);
+  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
+                                                /*benefit=*/300);
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index d3a84eb513ab4..138a30fa837d4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -96,7 +96,7 @@ func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
 // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32>
 func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]]
+  // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor<f32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
   // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
   // CHECK:   [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
@@ -116,7 +116,7 @@ func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<
 // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32>
 func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG1]]
+  // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
   // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
   // CHECK:   [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
@@ -137,8 +137,8 @@ func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32
 // CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
 func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
-  // CHECK: [[RESHAPE1:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK: [[RESHAPE2:%.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+  // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) {new_shape = array<i64: 3>}
+  // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array<i64: 2>}
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
   // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
   // CHECK:   [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
@@ -536,94 +536,6 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   return
 }
 
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
-  // CHECK: return [[RESHAPE]]
-  return %0 : tensor<6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
-  // CHECK: return [[RESHAPE]]
-  return %0 : tensor<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_uprank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
-  // CHECK: return [[RESHAPE]]
-  return %0 : tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_uprank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
-  // CHECK: return [[RESHAPE]]
-  return %0 : tensor<2x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_samerank
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
-func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
-  // CHECK-NEXT: return %[[RESHAPE2]]
-  return %0 : tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_samerank_dyn
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
-func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
-  // CHECK-NEXT: return %[[RESHAPE2]]
-  return %0 : tensor<2x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank_6D
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
-  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
-  return %0 : tensor<6x5x77xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @test_reshape_downrank_6D_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
-  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
-  // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
-  return %0 : tensor<?x5x77xf32>
-}
 
 // -----
 
@@ -714,7 +626,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
-  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
@@ -724,7 +636,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
-  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
 
   // CHECK: arith.constant 1.0
@@ -764,7 +676,7 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
+  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1, 4>}
   %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
   return
 }
@@ -784,7 +696,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor<f32> into tensor<1xf32>
+  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<?xf32>) -> tensor<1xf32>
   return
 }
@@ -806,7 +718,7 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
-  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
+  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 5, -9223372036854775808, 1>}
   %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
   return
 }
@@ -828,7 +740,7 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
   // CHECK:   %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[MAX]] : f32
-  // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
+  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 1>}
   %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
   return
 }
@@ -849,7 +761,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
-  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
@@ -859,7 +771,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
-  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 5, 1>}
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
 
   // CHECK: arith.constant 1
@@ -899,7 +811,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1)
   // CHECK:   [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
   // CHECK:   linalg.yield [[RES]] : i1
-  // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 1, 4>}
   %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   // CHECK: arith.constant false
@@ -1231,21 +1143,21 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
-  // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]]
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 4, 3>}
   %0 = "tosa.tile"(%arg0) {multiples = array<i64: 2, 1>} : (tensor<2x3xi8>)  -> (tensor<4x3xi8>)
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
-  // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 2, 6>}
   %1 = "tosa.tile"(%arg0) {multiples = array<i64: 1, 2>} : (tensor<2x3xi8>)  -> (tensor<2x6xi8>)
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
-  // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
+  // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array<i64: 10, 21>}
   %2 = "tosa.tile"(%arg0) {multiples = array<i64: 5, 7>} : (tensor<2x3xi8>)  -> (tensor<10x21xi8>)
 
   return
@@ -1265,8 +1177,7 @@ func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
   // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
   // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK:   linalg.yield %[[ARG1]] : i8
-  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
-  // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: -9223372036854775808, 3>}
   %0 = "tosa.tile"(%arg0) {multiples = array<i64: 2, 1>} : (tensor<?x3xi8>)  -> (tensor<?x3xi8>)
 
   return
@@ -1286,8 +1197,7 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
   // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK:   linalg.yield %[[ARG1]] : i8
-  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
-  // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
+  // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array<i64: 2, -9223372036854775808>}
   %0 = "tosa.tile"(%arg0) {multiples = array<i64: 2, -1>} : (tensor<2x3xi8>)  -> (tensor<2x?xi8>)
 
   return

diff  --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index b4ed6ca0ceae3..34084ebf5d3ce 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -1,6 +1,95 @@
 // RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
 
-// CHECK-LABEL: @slice
+// CHECK-LABEL: @test_reshape_downrank
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
+  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_downrank_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
+  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_uprank
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
+  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_uprank_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
+  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<2x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_samerank
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
+func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
+  // CHECK-NEXT: return %[[RESHAPE2]]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_samerank_dyn
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+  // CHECK-NEXT: return %[[RESHAPE2]]
+  return %0 : tensor<2x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_downrank_6D
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  return %0 : tensor<6x5x77xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
+  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
+  // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+  return %0 : tensor<?x5x77xf32>
+}
+
+// -----
+
+// CHECK-LABLE: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
   %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>)  -> (tensor<1xf32>)


        


More information about the Mlir-commits mailing list