[Mlir-commits] [mlir] Fixes in 'tosa.reshape' lowering and folder (PR #85798)

Rafael Ubal llvmlistbot at llvm.org
Tue Mar 26 06:16:42 PDT 2024


https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/85798

>From a7653af040179b3fda73a5fadb486a7468aaa3d5 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Mon, 18 Mar 2024 15:25:40 -0400
Subject: [PATCH 1/5] Lowering of 'tosa.reshape' based on 'tensor.reshape'

---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 239 ++++--------------
 1 file changed, 48 insertions(+), 191 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 505d85f211111c..0b51e28e191fb2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -19,184 +19,48 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
+#include <numeric>
+
 using namespace mlir;
 using namespace tosa;
 
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
-                                  ArrayRef<int64_t> rhsShape,
-                                  SmallVector<int64_t> &intermediateShape,
-                                  bool isDynamic) {
-  if (isDynamic) {
-    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
-    intermediateShape = {ShapedType::kDynamic};
-    return true;
-  }
-
-  if (lhsShape.empty() || rhsShape.empty()) {
-    intermediateShape = {};
-    return true;
-  }
-
-  unsigned currLhsDim = 0, currRhsDim = 0;
-  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
-    int64_t rhsSize = rhsShape[currRhsDim];
-    int64_t lhsSize = lhsShape[currLhsDim];
-    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
-           currRhsDim < rhsShape.size()) {
-      if (lhsSize < rhsSize) {
-        currLhsDim++;
-        if (currLhsDim < lhsShape.size()) {
-          lhsSize *= lhsShape[currLhsDim];
-        }
-      } else {
-        currRhsDim++;
-        if (currRhsDim < rhsShape.size()) {
-          rhsSize *= rhsShape[currRhsDim];
-        }
-      }
-    }
-    if (lhsSize == rhsSize) {
-      intermediateShape.push_back(lhsSize);
-    }
-    currRhsDim++;
-    currLhsDim++;
-  }
-
-  // If the iterators didn't reach the end and their leftover dimensions are not
-  // equal to 1 an intermediate shape was not found.
-  while (currLhsDim < lhsShape.size()) {
-    if (lhsShape[currLhsDim++] != 1) {
-      return false;
+// Compute the dimension size of the result tensor corresponding to the
+// placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
+// op. Argument 'newShape' is expected to contain exactly one value set to -1.
+static Value getReshapePlaceholderDimSize(OpBuilder &builder, Location loc,
+                                          TypedValue<TensorType> input,
+                                          ArrayRef<int64_t> newShape) {
+  // Calculate the product of all dimensions in the new shape. We expect to have
+  // exactly one size set to -1, so we can discard this component by just
+  // negating the final product.
+  auto newSizeValue = -std::accumulate(newShape.begin(), newShape.end(), 1,
+                                       std::multiplies<int64_t>());
+  assert(newSizeValue > 0);
+  auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeValue);
+
+  // Calculate input tensor size. Traverse static dimensions first to fold as
+  // many results as possible.
+  auto inputShape = input.getType().getShape();
+  int64_t staticInputSize = 1;
+  for (auto size : inputShape)
+    if (size != ShapedType::kDynamic)
+      staticInputSize *= size;
+ 
+  // Continue calculating input tensor size by multiplying dynamic dimensions.
+  Value inputSize = builder.create<arith::ConstantIndexOp>(loc, staticInputSize);
+  for (auto [index, size] : llvm::enumerate(inputShape)) {
+    if (size == ShapedType::kDynamic) {
+      auto dimSize = builder.create<tensor::DimOp>(loc, input, index);
+      inputSize = builder.create<arith::MulIOp>(loc, inputSize, dimSize);
     }
   }
 
-  while (currRhsDim < rhsShape.size()) {
-    if (rhsShape[currRhsDim++] != 1) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
-static bool createReassociationMapsForCollapse(
-    PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
-    ArrayRef<int64_t> dstShape,
-    SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-
-  // If the shape is dynamic, create a map for collapsing into one dimension.
-  if (isDynamic) {
-    SmallVector<AffineExpr, 2> exprs;
-    for (int i = 0, s = srcShape.size(); i < s; ++i)
-      exprs.push_back(rewriter.getAffineDimExpr(i));
-    reassociationMap = {exprs};
-    return true;
-  }
-
-  if (dstShape.empty()) {
-    reassociationMap = {};
-    return true;
-  }
-
-  reassociationMap.resize(dstShape.size());
-  unsigned currSrcDim = 0, currDstDim = 0;
-  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
-    int64_t dstSize = dstShape[currDstDim];
-    int64_t srcSize = srcShape[currSrcDim];
-    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      srcSize *= srcShape[currSrcDim];
-    }
-    if (srcSize == dstSize) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      // If the next dim in collapsedShape is not 1, treat subsequent dims in
-      // expandedShape which are 1 to be collapsed.
-      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
-        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
-          reassociationMap[currDstDim].push_back(
-              rewriter.getAffineDimExpr(currSrcDim++));
-        }
-      }
-    }
-    currDstDim++;
-  }
-
-  // If both iterators didn't reach the end, we have leftover dimentions which
-  // implies that we have a mismatch in shape.
-  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+  // Calculate placeholder size, which will be folded if input tensor was
+  // statically shaped.
+  return builder.createOrFold<arith::DivUIOp>(loc, inputSize, newSize);
 }
 
 namespace {
-Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
-                     ShapedType resultTy, Value operand) {
-  ShapedType operandTy = cast<ShapedType>(operand.getType());
-  if (resultTy == operandTy)
-    return operand;
-
-  bool isDynamic = !operandTy.hasStaticShape();
-
-  if (isDynamic && resultTy.getRank() != 1) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "Cannot collapse dynamic dims to more than one dimension");
-    return {};
-  }
-
-  SmallVector<ReassociationExprs, 4> reassociationMap;
-  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
-                                          resultTy.getShape(),
-                                          reassociationMap, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Attempting to collapse into an incompatible shape");
-    return {};
-  }
-
-  SmallVector<int64_t> intermediateShape;
-  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                             intermediateShape, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Cannot collapse into given shape");
-    return {};
-  }
-  return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
-                                                  reassociationMap);
-}
-
-Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
-                   ShapedType resultTy, Value operand) {
-  ShapedType operandTy = cast<ShapedType>(operand.getType());
-  if (resultTy == operandTy)
-    return operand;
-
-  bool isDynamic = !operandTy.hasStaticShape();
-
-  if (isDynamic && operandTy.getRank() != 1) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "Cannot expand dynamic dims from more than one dimension");
-    return {};
-  }
-
-  SmallVector<ReassociationExprs, 4> reassociationMap;
-  if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
-                                          operandTy.getShape(),
-                                          reassociationMap, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Attempting to expand into an incompatible shape");
-    return {};
-  }
-
-  SmallVector<int64_t> intermediateShape;
-  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                             intermediateShape, isDynamic) ||
-      intermediateShape != operandTy.getShape()) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Cannot expand into given shape");
-    return {};
-  }
-  return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
-                                                reassociationMap);
-}
 
 class ReshapeConverterCollapseExpand
     : public OpConversionPattern<tosa::ReshapeOp> {
@@ -206,30 +70,23 @@ class ReshapeConverterCollapseExpand
   LogicalResult
   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
-    ShapedType resultTy = cast<ShapedType>(reshape.getType());
-    bool isDynamic = !operandTy.hasStaticShape();
-
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
-                               intermediateShape, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot identify an intermediate shape between "
-                   "the given two shapes");
+    auto loc = reshape.getLoc();
+    auto input = reshape.getInput1();
+    auto newShape = reshape.getNewShape();
+
+    // Create list of values for new shape
+    SmallVector<Value> newShapeList(reshape.getNewShape().size());
+    for (auto [index, size] : llvm::enumerate(reshape.getNewShape())) {
+      newShapeList[index] = size == -1 ?
+          getReshapePlaceholderDimSize(rewriter, loc, input, newShape) :
+          rewriter.create<arith::ConstantIndexOp>(loc, size);
     }
-    auto intermediateTy = RankedTensorType::get(
-        intermediateShape, reshape.getType().getElementType());
-
-    Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
-                                    adaptor.getInput1());
-    if (!collapse)
-      return failure();
-
-    Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
-    if (!expand)
-      return failure();
 
-    rewriter.replaceOp(reshape, expand);
+    // Reshape tensor
+    auto newShapeTensor = rewriter.createOrFold<tensor::FromElementsOp>(
+        loc, newShapeList);
+    rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
+        reshape, reshape.getResult().getType(), input, newShapeTensor);
     return success();
   }
 };

>From 482e127e33011dd25c6a9468255cafbcd9b28259 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 19 Mar 2024 10:24:14 -0400
Subject: [PATCH 2/5] Bux fixed in 'tosa.reshape' canonicalization. Completed
 lowering for 'tosa.reshape' lowering

---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 106 ++++++-----
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |   5 +-
 .../TosaToTensor/tosa-to-tensor.mlir          | 167 +++++++++++++++---
 mlir/test/Dialect/Tosa/canonicalize.mlir      |   8 +
 4 files changed, 222 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 0b51e28e191fb2..62ed41ebda4f50 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -24,46 +24,71 @@
 using namespace mlir;
 using namespace tosa;
 
+static Value getIndexConstant(OpBuilder& builder, Location loc, int64_t index) {
+  return builder.create<arith::ConstantIndexOp>(loc, index);
+}
+
+// Return the total size of the given input tensor.
+static Value getTensorSize(OpBuilder& builder, Location loc, TypedValue<TensorType> input) {
+  // If the input tensor is statically shaped, return its size as a constant.
+  if (input.getType().hasStaticShape()) {
+    auto shape = input.getType().getShape();
+    auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
+    return getIndexConstant(builder, loc, size);
+  }
+
+  // When the input tensor has at least one dynamic dimension, collapse it into
+  // a 1D tensor and get its size.
+  auto rank = input.getType().getRank();
+  auto elementType = input.getType().getElementType();
+  auto collapsedType = RankedTensorType::get({ShapedType::kDynamic}, elementType);
+  auto reassociationIndices = SmallVector<ReassociationIndices>{
+    llvm::to_vector(llvm::seq<int64_t>(rank))
+  };
+  auto collapsed = builder.create<tensor::CollapseShapeOp>(
+      loc, collapsedType, input, reassociationIndices);
+  return builder.create<tensor::DimOp>(loc, collapsed, 0);
+}
+
 // Compute the dimension size of the result tensor corresponding to the
 // placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
-// op. Argument 'newShape' is expected to contain exactly one value set to -1.
-static Value getReshapePlaceholderDimSize(OpBuilder &builder, Location loc,
-                                          TypedValue<TensorType> input,
-                                          ArrayRef<int64_t> newShape) {
+// op. Argument 'index' indicates the position of the -1 placeholder.
+static Value getReshapePlaceholderDimSize(OpBuilder &builder,
+                                          tosa::ReshapeOp reshape,
+                                          int64_t index) {
+  auto loc = reshape.getLoc();
+  auto input = reshape.getInput1();
+  auto newShape = reshape.getNewShape();
+  auto resultType = reshape.getResult().getType();
+
+  // If the corresponding dimension in the result type is static, take the
+  // dimension size from there.
+  assert(newShape[index] == -1);
+  if (!resultType.isDynamicDim(index))
+    return getIndexConstant(builder, loc, resultType.getDimSize(index));
+
   // Calculate the product of all dimensions in the new shape. We expect to have
   // exactly one size set to -1, so we can discard this component by just
   // negating the final product.
-  auto newSizeValue = -std::accumulate(newShape.begin(), newShape.end(), 1,
-                                       std::multiplies<int64_t>());
-  assert(newSizeValue > 0);
-  auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeValue);
-
-  // Calculate input tensor size. Traverse static dimensions first to fold as
-  // many results as possible.
-  auto inputShape = input.getType().getShape();
-  int64_t staticInputSize = 1;
-  for (auto size : inputShape)
-    if (size != ShapedType::kDynamic)
-      staticInputSize *= size;
- 
-  // Continue calculating input tensor size by multiplying dynamic dimensions.
-  Value inputSize = builder.create<arith::ConstantIndexOp>(loc, staticInputSize);
-  for (auto [index, size] : llvm::enumerate(inputShape)) {
-    if (size == ShapedType::kDynamic) {
-      auto dimSize = builder.create<tensor::DimOp>(loc, input, index);
-      inputSize = builder.create<arith::MulIOp>(loc, inputSize, dimSize);
-    }
-  }
-
-  // Calculate placeholder size, which will be folded if input tensor was
-  // statically shaped.
+  auto newSizeLiteral = -std::accumulate(newShape.begin(), newShape.end(), 1,
+                                         std::multiplies<int64_t>());
+  assert(newSizeLiteral >= 0);
+  auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeLiteral);
+
+  // Avoid a division by zero. If any of the given dimension sizes was set to
+  // zero, set the placeholder size to zero, too.
+  if (newSizeLiteral == 0)
+    return newSize;
+
+  // The size of the placeholder dimension is the size of the input tensor
+  // divided by all non-placeholder dimension sizes.
+  auto inputSize = getTensorSize(builder, loc, input);
   return builder.createOrFold<arith::DivUIOp>(loc, inputSize, newSize);
 }
 
 namespace {
 
-class ReshapeConverterCollapseExpand
-    : public OpConversionPattern<tosa::ReshapeOp> {
+class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
 public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
 
@@ -72,19 +97,18 @@ class ReshapeConverterCollapseExpand
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = reshape.getLoc();
     auto input = reshape.getInput1();
-    auto newShape = reshape.getNewShape();
 
     // Create list of values for new shape
-    SmallVector<Value> newShapeList(reshape.getNewShape().size());
+    SmallVector<Value> newShapeVector(reshape.getNewShape().size());
     for (auto [index, size] : llvm::enumerate(reshape.getNewShape())) {
-      newShapeList[index] = size == -1 ?
-          getReshapePlaceholderDimSize(rewriter, loc, input, newShape) :
-          rewriter.create<arith::ConstantIndexOp>(loc, size);
+      newShapeVector[index] = size == -1 ?
+          getReshapePlaceholderDimSize(rewriter, reshape, index) :
+          getIndexConstant(rewriter, loc, size);
     }
 
     // Reshape tensor
     auto newShapeTensor = rewriter.createOrFold<tensor::FromElementsOp>(
-        loc, newShapeList);
+        loc, newShapeVector);
     rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
         reshape, reshape.getResult().getType(), input, newShapeTensor);
     return success();
@@ -273,8 +297,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<SliceConverter, PadConverter, ConcatConverter>(
-      patterns->getContext());
-
-  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
+  patterns->add<
+    ConcatConverter,
+    PadConverter,
+    ReshapeConverter,
+    SliceConverter
+  >(patterns->getContext());
 }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4c50aaecfe9488..d23c9fe824c94a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -795,7 +795,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   if (!inputTy || !outputTy)
     return {};
 
-  if (inputTy == outputTy)
+  // Fold when the input and output types are the same. This is only safe when
+  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
+  // there may still be a productive reshape.
+  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
     return getInput1();
 
   // reshape(reshape(x)) -> reshape(x)
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index daaa68a7260b71..e1fd7838293b6a 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -1,11 +1,15 @@
 // RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
 
+// -----
+
 // CHECK-LABEL: @test_reshape_downrank
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK: %[[SHAPE:.+]] = arith.constant dense<6> : tensor<1xindex>
+  // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x3xf32>, tensor<1xindex>) -> tensor<6xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
-  // CHECK: return [[RESHAPE]]
+
+  // CHECK: return %[[RESHAPE]] : tensor<6xf32>
   return %0 : tensor<6xf32>
 }
 
@@ -14,9 +18,16 @@ func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
 // CHECK-LABEL: @test_reshape_downrank_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+
+  // CHECK-DAG: %[[SHAPE:.+]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+  // CHECK-DAG: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x?xf32>, tensor<1xindex>) -> tensor<?xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
-  // CHECK: return [[RESHAPE]]
+
+  // CHECK: return %[[RESHAPED]] : tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -25,9 +36,10 @@ func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
 // CHECK-LABEL: @test_reshape_uprank
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK: %[[SHAPE:.+]] = arith.constant dense<[2, 3]> : tensor<2xindex>
+  // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<6xf32>, tensor<2xindex>) -> tensor<2x3xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
-  // CHECK: return [[RESHAPE]]
+  // CHECK: return %[[RESHAPE]] : tensor<2x3xf32>
   return %0 : tensor<2x3xf32>
 }
 
@@ -36,57 +48,166 @@ func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
 // CHECK-LABEL: @test_reshape_uprank_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C2_0:.+]] = arith.constant 2 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0]] : tensor<?xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+  // CHECK-DAG: %[[PLACEHOLDER:.+]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+
+  // CHECK: %[[SHAPE:.+]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
+  // CHECK: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?xf32>, tensor<2xindex>) -> tensor<2x?xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
-  // CHECK: return [[RESHAPE]]
+  
+  // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @test_reshape_samerank
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
 func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3]> : tensor<2xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<3x2xf32>, tensor<2xindex>) -> tensor<2x3xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
-  // CHECK-NEXT: return %[[RESHAPE2]]
+  
+  // CHECK: return %[[RESHAPED]] : tensor<2x3xf32>
   return %0 : tensor<2x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @test_reshape_samerank_dyn
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
 func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
+  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x2xf32>, tensor<2xindex>) -> tensor<2x?xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
-  // CHECK-NEXT: return %[[RESHAPE2]]
+
+  // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_6D
+// CHECK-LABEL: @test_reshape_downrank_6d
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
-  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
+func.func @test_reshape_downrank_6d(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[6, 5, 77]> : tensor<3xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<1x2x3x5x7x11xf32>, tensor<3xindex>) -> tensor<6x5x77xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  
+  // CHECK: return %[[RESHAPED]] : tensor<6x5x77xf32>
   return %0 : tensor<6x5x77xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+// CHECK-LABEL: @test_reshape_downrank_6d_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
-  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
-  // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
+func.func @test_reshape_downrank_6d_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+  // CHECK-DAG: %[[C77:.*]] = arith.constant 77 : index
+  // CHECK-DAG: %[[C385:.*]] = arith.constant 385 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C385]] : index
+
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[PLACEHOLDER]], %[[C5]], %[[C77]] : tensor<3xindex>
+  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<1x2x?x5x7x11xf32>, tensor<3xindex>) -> tensor<?x5x77xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+
+  // CHECK: return %[[RESHAPED]] : tensor<?x5x77xf32>
   return %0 : tensor<?x5x77xf32>
 }
 
+// -----
+
+// CHECK-LABEL: test_reshape_zero
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
+func.func @test_reshape_zero(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[0, 3, 0]> : tensor<3xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<3x2x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, -1>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+  
+  // CHECK: return %[[RESHAPED]] : tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_dynamic_to_static
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+func.func @test_reshape_dynamic_to_static(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3, 4]> : tensor<3xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<2x3x4xf32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+  
+  // CHECK: return %[[RESHAPED]] : tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_dynamic_to_static_placeholder
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+func.func @test_reshape_dynamic_to_static_placeholder(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3, 4]> : tensor<3xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<2x3x4xf32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+  
+  // CHECK: return %[[RESHAPED]] : tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_dynamic_to_dynamic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+func.func @test_reshape_dynamic_to_dynamic(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3, 4]> : tensor<3xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+  // CHECK: return %[[RESHAPED]] : tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_dynamic_to_dynamic_placeholder
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+func.func @test_reshape_dynamic_to_dynamic_placeholder(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C8]] : index
+
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]], %[[C4]] : tensor<3xindex>
+  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  
+  // CHECK: return %[[RESHAPED]] : tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+
 // -----
 
 // CHECK-LABLE: func @slice
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e7ede2e0ccef9a..6eac759a083645 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -365,6 +365,14 @@ func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
   return %0 : tensor<?x10xf32>
 }
 
+// CHECK-LABEL: @reshape_canonicalize_dyn_nofold
+func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor<?x?x10xf32>) -> tensor<?x?x10xf32> {
+  // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
+  // CHECK: return %[[VAR0]] : tensor<?x?x10xf32>
+  %0 = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
+  return %0 : tensor<?x?x10xf32>
+}
+
 // CHECK-LABEL: @reshape_canonicalize_double
 func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
   // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>}

>From 3738c9721e15de199ad9d0a2f4e523ea22c1f5e0 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 21 Mar 2024 16:02:16 -0400
Subject: [PATCH 3/5] Addressed feedback, reverting to collapse-expand pattern

---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 245 +++++++++----
 .../TosaToTensor/tosa-to-tensor.mlir          | 340 +++++++++++-------
 2 files changed, 381 insertions(+), 204 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 62ed41ebda4f50..37ace4fd015938 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -24,69 +24,180 @@
 using namespace mlir;
 using namespace tosa;
 
-static Value getIndexConstant(OpBuilder& builder, Location loc, int64_t index) {
-  return builder.create<arith::ConstantIndexOp>(loc, index);
+namespace {
+
+// Infer the result type of 'tensor.expand_shape' in the collapse-expand
+// pair emitted for a 'tosa.reshape' op.
+TensorType inferReshapedType(TypedValue<TensorType> input,
+                             ArrayRef<int64_t> newShape) {
+  // Check if the input is static, and if so, get its total size
+  bool inputIsStatic = input.getType().hasStaticShape();
+  int64_t totalSize = inputIsStatic ? input.getType().getNumElements() : -1;
+ 
+  // Compute result shape
+  bool resultIsStatic = true;
+  auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
+    // If this is not a placeholder, do not change it
+    if (size >= 0)
+      return size;
+
+    // If we do not know the total size of the tensor, keep this dimension
+    // dynamic in the result shape.
+    if (!inputIsStatic) {
+      resultIsStatic = false;
+      return ShapedType::kDynamic;
+    }
+
+    // Calculate the product of all elements in 'newShape' except for the -1
+    // placeholder, which we discard by negating the result.
+    int64_t totalSizeNoPlaceholder = -std::accumulate(
+        newShape.begin(), newShape.end(), 1, std::multiplies());
+
+    // If there is a 0 component in 'newShape', resolve the placeholder as 0.
+    if (totalSizeNoPlaceholder == 0)
+      return 0;
+
+    // Resolve the placeholder as the quotient between the total tensor size and
+    // the product of all other sizes.
+    return totalSize / totalSizeNoPlaceholder;
+  });
+
+  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
+  // shaped input from being reshaped into a statically shaped result. We may
+  // simply turn the first result dimension dynamic to address this.
+  if (!inputIsStatic && resultIsStatic)
+    resultShape[0] = ShapedType::kDynamic;
+  
+  // The 'tensor.expand_shape' op also forbids a statically shaped input from
+  // being reshaped into a dynamically shaped result, but the placeholder
+  // inference algorithm above guarantees that this will never be the case.
+  assert(!inputIsStatic || resultIsStatic);
+
+  // Create result type
+  return input.getType().clone(resultShape);
+}
+
+// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
+// pair emitted for a 'tosa.reshape' op.
+TensorType inferIntermediateType(TensorType lhsType, TensorType rhsType) {
+  auto lhsShape = lhsType.getShape();
+  auto rhsShape = rhsType.getShape();
+
+  if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
+    return lhsType.clone({ShapedType::kDynamic});
+
+  if (lhsShape.empty() || rhsShape.empty())
+    return lhsType.clone({});
+
+  SmallVector<int64_t> intermediateShape;
+  unsigned currLhsDim = 0, currRhsDim = 0;
+  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
+    int64_t rhsSize = rhsShape[currRhsDim];
+    int64_t lhsSize = lhsShape[currLhsDim];
+    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
+           currRhsDim < rhsShape.size()) {
+      if (lhsSize < rhsSize) {
+        currLhsDim++;
+        if (currLhsDim < lhsShape.size()) {
+          lhsSize *= lhsShape[currLhsDim];
+        }
+      } else {
+        currRhsDim++;
+        if (currRhsDim < rhsShape.size()) {
+          rhsSize *= rhsShape[currRhsDim];
+        }
+      }
+    }
+    if (lhsSize == rhsSize) {
+      intermediateShape.push_back(lhsSize);
+    }
+    currRhsDim++;
+    currLhsDim++;
+  }
+
+  // Static shapes are guaranteed to be compatible by the op verifier, so all
+  // leftover dimensions should be 1.
+  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
+    assert(lhsShape[currLhsDim] == 1);
+  }
+  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
+    assert(rhsShape[currRhsDim] == 1);
+  }
+  
+  return lhsType.clone(intermediateShape);
 }
 
-// Return the total size of the given input tensor.
-static Value getTensorSize(OpBuilder& builder, Location loc, TypedValue<TensorType> input) {
-  // If the input tensor is statically shaped, return its size as a constant.
-  if (input.getType().hasStaticShape()) {
-    auto shape = input.getType().getShape();
-    auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
-    return getIndexConstant(builder, loc, size);
+SmallVector<ReassociationExprs>
+createReassociationMapForCollapse(OpBuilder &builder,
+                                  ArrayRef<int64_t> srcShape,
+                                  ArrayRef<int64_t> dstShape) {
+  if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
+    assert(dstShape.size() == 1);
+    SmallVector<AffineExpr, 2> exprs;
+    for (int64_t i = 0, s = srcShape.size(); i < s; ++i)
+      exprs.push_back(builder.getAffineDimExpr(i));
+    return {exprs};
   }
 
-  // When the input tensor has at least one dynamic dimension, collapse it into
-  // a 1D tensor and get its size.
-  auto rank = input.getType().getRank();
-  auto elementType = input.getType().getElementType();
-  auto collapsedType = RankedTensorType::get({ShapedType::kDynamic}, elementType);
-  auto reassociationIndices = SmallVector<ReassociationIndices>{
-    llvm::to_vector(llvm::seq<int64_t>(rank))
-  };
-  auto collapsed = builder.create<tensor::CollapseShapeOp>(
-      loc, collapsedType, input, reassociationIndices);
-  return builder.create<tensor::DimOp>(loc, collapsed, 0);
+  if (dstShape.empty())
+    return {};
+
+  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
+  unsigned currSrcDim = 0, currDstDim = 0;
+  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
+    int64_t dstSize = dstShape[currDstDim];
+    int64_t srcSize = srcShape[currSrcDim];
+    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
+      reassociationMap[currDstDim].push_back(
+          builder.getAffineDimExpr(currSrcDim++));
+      srcSize *= srcShape[currSrcDim];
+    }
+    if (srcSize == dstSize) {
+      reassociationMap[currDstDim].push_back(
+          builder.getAffineDimExpr(currSrcDim++));
+      // If the next dim in collapsedShape is not 1, treat subsequent dims in
+      // expandedShape which are 1 to be collapsed.
+      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
+        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
+          reassociationMap[currDstDim].push_back(
+              builder.getAffineDimExpr(currSrcDim++));
+        }
+      }
+    }
+    currDstDim++;
+  }
+
+  // If the source and target shapes are compatible, both iterators must have
+  // reached the end. This condition is guaranteed by the op verifier for
+  // static shapes.
+  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
+  return reassociationMap;
 }
 
-// Compute the dimension size of the result tensor corresponding to the
-// placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
-// op. Argument 'index' indicates the position of the -1 placeholder.
-static Value getReshapePlaceholderDimSize(OpBuilder &builder,
-                                          tosa::ReshapeOp reshape,
-                                          int64_t index) {
-  auto loc = reshape.getLoc();
-  auto input = reshape.getInput1();
-  auto newShape = reshape.getNewShape();
-  auto resultType = reshape.getResult().getType();
-
-  // If the corresponding dimension in the result type is static, take the
-  // dimension size from there.
-  assert(newShape[index] == -1);
-  if (!resultType.isDynamicDim(index))
-    return getIndexConstant(builder, loc, resultType.getDimSize(index));
-
-  // Calculate the product of all dimensions in the new shape. We expect to have
-  // exactly one size set to -1, so we can discard this component by just
-  // negating the final product.
-  auto newSizeLiteral = -std::accumulate(newShape.begin(), newShape.end(), 1,
-                                         std::multiplies<int64_t>());
-  assert(newSizeLiteral >= 0);
-  auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeLiteral);
-
-  // Avoid a division by zero. If any of the given dimension sizes was set to
-  // zero, set the placeholder size to zero, too.
-  if (newSizeLiteral == 0)
-    return newSize;
-
-  // The size of the placeholder dimension is the size of the input tensor
-  // divided by all non-placeholder dimension sizes.
-  auto inputSize = getTensorSize(builder, loc, input);
-  return builder.createOrFold<arith::DivUIOp>(loc, inputSize, newSize);
+// Create a tensor.collapse_shape op that reshapes the input into the given
+// result type. Both 'returnType' and 'input' must be statically shaped.
+TypedValue<TensorType> createCollapse(OpBuilder &builder, Location loc,
+                                      TensorType resultType,
+                                      TypedValue<TensorType> input) {
+  auto reassociationMap = createReassociationMapForCollapse(
+      builder, input.getType().getShape(), resultType.getShape());
+  auto result = builder.createOrFold<tensor::CollapseShapeOp>(
+      loc, resultType, input, reassociationMap);
+  return cast<TypedValue<TensorType>>(result);
 }
 
-namespace {
+// Create a tensor.expand_shape op that reshapes the input into the given result
+// type. Both 'returnType' and 'input' must be statically shaped.
+TypedValue<TensorType> createExpand(OpBuilder &builder, Location loc,
+                                    TensorType resultType,
+                                    TypedValue<TensorType> input) {
+  // Emit tensor.expand_shape op
+  auto reassociationMap = createReassociationMapForCollapse(
+      builder, resultType.getShape(), input.getType().getShape());
+  auto result = builder.createOrFold<tensor::ExpandShapeOp>(
+      loc, resultType, input, reassociationMap);
+  return cast<TypedValue<TensorType>>(result);
+}
 
 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
 public:
@@ -96,21 +207,21 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = reshape.getLoc();
+    auto resultType = reshape.getResult().getType();
     auto input = reshape.getInput1();
+    auto newShape = reshape.getNewShape();
 
-    // Create list of values for new shape
-    SmallVector<Value> newShapeVector(reshape.getNewShape().size());
-    for (auto [index, size] : llvm::enumerate(reshape.getNewShape())) {
-      newShapeVector[index] = size == -1 ?
-          getReshapePlaceholderDimSize(rewriter, reshape, index) :
-          getIndexConstant(rewriter, loc, size);
-    }
+    // Infer the result types for the subsequently emitted ops.
+    auto reshapedType = inferReshapedType(input, newShape);
+    auto intermediateType = inferIntermediateType(input.getType(), reshapedType);
+
+    // Emit collaspe-expand pair
+    auto collapsed = createCollapse(rewriter, loc, intermediateType, input);
+    auto expanded = createExpand(rewriter, loc, reshapedType, collapsed);
 
-    // Reshape tensor
-    auto newShapeTensor = rewriter.createOrFold<tensor::FromElementsOp>(
-        loc, newShapeVector);
-    rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
-        reshape, reshape.getResult().getType(), input, newShapeTensor);
+    // Cast to final result type if needed
+    auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
+    rewriter.replaceOp(reshape, result);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index e1fd7838293b6a..dc234f1ec333f7 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -2,215 +2,281 @@
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
-  // CHECK: %[[SHAPE:.+]] = arith.constant dense<6> : tensor<1xindex>
-  // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x3xf32>, tensor<1xindex>) -> tensor<6xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
+// CHECK-LABEL: test_reshape_0d_same_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: return %[[ARG_0]] : tensor<f32>
+func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
 
-  // CHECK: return %[[RESHAPE]] : tensor<6xf32>
-  return %0 : tensor<6xf32>
+// -----
+
+// CHECK-LABEL: test_reshape_1d_up_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[VAL_0]] : tensor<2x?xf32>
+func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
-  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-LABEL: test_reshape_1d_up_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<6xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: return %[[VAL_0]] : tensor<2x3xf32>
+func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
+  return %0 : tensor<2x3xf32>
+}
 
-  // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
-  // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// -----
 
-  // CHECK-DAG: %[[SHAPE:.+]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
-  // CHECK-DAG: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x?xf32>, tensor<1xindex>) -> tensor<?xf32>
+// CHECK-LABEL: test_reshape_2d_down_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x?xf32> into tensor<?xf32>
+// CHECK: return %[[VAL_0]] : tensor<?xf32>
+func.func @test_reshape_2d_down_d2d_auto(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
-
-  // CHECK: return %[[RESHAPED]] : tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_uprank
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
-  // CHECK: %[[SHAPE:.+]] = arith.constant dense<[2, 3]> : tensor<2xindex>
-  // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<6xf32>, tensor<2xindex>) -> tensor<2x3xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
-  // CHECK: return %[[RESHAPE]] : tensor<2x3xf32>
-  return %0 : tensor<2x3xf32>
+// CHECK-LABEL: test_reshape_2d_down_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x3xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x3xf32> into tensor<6xf32>
+// CHECK: return %[[VAL_0]] : tensor<6xf32>
+func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
+  return %0 : tensor<6xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_uprank_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
-  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-  // CHECK-DAG: %[[C2_0:.+]] = arith.constant 2 : index
+// CHECK-LABEL: test_reshape_2d_same_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x2xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[VAL_1]] : tensor<2x?xf32>
+func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}
 
-  // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0]] : tensor<?xf32> into tensor<?xf32>
-  // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
-  // CHECK-DAG: %[[PLACEHOLDER:.+]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+// -----
 
-  // CHECK: %[[SHAPE:.+]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
-  // CHECK: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?xf32>, tensor<2xindex>) -> tensor<2x?xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
-  return %0 : tensor<2x?xf32>
+// CHECK-LABEL: test_reshape_2d_same_s2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
+func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2>} : (tensor<2x4xf32>) -> tensor<?x2xf32>
+  return %0 : tensor<?x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_2d_same_s2d_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
+func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 4, 2>} : (tensor<2x4xf32>) -> tensor<?x2xf32>
+  return %0 : tensor<?x2xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_samerank
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
-func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
-  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3]> : tensor<2xindex>
-  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<3x2xf32>, tensor<2xindex>) -> tensor<2x3xf32>
+// CHECK-LABEL: test_reshape_2d_same_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: return %[[VAL_1]] : tensor<2x3xf32>
+func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<2x3xf32>
   return %0 : tensor<2x3xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_samerank_dyn
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
-func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-LABEL: test_reshape_3d_same_d2d_auto_empty
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<0x3x?xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, -1>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
 
-  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
-  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
-  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+// -----
 
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
-  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x2xf32>, tensor<2xindex>) -> tensor<2x?xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+// CHECK-LABEL: test_reshape_3d_same_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<2x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<2x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
 
-  // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
-  return %0 : tensor<2x?xf32>
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_auto_identity
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x3x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x3x4xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x3x?xf32>
+// CHECK: return %[[VAL_1]] : tensor<2x3x?xf32>
+func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, -1>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
+  return %0 : tensor<2x3x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_6d
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6d(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
-  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[6, 5, 77]> : tensor<3xindex>
-  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<1x2x3x5x7x11xf32>, tensor<3xindex>) -> tensor<6x5x77xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<6x5x77xf32>
-  return %0 : tensor<6x5x77xf32>
+// CHECK-LABEL: test_reshape_3d_same_d2d_explicit_empty
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, 2>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_6d_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6d_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-  // CHECK-DAG: %[[C77:.*]] = arith.constant 77 : index
-  // CHECK-DAG: %[[C385:.*]] = arith.constant 385 : index
-
-  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
-  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
-  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C385]] : index
-
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[PLACEHOLDER]], %[[C5]], %[[C77]] : tensor<3xindex>
-  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<1x2x?x5x7x11xf32>, tensor<3xindex>) -> tensor<?x5x77xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+// CHECK-LABEL: test_reshape_3d_same_d2d_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
+func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
 
-  // CHECK: return %[[RESHAPED]] : tensor<?x5x77xf32>
-  return %0 : tensor<?x5x77xf32>
+// -----
+
+// CHECK-LABEL: test_reshape_3d_same_d2d_explicit_identity
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x3x4xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?x3x4xf32> to tensor<2x3x?xf32>
+// CHECK: return %[[VAL_0]] : tensor<2x3x?xf32>
+func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
+  return %0 : tensor<2x3x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: test_reshape_zero
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
-func.func @test_reshape_zero(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
-  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[0, 3, 0]> : tensor<3xindex>
-  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<3x2x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, -1>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
+// CHECK-LABEL: test_reshape_3d_same_d2s_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
+func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: test_reshape_dynamic_to_static
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-func.func @test_reshape_dynamic_to_static(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
-  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3, 4]> : tensor<3xindex>
-  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<2x3x4xf32>
+// CHECK-LABEL: test_reshape_3d_same_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
+// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
+func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: test_reshape_dynamic_to_static_placeholder
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-func.func @test_reshape_dynamic_to_static_placeholder(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
-  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3, 4]> : tensor<3xindex>
-  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<2x3x4xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<2x3x4xf32>
+// CHECK-LABEL: test_reshape_3d_same_s2s_explicit_identity
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x3x4xf32>
+// CHECK: return %[[ARG_0]] : tensor<2x3x4xf32>
+func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: test_reshape_dynamic_to_dynamic
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-func.func @test_reshape_dynamic_to_dynamic(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3, 4]> : tensor<3xindex>
-  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-LABEL: test_reshape_3d_up_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : tensor<?xf32> into tensor<?x3x2x1xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x3x2x1xf32>
+func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
+  %0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
+  return %0 : tensor<1x3x2x1xf32>
+}
 
-  // CHECK: return %[[RESHAPED]] : tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
+// -----
+
+// CHECK-LABEL: test_reshape_5d_down_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x2x3xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x2x3xf32>
+// CHECK: return %[[VAL_1]] : tensor<?x2x3xf32>
+func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2, 3>} : (tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32>
+  return %0 : tensor<?x2x3xf32>
 }
 
 // -----
 
-// CHECK-LABEL: test_reshape_dynamic_to_dynamic_placeholder
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-func.func @test_reshape_dynamic_to_dynamic_placeholder(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-LABEL: test_reshape_6d_down_d2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x?x5x7x11xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x5x77xf32>
+// CHECK: return %[[VAL_1]] : tensor<?x5x77xf32>
+func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+  return %0 : tensor<?x5x77xf32>
+}
 
-  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
-  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C8]] : index
+// -----
 
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]], %[[C4]] : tensor<3xindex>
-  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
-  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-  
-  // CHECK: return %[[RESHAPED]] : tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
+// CHECK-LABEL: test_reshape_6d_down_s2s_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x3x5x7x11xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>
+// CHECK: return %[[VAL_0]] : tensor<6x5x77xf32>
+func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, -1>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  return %0 : tensor<6x5x77xf32>
 }
 
+// -----
+
+// CHECK-LABEL: test_reshape_6d_down_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x3x5x7x11xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>
+// CHECK: return %[[VAL_0]] : tensor<6x5x77xf32>
+func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  return %0 : tensor<6x5x77xf32>
+}
 
 // -----
 
-// CHECK-LABLE: func @slice
+// CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
   %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>)  -> (tensor<1xf32>)

>From 02dcd8fa411103b635ab48cf23fef755a9d7b5f3 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Fri, 22 Mar 2024 12:33:49 -0400
Subject: [PATCH 4/5] Added restriction for max number of missing target dims
 in op verifier

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f461e7e1a555b8..1aad8c8b12e786 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -970,6 +970,11 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
                            << " elements into " << outputElementsNum;
     }
   }
+
+  int missingDims = std::count(getNewShape().begin(), getNewShape().end(), -1);
+  if (missingDims > 1)
+    return emitOpError() << "At most one target dim can be -1";
+
   return mlir::success();
 }
 

>From f800e4d78e2d52c0d05728cddbf7db223a00f88b Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Sat, 23 Mar 2024 12:46:34 -0400
Subject: [PATCH 5/5] Added support for 0D tensor reshapes and addressed
 feedback

---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 102 +++++++++++-------
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |   4 +-
 .../TosaToTensor/tosa-to-tensor.mlir          |  81 ++++++++++++++
 3 files changed, 145 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 37ace4fd015938..11ba98ddf352b4 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -26,13 +26,35 @@ using namespace tosa;
 
 namespace {
 
+// Infer the type to which the input of a 'tosa.reshape' op must be cast when
+// lowered.
+TensorType inferReshapeInputType(TypedValue<TensorType> input,
+                                 ArrayRef<int64_t> newShape) {
+  // No need to cast input for non-empty target shape
+  if (!newShape.empty())
+    return input.getType();
+
+  // The input type must be cast into a tensor with the same rank and all static
+  // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
+  // op that converts a dynamically shaped tensor into a 0D tensor. While such
+  // construct is not incorrect on its own, bufferization cannot properly handle
+  // it at the moment, so we avoid it.
+  SmallVector<int64_t> shape(input.getType().getRank(), 1);
+  return input.getType().clone(shape);
+}
+
 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
 // pair emitted for a 'tosa.reshape' op.
-TensorType inferReshapedType(TypedValue<TensorType> input,
-                             ArrayRef<int64_t> newShape) {
+TensorType inferReshapeExpandedType(TensorType inputType,
+                                    ArrayRef<int64_t> newShape) {
+  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
+  // with just '{}', as it will invoke the incorrect overload.
+  if (newShape.empty())
+    return inputType.clone(ArrayRef<int64_t>{});
+
   // Check if the input is static, and if so, get its total size
-  bool inputIsStatic = input.getType().hasStaticShape();
-  int64_t totalSize = inputIsStatic ? input.getType().getNumElements() : -1;
+  bool inputIsStatic = inputType.hasStaticShape();
+  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
  
   // Compute result shape
   bool resultIsStatic = true;
@@ -74,21 +96,21 @@ TensorType inferReshapedType(TypedValue<TensorType> input,
   assert(!inputIsStatic || resultIsStatic);
 
   // Create result type
-  return input.getType().clone(resultShape);
+  return inputType.clone(resultShape);
 }
 
 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
 // pair emitted for a 'tosa.reshape' op.
-TensorType inferIntermediateType(TensorType lhsType, TensorType rhsType) {
+TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
   auto lhsShape = lhsType.getShape();
   auto rhsShape = rhsType.getShape();
 
+  if (lhsShape.empty() || rhsShape.empty())
+    return lhsType.clone(ArrayRef<int64_t>{});
+
   if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
     return lhsType.clone({ShapedType::kDynamic});
 
-  if (lhsShape.empty() || rhsShape.empty())
-    return lhsType.clone({});
-
   SmallVector<int64_t> intermediateShape;
   unsigned currLhsDim = 0, currRhsDim = 0;
   while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
@@ -128,20 +150,21 @@ TensorType inferIntermediateType(TensorType lhsType, TensorType rhsType) {
 }
 
 SmallVector<ReassociationExprs>
-createReassociationMapForCollapse(OpBuilder &builder,
-                                  ArrayRef<int64_t> srcShape,
-                                  ArrayRef<int64_t> dstShape) {
+createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
+  auto srcShape = cast<TensorType>(srcType).getShape();
+  auto dstShape = cast<TensorType>(dstType).getShape();
+
+  if (srcShape.empty() || dstShape.empty())
+    return {};
+
   if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
     assert(dstShape.size() == 1);
     SmallVector<AffineExpr, 2> exprs;
-    for (int64_t i = 0, s = srcShape.size(); i < s; ++i)
+    for (auto i : llvm::seq<int64_t>(srcShape.size()))
       exprs.push_back(builder.getAffineDimExpr(i));
     return {exprs};
   }
 
-  if (dstShape.empty())
-    return {};
-
   SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
   unsigned currSrcDim = 0, currDstDim = 0;
   while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
@@ -175,28 +198,23 @@ createReassociationMapForCollapse(OpBuilder &builder,
 }
 
 // Create a tensor.collapse_shape op that reshapes the input into the given
-// result type. Both 'returnType' and 'input' must be statically shaped.
-TypedValue<TensorType> createCollapse(OpBuilder &builder, Location loc,
-                                      TensorType resultType,
-                                      TypedValue<TensorType> input) {
-  auto reassociationMap = createReassociationMapForCollapse(
-      builder, input.getType().getShape(), resultType.getShape());
-  auto result = builder.createOrFold<tensor::CollapseShapeOp>(
-      loc, resultType, input, reassociationMap);
-  return cast<TypedValue<TensorType>>(result);
+// result type.
+Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
+                     Value input) {
+  auto reassociationMap =
+      createReassociationMapForCollapse(builder, input.getType(), resultType);
+  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
+                                                       reassociationMap);
 }
 
 // Create a tensor.expand_shape op that reshapes the input into the given result
-// type. Both 'returnType' and 'input' must be statically shaped.
-TypedValue<TensorType> createExpand(OpBuilder &builder, Location loc,
-                                    TensorType resultType,
-                                    TypedValue<TensorType> input) {
-  // Emit tensor.expand_shape op
-  auto reassociationMap = createReassociationMapForCollapse(
-      builder, resultType.getShape(), input.getType().getShape());
-  auto result = builder.createOrFold<tensor::ExpandShapeOp>(
-      loc, resultType, input, reassociationMap);
-  return cast<TypedValue<TensorType>>(result);
+// type.
+Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
+                   Value input) {
+  auto reassociationMap =
+      createReassociationMapForCollapse(builder, resultType, input.getType());
+  return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
+                                                     reassociationMap);
 }
 
 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
@@ -211,13 +229,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
     auto input = reshape.getInput1();
     auto newShape = reshape.getNewShape();
 
-    // Infer the result types for the subsequently emitted ops.
-    auto reshapedType = inferReshapedType(input, newShape);
-    auto intermediateType = inferIntermediateType(input.getType(), reshapedType);
+    // Infer all intermediate types
+    auto inputType = inferReshapeInputType(input, newShape);
+    auto expandedType = inferReshapeExpandedType(inputType, newShape);
+    auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
+
+    // Cast input if needed
+    auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
 
     // Emit collaspe-expand pair
-    auto collapsed = createCollapse(rewriter, loc, intermediateType, input);
-    auto expanded = createExpand(rewriter, loc, reshapedType, collapsed);
+    auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
+    auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
 
     // Cast to final result type if needed
     auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1aad8c8b12e786..6e6e8435073812 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -971,9 +971,9 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
     }
   }
 
-  int missingDims = std::count(getNewShape().begin(), getNewShape().end(), -1);
+  int missingDims = llvm::count(getNewShape(), -1);
   if (missingDims > 1)
-    return emitOpError() << "At most one target dim can be -1";
+    return emitOpError() << "At most one target dimension can be -1";
 
   return mlir::success();
 }
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index dc234f1ec333f7..a8a3c42e168422 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -12,6 +12,75 @@ func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32>
 
 // -----
 
+// CHECK-LABEL: test_reshape_0d_up_s2d_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
+// CHECK: return %[[VAL_1]] : tensor<?xf32>
+func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_0d_up_s2d_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
+// CHECK: return %[[VAL_1]] : tensor<?xf32>
+func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_0d_up_s2s_auto
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: return %[[VAL_0]] : tensor<1xf32>
+func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_0d_up_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: return %[[VAL_0]] : tensor<1xf32>
+func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor<f32>) -> tensor<1xf32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_1d_down_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?xf32> to tensor<1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1xf32> into tensor<f32>
+// CHECK: return %[[VAL_1]] : tensor<f32>
+func.func @test_reshape_1d_down_d2s_explicit(%arg0: tensor<?xf32>) -> tensor<f32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<?xf32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape_1d_down_s2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] [] : tensor<1xf32> into tensor<f32>
+// CHECK: return %[[VAL_0]] : tensor<f32>
+func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor<f32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<1xf32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
 // CHECK-LABEL: test_reshape_1d_up_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
@@ -230,6 +299,18 @@ func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<
 
 // -----
 
+// CHECK-LABEL: test_reshape_4d_down_d2s_explicit
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?x?x?x?xf32> to tensor<1x1x1x1xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1x1x1x1xf32> into tensor<f32>
+// CHECK: return %[[VAL_1]] : tensor<f32>
+func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor<?x?x?x?xf32>) -> tensor<f32> {
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<?x?x?x?xf32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
 // CHECK-LABEL: test_reshape_5d_down_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x2x3xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>



More information about the Mlir-commits mailing list