[Mlir-commits] [mlir] [mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape … (PR #125789)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 4 16:19:37 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: TatWai Chong (tatwaichong)

<details>
<summary>Changes</summary>

The shape operand is changed to input shape type since V1.0

Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110

---

Patch is 115.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125789.diff


25 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+3) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+3-2) 
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+7-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-9) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+17-20) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+7-5) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+17-6) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+9-2) 
- (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+12-7) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+2-1) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-5) 
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+72-38) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+37-20) 
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-1) 
- (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+33-12) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-1) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-8) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+25-12) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+29-15) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+42-26) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+27-14) 
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+44-33) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ede271cc56a8a..869ab913a715ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1621,7 +1621,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$new_shape
+    Tosa_Shape:$shape
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437e..88c21629286525 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
 }
 
 // Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+                        llvm::ArrayRef<int64_t> shape);
 Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
                         llvm::ArrayRef<int64_t> shape);
+
 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
 
 bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..edb04010d53fd9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1952,9 +1952,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
         });
 
+    auto shapeValue = getTosaConstShape(
+        rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultTy, genericOp.getResult(0),
-        rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+        op, resultTy, genericOp.getResult(0), shapeValue);
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..fdb8b1e1471a73 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
       return rewriter.notifyMatchFailure(reshape.getLoc(),
                                          "expected input type to be tensor");
     }
-    auto newShape = reshape.getNewShape();
+
+    llvm::SmallVector<int64_t> newShape;
+    if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+                                  newShape)) {
+      return failure();
+    }
 
     // Infer all intermediate types
     auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..229719f5ef84d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
         op, op.getType(), op.getInput1(),
-        rewriter.getDenseI64ArrayAttr(newShape));
+        getTosaConstShape(rewriter, op.getLoc(), newShape));
     return success();
   }
 };
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     if (!getInput1().hasOneUse())
       return {};
 
+    llvm::SmallVector<int64_t> shapeVec;
+    if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+      return {};
+
     return operand.reshape(
-        llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+        llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
   }
 
   return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..f88c6df8e2b458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1309,8 +1309,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
-  llvm::SmallVector<int64_t> newShapeValue =
-      convertToMlirShape(adaptor.getNewShape());
+  llvm::SmallVector<int64_t> newShapeValue;
+  if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+                                newShapeValue)) {
+    auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  } else {
+    newShapeValue = convertToMlirShape(newShapeValue);
+  }
 
   // We cannot infer from the total number of elements so we must take the
   // shape attribute as exact.
@@ -1346,13 +1354,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
-  if ((int64_t)getNewShape().size() != outputType.getRank())
+  SmallVector<int64_t> shapeValues;
+  if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+    // skip following checks if shape is not constant
+    return mlir::success();
+  }
+
+  if ((int64_t)shapeValues.size() != outputType.getRank())
     return emitOpError() << "new shape does not match result rank";
 
   for (auto [newShapeDim, outputShapeDim] :
-       zip(getNewShape(), outputType.getShape())) {
-    if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
-        newShapeDim != outputShapeDim)
+       zip(shapeValues, outputType.getShape())) {
+    if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+        outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
       return emitOpError() << "new shape is inconsistent with result shape";
 
     if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1371,10 +1385,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
 
     int64_t newShapeElementsNum = std::accumulate(
-        getNewShape().begin(), getNewShape().end(), 1LL,
+        shapeValues.begin(), shapeValues.end(), 1LL,
         [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
     bool isStaticNewShape =
-        llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+        llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
     if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
         (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
       return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1382,7 +1396,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     }
   }
 
-  int missingDims = llvm::count(getNewShape(), -1);
+  int missingDims = llvm::count(shapeValues, -1);
   if (missingDims > 1)
     return emitOpError() << "expected at most one target dimension to be -1";
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..04e8ad31cf2e2e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
 
 namespace {
 
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return ShapedType::isDynamic(dim) ? -1 : dim;
-  }));
-}
-
 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
   explicit Conv2DIsFullyConnected(MLIRContext *context)
       : OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
     auto revisedInputShapeType =
         RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getDenseI64ArrayAttr(
-                                     convertFromMlirShape(revisedInputShape)))
-                             .getResult();
+    auto revisedInputShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+    auto reshapedInput =
+        rewriter
+            .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+                                     revisedInputShapeValue)
+            .getResult();
 
     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     auto revisedWeightShapeType = RankedTensorType::get(
         revisedWeightShape,
         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getDenseI64ArrayAttr(
-                                      convertFromMlirShape(revisedWeightShape)))
-                              .getResult();
+    auto revisedWeightShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+    auto reshapedWeight =
+        rewriter
+            .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+                                     weight, revisedWeightShapeValue)
+            .getResult();
 
     // Perform a fully connected network over the reshaped input and weight.
     llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     // Reshape output to [N, IH, IW, OC].
     llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
                                               inputShape[2], weightShape[0]};
+    auto outputShapeValue = getTosaConstShape(
+        rewriter, op.getLoc(), convertFromMlirShape(outputShape));
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+        op, resultType, fullyConnectedValue, outputShapeValue);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d..b26397d0e3ed7a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     inputType = RankedTensorType::get(
         revisedInputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
+    auto revisedInputShapeValue =
+        getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
     input = rewriter
-                .create<tosa::ReshapeOp>(
-                    op.getLoc(), inputType, input,
-                    rewriter.getDenseI64ArrayAttr(revisedInputShape))
+                .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+                                         revisedInputShapeValue)
                 .getResult();
 
     Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     auto outputShapeType = RankedTensorType::get(
         outputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
+    auto outputShapeValue =
+        getTosaConstShape(rewriter, op->getLoc(), outputShape);
     Value outputValue = rewriter.create<tosa::ReshapeOp>(
-        op.getLoc(), outputShapeType, mulValue,
-        rewriter.getDenseI64ArrayAttr(outputShape));
+        op.getLoc(), outputShapeType, mulValue, outputShapeValue);
 
     Value bias = op.getBias();
     if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..69a66c98307e94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -160,9 +160,11 @@ class TransposeConvStridedConverter
         outputChannels, weightHeight / stride[0],
         stride[0],      weightWidth / stride[1],
         stride[1],      inputChannels};
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
-        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+        builder, UnrankedTensorType::get(weightETy), weight,
+        getTosaConstShape(rewriter, loc, weightReshapeDims0));
 
     // Transpose the factored-out stride to the output channels.
     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -174,12 +176,13 @@ class TransposeConvStridedConverter
         transposeWeightVal);
 
     // Collapse the strides and output channels into a single dimension.
-    llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+    llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
         weightWidth / stride[1], inputChannels};
+
     weight = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+        getTosaConstShape(rewriter, loc, weightReshapeDims1));
     ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
 
     weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -258,9 +261,13 @@ class TransposeConvStridedConverter
     // Factor striding out of the convolution result.
     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+    auto convReshapeDims0Value =
+        getTosaConstShape(rewriter, loc, convReshapeDims0);
+
     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-        rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+        convReshapeDims0Value);
 
     // Transpose the factored-out stride to the output channels.
     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -274,9 +281,13 @@ class TransposeConvStridedConverter
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+    auto convReshapeDims1Value =
+        getTosaConstShape(rewriter, loc, convReshapeDims1);
+
     conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-        rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+        convReshapeDims1Value);
 
     // Determine the amount to slice / pad from the result start.
     int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..281f0529a5c081 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     return std::nullopt;
 
   // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+  llvm::SmallVector<int64_t> newShape;
+  if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+                                newShape)) {
+    // this mean shape is not constant
+    return std::nullopt;
+  }
+  ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
   auto foldedReshape = rewriter.create<ReshapeOp>(
       reshapeOp.getLoc(),
       RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
                             reshapeOutputType.getElementType()),
       reshapeOp.getInput1(),
-      rewriter.getDenseI64ArrayAttr(
-          applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+      getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+                                                      hoistedPerms)));
   return foldedReshape->getResult(0);
 }
 
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e395..8ab12d038849f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
       llvm::cast<RankedTensorType>(lowerTensorValue.getType());
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+  auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
 
   auto reshapeLower = builder.create<tosa::ReshapeOp>(
-      reshapeOutputType, lowerTensorValue,
-      builder.getDenseI64ArrayAttr(reshapeOutputShape));
+      reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
 
   if (input1Rank > input2Rank) {
     input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
   return success();
 }
 
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
                                     llvm::ArrayRef<int64_t> shape) {
-  auto attr = rewriter.getIndexTensorAttr(shape);
-  auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
-  mlir::Operation *mlir_op =
-      rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+  auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+  auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+  mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
   return mlir_op->getResult(0);
 }
 
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+                                    llvm::ArrayRef<int64_t> shape) {
+  ImplicitLocOpBuilder builder(loc, rewriter);
+  return getTosaConstShape(builder, shape);
+}
+
 SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
   return to_vector(llvm::map_range(shape, [](int64_t dim) {
     return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d899..460e207d62de6a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
   %reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
   %1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
   %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
-  %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125789


More information about the Mlir-commits mailing list