[Mlir-commits] [mlir] [mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape … (PR #125789)
TatWai Chong
llvmlistbot at llvm.org
Wed Feb 5 13:14:12 PST 2025
https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/125789
>From 8aaa01d6d8668cae8f6b2ec715c831943c922123 Mon Sep 17 00:00:00 2001
From: Won Jeon <won.jeon at arm.com>
Date: Wed, 6 Dec 2023 22:11:25 +0000
Subject: [PATCH] [mlir][tosa] Change 'shape' of RESHAPE from attribute to
input shape type
Co-authored-by: TatWai Chong <tatwai.chong at arm.com>
Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +-
.../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 3 +
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 5 +-
.../Conversion/TosaToTensor/TosaToTensor.cpp | 8 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 8 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 32 +++--
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 37 +++---
.../Transforms/TosaDecomposeDepthwise.cpp | 12 +-
.../Transforms/TosaDecomposeTransposeConv.cpp | 23 +++-
.../Tosa/Transforms/TosaReduceTransposes.cpp | 11 +-
.../Dialect/Tosa/Utils/ConversionUtils.cpp | 19 +--
.../TosaToLinalg/tosa-to-linalg-invalid.mlir | 3 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 15 ++-
.../TosaToTensor/tosa-to-tensor.mlir | 110 ++++++++++++------
mlir/test/Dialect/Tosa/canonicalize.mlir | 57 +++++----
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 3 +-
mlir/test/Dialect/Tosa/inlining.mlir | 3 +-
mlir/test/Dialect/Tosa/invalid.mlir | 45 +++++--
mlir/test/Dialect/Tosa/level_check.mlir | 3 +-
mlir/test/Dialect/Tosa/ops.mlir | 24 ++--
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 42 ++++---
.../Tosa/tosa-decompose-depthwise.mlir | 44 ++++---
.../Tosa/tosa-decompose-transpose-conv.mlir | 68 ++++++-----
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 41 ++++---
.../Dialect/Tosa/tosa-reduce-transposes.mlir | 77 ++++++------
25 files changed, 449 insertions(+), 246 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 98bcbca3b02fa12..840558a81493f21 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1625,7 +1625,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let arguments = (ins
Tosa_Tensor:$input1,
- DenseI64ArrayAttr:$new_shape
+ Tosa_Shape:$shape
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437ec..88c21629286525d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
}
// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+ llvm::ArrayRef<int64_t> shape);
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 67218cee518d591..e4f055ea2f5c41b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1954,9 +1954,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
+ auto shapeValue = getTosaConstShape(
+ rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultTy, genericOp.getResult(0),
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+ op, resultTy, genericOp.getResult(0), shapeValue);
return success();
}
};
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 2a9b4d111bdfa2d..7f029d56e258283 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
return rewriter.notifyMatchFailure(reshape.getLoc(),
"expected input type to be tensor");
}
- auto newShape = reshape.getNewShape();
+
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+ newShape)) {
+ return failure();
+ }
// Infer all intermediate types
auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8e22c879753a339..a9a65ac271b3c2e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
- rewriter.getDenseI64ArrayAttr(newShape));
+ getTosaConstShape(rewriter, op.getLoc(), newShape));
return success();
}
};
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!getInput1().hasOneUse())
return {};
+ llvm::SmallVector<int64_t> shapeVec;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+ return {};
+
return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+ llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
}
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..955021abdd67b12 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1335,8 +1335,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
- llvm::SmallVector<int64_t> newShapeValue =
- convertToMlirShape(adaptor.getNewShape());
+ llvm::SmallVector<int64_t> newShapeValue;
+ if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+ newShapeValue)) {
+ auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ newShapeValue = convertToMlirShape(newShapeValue);
+ }
// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
@@ -1372,13 +1380,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
- if ((int64_t)getNewShape().size() != outputType.getRank())
+ SmallVector<int64_t> shapeValues;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+ // skip following checks if shape is not constant
+ return mlir::success();
+ }
+
+ if ((int64_t)shapeValues.size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";
for (auto [newShapeDim, outputShapeDim] :
- zip(getNewShape(), outputType.getShape())) {
- if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
- newShapeDim != outputShapeDim)
+ zip(shapeValues, outputType.getShape())) {
+ if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+ outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1397,10 +1411,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
int64_t newShapeElementsNum = std::accumulate(
- getNewShape().begin(), getNewShape().end(), 1LL,
+ shapeValues.begin(), shapeValues.end(), 1LL,
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
bool isStaticNewShape =
- llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+ llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1408,7 +1422,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
}
- int missingDims = llvm::count(getNewShape(), -1);
+ int missingDims = llvm::count(shapeValues, -1);
if (missingDims > 1)
return emitOpError() << "expected at most one target dimension to be -1";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 4eba89b59bbd79f..617a59bc87c9f66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
namespace {
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
- return ShapedType::isDynamic(dim) ? -1 : dim;
- }));
-}
-
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
auto revisedInputShapeType =
RankedTensorType::get(revisedInputShape, inputType.getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedInputShape)))
- .getResult();
+ auto revisedInputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+ auto reshapedInput =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+ revisedInputShapeValue)
+ .getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedWeightShape)))
- .getResult();
+ auto revisedWeightShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+ auto reshapedWeight =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+ weight, revisedWeightShapeValue)
+ .getResult();
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
+ auto outputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(outputShape));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultType, fullyConnectedValue,
- rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+ op, resultType, fullyConnectedValue, outputShapeValue);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d3..b26397d0e3ed7ab 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType = RankedTensorType::get(
revisedInputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto revisedInputShapeValue =
+ getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
input = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), inputType, input,
- rewriter.getDenseI64ArrayAttr(revisedInputShape))
+ .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+ revisedInputShapeValue)
.getResult();
Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto outputShapeValue =
+ getTosaConstShape(rewriter, op->getLoc(), outputShape);
Value outputValue = rewriter.create<tosa::ReshapeOp>(
- op.getLoc(), outputShapeType, mulValue,
- rewriter.getDenseI64ArrayAttr(outputShape));
+ op.getLoc(), outputShapeType, mulValue, outputShapeValue);
Value bias = op.getBias();
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index b5b3e9d76c47e23..26baddcf1dd1564 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -159,9 +159,11 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
- rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+ builder, UnrankedTensorType::get(weightETy), weight,
+ getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -173,12 +175,13 @@ class TransposeConvStridedConverter
transposeWeightVal);
// Collapse the strides and output channels into a single dimension.
- llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+ llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+ getTosaConstShape(rewriter, loc, weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -257,9 +260,13 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+ auto convReshapeDims0Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims0);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+ convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -273,9 +280,13 @@ class TransposeConvStridedConverter
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+ auto convReshapeDims1Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims1);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+ convReshapeDims1Value);
// Determine the amount to slice / pad from the result start.
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888b..281f0529a5c0816 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
return std::nullopt;
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+ newShape)) {
+ // this mean shape is not constant
+ return std::nullopt;
+ }
+ ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
auto foldedReshape = rewriter.create<ReshapeOp>(
reshapeOp.getLoc(),
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
reshapeOutputType.getElementType()),
reshapeOp.getInput1(),
- rewriter.getDenseI64ArrayAttr(
- applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+ getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+ hoistedPerms)));
return foldedReshape->getResult(0);
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e3950..8ab12d038849f4e 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+ auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
auto reshapeLower = builder.create<tosa::ReshapeOp>(
- reshapeOutputType, lowerTensorValue,
- builder.getDenseI64ArrayAttr(reshapeOutputShape));
+ reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
if (input1Rank > input2Rank) {
input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
return success();
}
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<int64_t> shape) {
- auto attr = rewriter.getIndexTensorAttr(shape);
- auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
- mlir::Operation *mlir_op =
- rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+ auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+ auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+ mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
return mlir_op->getResult(0);
}
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+ llvm::ArrayRef<int64_t> shape) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return getTosaConstShape(builder, shape);
+}
+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d8997..460e207d62de6af 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
- %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
+ %s = tosa.const_shape {value = dense<[10, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %2 = tosa.reshape %0, %s : (tensor<*xf32>, !tosa.shape<2>) -> tensor<10x10xf32>
return %2 : tensor<10x10xf32>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6e8501aaaf2afe2..3031434e6d4ba1f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1387,7 +1387,8 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 4, 3>}
+ // CHECK: [[CONST3:%.+]] = tosa.const_shape {value = dense<[4, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: tosa.reshape [[GENERIC]], [[CONST3]]
%cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8>
@@ -1395,7 +1396,8 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 2, 6>}
+ // CHECK: [[CONST8:%.+]] = tosa.const_shape {value = dense<[2, 6]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // tosa.reshape [[GENERIC]], [[CONST8]]
%cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
%1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8>
@@ -1403,8 +1405,9 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 10, 21>}
%cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ // CHECK: [[CONST13:%.+]] = tosa.const_shape {value = dense<[10, 21]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: tosa.reshape [[GENERIC]], [[CONST13]]
%2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<10x21xi8>
return
@@ -1424,7 +1427,8 @@ func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: -9223372036854775808, 3>}
+ // CHECK: %[[CONST3:.+]] = tosa.const_shape {value = dense<[-1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: tosa.reshape %[[GENERIC]], %[[CONST3]]
%cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%0 = tosa.tile %arg0, %cst21: (tensor<?x3xi8>, !tosa.shape<2>) -> tensor<?x3xi8>
@@ -1445,7 +1449,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK: linalg.yield %[[ARG1]] : i8
- // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: 2, -9223372036854775808>}
+ // CHECK: %[[CONST2:.+]] = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: tosa.reshape %[[GENERIC]], %[[CONST2]]
%cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x?xi8>
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index e83e898644bc091..c2eaba4c563d006 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -6,7 +6,8 @@
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
// CHECK: return %[[ARG_0]] : tensor<f32>
func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<f32>) -> tensor<f32>
+ %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<f32>, !tosa.shape<0>) -> tensor<f32>
return %0 : tensor<f32>
}
@@ -18,7 +19,8 @@ func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
// CHECK: return %[[VAL_1]] : tensor<?xf32>
func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<?xf32>
+ %s = tosa.const_shape { value = dense<-1> : tensor<1xindex> } : () -> !tosa.shape<1>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<f32>, !tosa.shape<1>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -30,7 +32,8 @@ func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
// CHECK: return %[[VAL_1]] : tensor<?xf32>
func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<?xf32>
+ %s = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<f32>, !tosa.shape<1>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -41,7 +44,8 @@ func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[VAL_0]] : tensor<1xf32>
func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<1xf32>
+ %s = tosa.const_shape { value = dense<-1> : tensor<1xindex> } : () -> !tosa.shape<1>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<f32>, !tosa.shape<1>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
@@ -52,7 +56,8 @@ func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[VAL_0]] : tensor<1xf32>
func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor<f32>) -> tensor<1xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
+ %s = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<f32>, !tosa.shape<1>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
@@ -64,7 +69,8 @@ func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor<f32>) -> tensor<1xf32>
// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1xf32> into tensor<f32>
// CHECK: return %[[VAL_1]] : tensor<f32>
func.func @test_reshape_1d_down_d2s_explicit(%arg0: tensor<?xf32>) -> tensor<f32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<?xf32>) -> tensor<f32>
+ %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?xf32>, !tosa.shape<0>) -> tensor<f32>
return %0 : tensor<f32>
}
@@ -75,7 +81,8 @@ func.func @test_reshape_1d_down_d2s_explicit(%arg0: tensor<?xf32>) -> tensor<f32
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] [] : tensor<1xf32> into tensor<f32>
// CHECK: return %[[VAL_0]] : tensor<f32>
func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor<f32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<1xf32>) -> tensor<f32>
+ %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<1xf32>, !tosa.shape<0>) -> tensor<f32>
return %0 : tensor<f32>
}
@@ -90,7 +97,8 @@ func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor<f32
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, %[[VAL_0]]] : tensor<?xf32> into tensor<2x?xf32>
// CHECK: return %[[EXPANDED]] : tensor<2x?xf32>
func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
+ %s = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?xf32>, !tosa.shape<2>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
@@ -101,7 +109,8 @@ func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: return %[[VAL_0]] : tensor<2x3xf32>
func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
+ %s = tosa.const_shape { value = dense<[2, 3]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<6xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
@@ -112,7 +121,8 @@ func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x?xf32> into tensor<?xf32>
// CHECK: return %[[VAL_0]] : tensor<?xf32>
func.func @test_reshape_2d_down_d2d_auto(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
+ %s = tosa.const_shape { value = dense<-1> : tensor<1xindex> } : () -> !tosa.shape<1>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x?xf32>, !tosa.shape<1>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -123,7 +133,8 @@ func.func @test_reshape_2d_down_d2d_auto(%arg0: tensor<2x?xf32>) -> tensor<?xf32
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x3xf32> into tensor<6xf32>
// CHECK: return %[[VAL_0]] : tensor<6xf32>
func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
+ %s = tosa.const_shape { value = dense<6> : tensor<1xindex> } : () -> !tosa.shape<1>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x3xf32>, !tosa.shape<1>) -> tensor<6xf32>
return %0 : tensor<6xf32>
}
@@ -139,7 +150,8 @@ func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, %[[DIV]]] : tensor<?xf32> into tensor<2x?xf32>
// CHECK: return %[[EXPANDED]] : tensor<2x?xf32>
func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
+ %s = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x2xf32>, !tosa.shape<2>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
@@ -152,10 +164,12 @@ func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2>} : (tensor<2x4xf32>) -> tensor<?x2xf32>
+ %s = tosa.const_shape { value = dense<[-1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
+
// -----
// CHECK-LABEL: test_reshape_2d_same_s2d_explicit
@@ -165,7 +179,8 @@ func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 4, 2>} : (tensor<2x4xf32>) -> tensor<?x2xf32>
+ %s = tosa.const_shape { value = dense<[4, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
@@ -177,7 +192,8 @@ func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?
// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: return %[[VAL_1]] : tensor<2x3xf32>
func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
+ %s = tosa.const_shape { value = dense<[2, 3]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<3x2xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
@@ -194,7 +210,8 @@ func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, -1>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+ %s = tosa.const_shape { value = dense<[0, 3, -1]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<3x2x?xf32>, !tosa.shape<3>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
@@ -211,7 +228,8 @@ func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tens
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<2x?x?xf32>) -> tensor<?x?x?xf32>
+ %s = tosa.const_shape { value = dense<[2, -1, 4]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x?x?xf32>, !tosa.shape<3>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
@@ -227,7 +245,8 @@ func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?
// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, 3, %[[DIV]]] : tensor<?xf32> into tensor<2x3x?xf32>
// CHECK: return %[[VAL_1]] : tensor<2x3x?xf32>
func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, -1>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
+ %s = tosa.const_shape { value = dense<[2, 3, -1]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x?xf32>
return %0 : tensor<2x3x?xf32>
}
@@ -244,7 +263,8 @@ func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> t
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, 2>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
+ %s = tosa.const_shape { value = dense<[0, 3, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<3x2x?xf32>, !tosa.shape<3>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
@@ -261,7 +281,8 @@ func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) ->
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x?x?xf32>, !tosa.shape<3>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
@@ -272,7 +293,8 @@ func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor<?x?x?xf32>) -> tensor
// CHECK: %[[VAL_0:.*]] = tensor.cast %[[ARG_0]] : tensor<?x3x4xf32> to tensor<2x3x?xf32>
// CHECK: return %[[VAL_0]] : tensor<2x3x?xf32>
func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
+ %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x?xf32>
return %0 : tensor<2x3x?xf32>
}
@@ -289,7 +311,8 @@ func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor<?x3x4xf32>)
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+ %s = tosa.const_shape { value = dense<[2, -1, 4]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x?x?xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
@@ -306,7 +329,8 @@ func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
+ %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x?x?xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
@@ -316,7 +340,8 @@ func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor<?x?x?xf32>) -> tensor
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x3x4xf32>
// CHECK: return %[[ARG_0]] : tensor<2x3x4xf32>
func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+ %s = tosa.const_shape { value = dense<[2, 3, 4]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
@@ -333,7 +358,8 @@ func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>)
// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
// CHECK: return %[[CAST]] : tensor<1x3x2x1xf32>
func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
- %0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
+ %s = tosa.const_shape { value = dense<[1, 3, 2, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %0 = tosa.reshape %input, %s : (tensor<?x?x?xf32>, !tosa.shape<4>) -> tensor<1x3x2x1xf32>
return %0 : tensor<1x3x2x1xf32>
}
@@ -345,7 +371,8 @@ func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<
// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[VAL_0]] [] : tensor<1x1x1x1xf32> into tensor<f32>
// CHECK: return %[[VAL_1]] : tensor<f32>
func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor<?x?x?x?xf32>) -> tensor<f32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64>} : (tensor<?x?x?x?xf32>) -> tensor<f32>
+ %s = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x?x?x?xf32>, !tosa.shape<0>) -> tensor<f32>
return %0 : tensor<f32>
}
@@ -361,7 +388,8 @@ func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor<?x?x?x?xf32>) -> tens
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 2, 3] : tensor<?xf32> into tensor<?x2x3xf32>
// CHECK: return %[[EXPANDED]] : tensor<?x2x3xf32>
func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2, 3>} : (tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32>
+ %s = tosa.const_shape { value = dense<[-1, 2, 3]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x?x?x2x3xf32>, !tosa.shape<3>) -> tensor<?x2x3xf32>
return %0 : tensor<?x2x3xf32>
}
@@ -377,7 +405,8 @@ func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 5, 77] : tensor<?xf32> into tensor<?x5x77xf32>
// CHECK: return %[[EXPANDED]] : tensor<?x5x77xf32>
func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
+ %s = tosa.const_shape { value = dense<[-1, 5, 77]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<1x2x?x5x7x11xf32>, !tosa.shape<3>) -> tensor<?x5x77xf32>
return %0 : tensor<?x5x77xf32>
}
@@ -388,7 +417,8 @@ func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> ten
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>
// CHECK: return %[[VAL_0]] : tensor<6x5x77xf32>
func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, -1>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+ %s = tosa.const_shape { value = dense<[6, 5, -1]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<1x2x3x5x7x11xf32>, !tosa.shape<3>) -> tensor<6x5x77xf32>
return %0 : tensor<6x5x77xf32>
}
@@ -400,10 +430,13 @@ func.func @test_reshape_6d_down_s2s_auto(%arg0: tensor<1x2x3x5x7x11xf32>) -> ten
//
// See https://github.com/llvm/llvm-project/pull/91521 for a full description.
+// -----
+
// CHECK-LABEL: reshape_bug_fix
// CHECK: tensor.expand_shape
func.func @reshape_bug_fix(%arg0: tensor<?xf32>) -> tensor<1x1x1x?xf32> {
- %0 = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, -1>} : (tensor<?xf32>) -> tensor<1x1x1x?xf32>
+ %1 = "tosa.const_shape"() {value = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = "tosa.reshape"(%arg0, %1) : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
return %0 : tensor<1x1x1x?xf32>
}
@@ -414,21 +447,22 @@ func.func @reshape_bug_fix(%arg0: tensor<?xf32>) -> tensor<1x1x1x?xf32> {
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2], [3], [4, 5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x5x77xf32>
// CHECK: return %[[VAL_0]] : tensor<6x5x77xf32>
func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+ %s = tosa.const_shape { value = dense<[6, 5, 77]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<1x2x3x5x7x11xf32>, !tosa.shape<3>) -> tensor<6x5x77xf32>
return %0 : tensor<6x5x77xf32>
}
// -----
// CHECK-LABEL: @test_reshape_samerank_unsigned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
+// CHECK-SAME: (%[[VAL_0:.*]]: tensor<3x2xui8>)
func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
- // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
- // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
- // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
- // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
- // CHECK-NEXT: return %[[CAST2]]
+ // CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<3x2xui8> to tensor<3x2xi8>
+ // CHECK: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
+ // CHECK: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
+ // CHECK: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
+ %s = tosa.const_shape { value = dense<[2, 3]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = "tosa.reshape"(%arg0, %s): (tensor<3x2xui8>, !tosa.shape<2>) -> tensor<2x3xui8>
return %0 : tensor<2x3xui8>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e0e1de6a94d10d0..582fd77cd7bc8ee 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -542,17 +542,20 @@ func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK-LABEL: @reshape_canonicalize
func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK: return %arg0
- %0 = tosa.reshape %arg0 {new_shape = array<i64: -1, 10>}: (tensor<?x10xf32>) -> tensor<?x10xf32>
- return %0 : tensor<?x10xf32>
+ %0 = "tosa.const_shape"() {value = dense<[-1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %arg0, %0 : (tensor<?x10xf32>, !tosa.shape<2>) -> tensor<?x10xf32>
+ return %1 : tensor<?x10xf32>
}
// -----
// CHECK-LABEL: @reshape_canonicalize_dyn_nofold
func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor<?x?x10xf32>) -> tensor<?x?x10xf32> {
- // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
+ // CHECK: %[[SHAPE:.+]] = tosa.const_shape {value = dense<[-1, 2, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK: %[[VAR0:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<?x?x10xf32>, !tosa.shape<3>) -> tensor<?x?x10xf32>
// CHECK: return %[[VAR0]] : tensor<?x?x10xf32>
- %0 = tosa.reshape %arg0 {new_shape = array<i64: -1, 2, 10>} : (tensor<?x?x10xf32>) -> tensor<?x?x10xf32>
+ %s = "tosa.const_shape"() {value = dense<[-1, 2, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.reshape %arg0, %s : (tensor<?x?x10xf32>, !tosa.shape<3>) -> tensor<?x?x10xf32>
return %0 : tensor<?x?x10xf32>
}
@@ -560,10 +563,13 @@ func.func @reshape_canonicalize_dyn_nofold(%arg0: tensor<?x?x10xf32>) -> tensor<
// CHECK-LABEL: @reshape_canonicalize_double
func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
- // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>}
+ // CHECK: %[[VAL_0:.*]] = tosa.const_shape {value = dense<[-1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_0]]
// CHECK: return %[[VAL_1]]
- %0 = tosa.reshape %arg0 {new_shape = array<i64: 5, -1>}: (tensor<?x10xf32>) -> tensor<5x?xf32>
- %1 = tosa.reshape %0 {new_shape = array<i64: -1, 5>}: (tensor<5x?xf32>) -> tensor<?x5xf32>
+ %cst0 = "tosa.const_shape"() <{value = dense<[5, -1]> : tensor<2xindex>}> : () -> !tosa.shape<2>
+ %0 = tosa.reshape %arg0, %cst0 : (tensor<?x10xf32>, !tosa.shape<2>) -> tensor<5x?xf32>
+ %cst1 = "tosa.const_shape"() <{value = dense<[-1, 5]> : tensor<2xindex>}> : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %cst1 : (tensor<5x?xf32>, !tosa.shape<2>) -> tensor<?x5xf32>
return %1 : tensor<?x5xf32>
}
@@ -574,8 +580,9 @@ func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
// CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>}
// CHECK: return %[[VAR0]]
%0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1, 5>} : (tensor<5xi32>) -> tensor<1x5xi32>
- return %1 : tensor<1x5xi32>
+ %1 = "tosa.const_shape"() {value = dense<[1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %2 = tosa.reshape %0, %1 : (tensor<5xi32>, !tosa.shape<2>) -> tensor<1x5xi32>
+ return %2 : tensor<1x5xi32>
}
// -----
@@ -584,7 +591,8 @@ func.func @reshape_canonicalize_const() -> tensor<1x5xi32> {
func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> {
// CHECK: tosa.reshape
%0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1, 5>} : (tensor<5xi32>) -> tensor<1x?xi32>
+ %2 = "tosa.const_shape"() {value = dense<[1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<5xi32>, !tosa.shape<2>) -> tensor<1x?xi32>
return %1 : tensor<1x?xi32>
}
@@ -596,7 +604,8 @@ func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi3
// CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>}
// CHECK: return %[[VAR0]], %[[VAR1]]
%0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1, 10>} : (tensor<10xi32>) -> tensor<1x10xi32>
+ %2 = "tosa.const_shape"() {value = dense<[1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<10xi32>, !tosa.shape<2>) -> tensor<1x10xi32>
return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
}
@@ -606,7 +615,8 @@ func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi3
func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
// CHECK: tosa.reshape
%0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1, 3>} : (tensor<3xi32>) -> tensor<1x3xi32>
+ %2 = "tosa.const_shape"() {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32>
return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
}
@@ -616,9 +626,10 @@ func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32
func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
// disabled folding for quantized element types
// CHECK{LITERAL}: "tosa.const"() <{value = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
- // CHECK{LITERAL}: tosa.reshape %0 {new_shape = array<i64: 1, 3>} : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+ // CHECK{LITERAL}: tosa.reshape %0, %1 : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<2>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
%0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1, 3>} : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+ %2 = "tosa.const_shape"() {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<2>) -> tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
return %1 : tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
}
@@ -626,8 +637,9 @@ func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform<i8:
// CHECK-LABEL: @transpose_canonicalize_strip_quant
func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
- // CHECK: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
- // CHECK: tosa.reshape %0 {new_shape = array<i64: 2, 1, 3>} : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+ // CHECK-DAG: tosa.const_shape {value = dense<[2, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK-DAG: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
+ // CHECK: tosa.reshape %0, %1 : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<3>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
%perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
%1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
@@ -691,7 +703,8 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
// CHECK-LABEL: @transpose_is_reshape
func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
- // CHECK: tosa.reshape %arg0 {new_shape = array<i64: 1, 4, 1, 5>} : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
+ // CHECK: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 4, 1, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: tosa.reshape %arg0, %[[CONST0]]
%perms = "tosa.const"() <{value = dense<[3, 1, 0, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%0 = tosa.transpose %arg0, %perms : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32>
return %0 : tensor<1x4x1x5xf32>
@@ -704,7 +717,8 @@ func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf3
func.func @single_bit_reshape() -> tensor<1xi1> {
// CHECK: "tosa.const"() <{value = dense<true> : tensor<1xi1>}
%0 = arith.constant dense<true> : tensor<1x1xi1>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1>} : (tensor<1x1xi1>) -> tensor<1xi1>
+ %2 = "tosa.const_shape"() <{value = dense<1> : tensor<1xindex>}> : () -> !tosa.shape<1>
+ %1 = tosa.reshape %0, %2 : (tensor<1x1xi1>, !tosa.shape<1>) -> tensor<1xi1>
return %1 : tensor<1xi1>
}
@@ -870,8 +884,11 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
// check that segfault is fixed
func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
%0 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %1 = tosa.reshape %0 {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1xi32>
+ %cst0 = "tosa.const_shape"() {value = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1xi32>
return %2 : tensor<1x1x1x1xi32>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 32677f06e22523e..40469987d89d087 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -500,7 +500,8 @@ func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
func.func @reshape_splat() -> tensor<6x5x4xi32> {
// CHECK: %[[SPLAT:.+]] = "tosa.const"() <{value = dense<42> : tensor<6x5x4xi32>}
%splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
- %reshape = tosa.reshape %splat { new_shape = array<i64: 6, 5, 4> } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32>
+ %const = tosa.const_shape {value = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32>
// CHECK: return %[[SPLAT]]
return %reshape : tensor<6x5x4xi32>
}
diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
index e892fdaa2775051..2a3065e80d0ea86 100644
--- a/mlir/test/Dialect/Tosa/inlining.mlir
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
}
func.func private @while_body_50(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) {
%1 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
- %3 = "tosa.reshape"(%1) {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
+ %4 = "tosa.const_shape"() {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %3 = "tosa.reshape"(%1, %4) : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
%2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
return %1, %arg1, %arg2, %2: tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 006c5bd52a9f669..2165e1f7ae3babb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -316,7 +316,8 @@ func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tenso
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
+ %3 = tosa.const_shape {value = dense<[273, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %arg0, %3 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<273x3xf32>
// expected-error at +1 {{'tosa.fully_connected' op weight of fully_connected is not constant}}
%2 = tosa.fully_connected %1, %arg1, %0 : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
return %2 : tensor<273x2xf32>
@@ -326,7 +327,8 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2xf32>) -> tensor<273x2xf32> {
%0 = "tosa.const"() {value = dense<[[-0.613216758, -0.63714242, -0.73500061], [0.180762768, 0.773053169, -0.933686495]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
+ %3 = tosa.const_shape {value = dense<[273, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %arg0, %3 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<273x3xf32>
// expected-error at +1 {{'tosa.fully_connected' op bias of fully_connected is not constant}}
%2 = tosa.fully_connected %1, %0, %arg1 : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
return %2 : tensor<273x2xf32>
@@ -426,81 +428,91 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
// -----
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
+ %1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
- %0 = tosa.reshape %arg0 {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
return
}
// -----
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
// expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<13x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
return
}
// -----
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
// expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
return
}
// -----
func.func @test_reshape_rank_mismatch(%arg0 : tensor<?xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape' op new shape does not match result rank}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4>} : (tensor<?xf32>) -> tensor<?xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?xf32>, !tosa.shape<2>) -> tensor<?xf32>
return
}
// -----
func.func @test_reshape_inconsistent_result_type(%arg0 : tensor<?xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[2, 4, -1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// expected-error at +1 {{'tosa.reshape' op new shape is inconsistent with result shape}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4, -1>} : (tensor<?xf32>) -> tensor<?x3x5xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?xf32>, !tosa.shape<3>) -> tensor<?x3x5xf32>
return
}
// -----
func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[3, 5]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape' op cannot reshape 8 elements into 15}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 3, 5>} : (tensor<2x4xf32>) -> tensor<3x5xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<2x4xf32>, !tosa.shape<2>) -> tensor<3x5xf32>
return
}
// -----
func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[-1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<1xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
return
}
// -----
func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<8xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
return
}
// -----
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[2, -1, -1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// expected-error at +1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>
+ %0 = "tosa.reshape"(%arg0, %s) : (tensor<?xf32>, !tosa.shape<3>) -> tensor<2x?x?xf32>
return
}
// -----
func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
+ %s = tosa.const_shape {value = dense<[-2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.reshape' op new shape has invalid tensor dimension size -2}}
- %0 = "tosa.reshape" (%arg0) {new_shape = array<i64: -2, -1>} : (tensor<4x?xf32>) -> tensor<?x4xf32>
+ %0 = "tosa.reshape" (%arg0, %s) : (tensor<4x?xf32>, !tosa.shape<2>) -> tensor<?x4xf32>
return
}
@@ -514,6 +526,15 @@ func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
// -----
+func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
+ %1 = tosa.const_shape {value = dense<[13, 21, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
+ %0 = "tosa.reshape"(%arg0, %1) : (tensor<?x0x3xf32>, !tosa.shape<3>) -> tensor<13x0x3xf32>
+ return
+}
+
+// -----
+
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
// expected-error at +1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 26bebdd898a0d6a..a7f76f2d0fa641f 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -70,8 +70,9 @@ func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x1
// -----
func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
+ %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
// expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
- %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1, 1, 1, 819>} : (tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32>
+ %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
return %0 : tensor<1x1x1x1x1x1x819xf32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index d00230d12aab1b4..baf09e089aa30fb 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -504,7 +504,8 @@ func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf
// CHECK-LABEL: reduce_all
func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
- %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
+ %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<1x21x3xi1>, !tosa.shape<2>) -> tensor<21x3xi1>
return %1 : tensor<21x3xi1>
}
@@ -512,7 +513,8 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
// CHECK-LABEL: reduce_any
func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
%0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
- %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
+ %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<1x21x3xi1>, !tosa.shape<2>) -> tensor<21x3xi1>
return %1 : tensor<21x3xi1>
}
@@ -520,7 +522,8 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
// CHECK-LABEL: reduce_max
func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
%0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32>
return %1 : tensor<21x3xf32>
}
@@ -528,7 +531,8 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
// CHECK-LABEL: reduce_min
func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
%0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32>
return %1 : tensor<21x3xf32>
}
@@ -536,7 +540,8 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
// CHECK-LABEL: reduce_product
func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
%0 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32>
return %1 : tensor<21x3xf32>
}
@@ -544,7 +549,8 @@ func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
// CHECK-LABEL: reduce_sum
func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
- %1 = tosa.reshape %0 {new_shape = array<i64: 21, 3>} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ %2 = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.reshape %0, %2 : (tensor<1x21x3xf32>, !tosa.shape<2>) -> tensor<21x3xf32>
return %1 : tensor<21x3xf32>
}
@@ -575,7 +581,8 @@ func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3
// -----
// CHECK-LABEL: reshape
func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
- %0 = tosa.reshape %arg0 {new_shape = array<i64: 1, 819>} : (tensor<13x21x3xf32>) -> tensor<1x819xf32>
+ %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
return %0 : tensor<1x819xf32>
}
@@ -724,7 +731,8 @@ func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
%2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
- %4 = tosa.reshape %2 {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
+ %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
%5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
%6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index e4a2897908072a6..9aade2fe45eb6ac 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -5,13 +5,16 @@
// CHECK-LABEL: @conv2d_as_fully_connected
func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
// CHECK-NOT: tosa.conv2d
- // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 400, 2>}
+ // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[400, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[4, 10, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[CONST0]]
// CHECK-SAME: -> tensor<400x2xf32>
- // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
+ // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1, %[[CONST1]]
// CHECK-SAME: -> tensor<3x2xf32>
// CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
// CHECK-SAME: -> tensor<400x3xf32>
- // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
+ // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[CONST2]]
// CHECK-SAME: -> tensor<4x10x10x3xf32>
// CHECK: return %[[VAR3]]
%0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
@@ -23,14 +26,17 @@ func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
// CHECK-LABEL: @conv2d_as_fully_connected_quant
func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
// CHECK-NOT: tosa.conv2d
- // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 400, 2>}
+ // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[400, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[4, 10, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[CONST0]]
// CHECK-SAME: -> tensor<400x2xi8>
- // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
+ // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1, %[[CONST1]]
// CHECK-SAME: -> tensor<3x2xi8>
// CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
// CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK-SAME: -> tensor<400x3xi32>
- // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
+ // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[CONST2]]
// CHECK-SAME: -> tensor<4x10x10x3xi32>
// CHECK: return %[[VAR3]]
%input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
@@ -42,14 +48,14 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
// -----
// CHECK-LABEL: func.func @conv_with_dynamic_dim(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x14x14x64xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<384x1x1x64xi8>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
-// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
-// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
-// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
-// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
+// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[-1, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[384, 64]> : tensor<2xindex>} : () -> !tosa.shape<2>
+// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[-1, 14, 14, 384]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[VAL_3:.*]] = tosa.reshape %arg0, %[[CONST0]]
+// CHECK: %[[VAL_4:.*]] = tosa.reshape %arg1, %[[CONST1]] : (tensor<384x1x1x64xi8>, !tosa.shape<2>) -> tensor<384x64xi8>
+// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %arg2 {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
+// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]], %[[CONST2]] : (tensor<?x384xi32>, !tosa.shape<4>) -> tensor<?x14x14x384xi32>
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
// CHECK: }
%input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8>
@@ -62,15 +68,19 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
// CHECK-LABEL: @conv2d_as_fully_connected_padded
func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
+ // CHECK-DAG: %[[FULLY_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 3]> : tensor<4xindex>}
+ // CHECK-DAG: %[[INPUT_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[576, 2]> : tensor<2xindex>}
+ // CHECK-DAG: %[[FILTER_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>}
// CHECK-DAG: %[[PAD_SHAPE:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
- // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
- // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
+ // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]], %[[INPUT_NEW_SHAPE]]
+ // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1, %[[FILTER_NEW_SHAPE]]
// CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32}
- // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
+ // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]], %[[FULLY_NEW_SHAPE]]
%input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32>
return %0 : tensor<4x12x12x3xi32>
}
+
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index ce29d1a498b4f88..6562a7c2ab55c4f 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -5,15 +5,19 @@
// CHECK-LABEL: @depthwise_conv2d_as_mul
func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK-NOT: tosa.depthwise_conv2d
- // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
+ // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>
+ // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>
+ // CHECK-DAG: %[[CONST2:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 6]> : tensor<4xindex>
+ // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>
+ // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0, %[[CONST0]]
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
- // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
+ // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1, %[[CONST1]]
// CHECK-SAME: -> tensor<1x1x1x2x3xf32>
// CHECK: %[[VAR2:.*]] = tosa.mul %[[VAR0]], %[[VAR1]]
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
- // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 6>}
+ // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[CONST2]]
// CHECK-SAME: -> tensor<4x10x10x6xf32>
- // CHECK: %[[VAR4:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
+ // CHECK: %[[VAR4:.*]] = tosa.reshape %arg2, %[[CONST3]]
// CHECK-SAME: -> tensor<1x1x1x6xf32>
// CHECK: %[[VAR5:.*]] = tosa.add %[[VAR3]], %[[VAR4]]
// CHECK-SAME: -> tensor<4x10x10x6xf32>
@@ -26,17 +30,22 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
+ // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>
// CHECK-DAG: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1x1x1xi32>}
// CHECK-DAG: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<1x1x1x1xi32>}
- // CHECK: %[[rIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
+ // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>
+ // CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 6]> : tensor<4xindex>
+ // CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>
+ // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: %[[rIn:.+]] = tosa.reshape %arg0, %[[CONST0]]
// CHECK: %[[cIn:.+]] = tosa.cast %[[rIn]] : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32>
// CHECK: %[[cWe:.+]] = tosa.cast %arg1 : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32>
// CHECK: %[[sIn:.+]] = tosa.sub %[[cIn]], %[[iZp]]
// CHECK: %[[sWe:.+]] = tosa.sub %[[cWe]], %[[wZp]]
- // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]] {new_shape = array<i64: 1, 1, 1, 2, 3>}
- // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]]
- // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
- // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
+ // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]], %[[CONST3]]
+ // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]], %[[SHIFT]]
+ // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]], %[[CONST4]]
+ // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2, %[[CONST5]]
// CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
%input_zp = "tosa.const"() {value = dense<7> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
@@ -48,14 +57,19 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
// CHECK-LABEL: @depthwise_conv2d_as_mul_padded
func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
- // CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
+ // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>}
+ // CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
// CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
- // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
+ // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>}
+ // CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 6]> : tensor<4xindex>}
+ // CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>}
+ // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: %[[reIn:.+]] = tosa.reshape %arg0, %[[CONST0]]
// CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
- // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
- // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]]
- // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
- // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
+ // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1, %[[CONST3]]
+ // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]], %[[SHIFT]]
+ // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]], %[[CONST4]]
+ // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2, %[[CONST5]]
// CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]]
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32>
return %0 : tensor<4x12x12x6xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 82838cc7e154514..bd18b7ea0fdff6b 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -56,11 +56,15 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]]
- // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
+ // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>}
+ // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]]
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
- // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
+ // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>}
+ // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]]
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32}
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
+ // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
@@ -70,13 +74,14 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<30xf32>}
// CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
- // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
+ // CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>}
+ // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONST6]]
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
- // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
- // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
- // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
- // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
+ // CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>
+ // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[CONST8]]
+ // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
+ // CHECK-DAG: %[[CONST9:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 5]> : tensor<4xindex>}
+ // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST9]]
// CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
%1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
@@ -92,11 +97,15 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
- // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
+ // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>}
+ // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]]
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
- // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
+ // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>}
+ // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]]
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32}
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
+ // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
@@ -108,13 +117,14 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
// CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
// CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
- // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
+ // CHECK-DAG: %[[CONV_NEW_SHAPE:.*]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>}
+ // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONV_NEW_SHAPE]]
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
- // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
- // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>}
- // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>}
- // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
- // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
+ // CHECK-DAG: %[[TEANS_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>}
+ // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[TEANS_NEW_SHAPE]]
+ // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
+ // CHECK-DAG: %[[ARG2_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 5]> : tensor<4xindex>}
+ // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[ARG2_NEW_SHAPE]]
// CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
%input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
@@ -126,25 +136,31 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-LABEL: @transpose_conv2d_strided_overpad
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
- // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>}
+ // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[1, 2, 1, 1, 2, 1]> : tensor<6xindex>}
// CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[2, 2, 1, 1]> : tensor<4xindex>}
+ // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
+ // CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[1, 17, 1, 1, 2, 1]> : tensor<6xindex>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ // CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[1, 17, 2, 1]> : tensor<4xindex>}
+ // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>}
+ // CHECK-DAG: %[[CONST10:.+]] = tosa.const_shape {value = dense<1> : tensor<4xindex>}
+ // CHECK-DAG: %[[INPUT_ZP:.*]] = "tosa.const"() <{value = dense<-103> : tensor<1xi8>}>
+ // CHECK-DAG: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{value = dense<93> : tensor<1xi8>}>
// CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
- // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
+ // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]], %[[CONST1]]
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
- // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
+ // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]], %[[CONST3]]
// CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
// CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
- // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
- // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]}
- // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
+ // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]], %[[CONST6]]
// CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
- // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>}
+ // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]], %[[CONST8]]
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
- // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]
// CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
%input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 73eabab657f380d..bdd403567a4ed5c 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -376,29 +376,42 @@ func.func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>)
// CHECK-LABEL: @test_static_reshape
func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
- // CHECK: tosa.reshape %arg0 {new_shape = array<i64: 16>} : (tensor<4x4xi32>) -> tensor<16xi32>
- %0 = tosa.reshape %arg0 {new_shape = array<i64: 16>} : (tensor<4x4xi32>) -> tensor<?xi32>
+ // CHECK: %[[CONST3:.+]] = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %3 = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK: tosa.reshape %arg0, %[[CONST3]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ %0 = tosa.reshape %arg0, %3 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
- // CHECK: tosa.reshape %arg0 {new_shape = array<i64: -1>} : (tensor<4x4xi32>) -> tensor<16xi32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: -1>} : (tensor<4x4xi32>) -> tensor<?xi32>
+ // CHECK: %[[CONST4:.+]] = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK: tosa.reshape %arg0, %[[CONST4]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ %4 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.reshape %arg0, %4 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
- // CHECK: tosa.reshape %arg0 {new_shape = array<i64: 2, -1>} : (tensor<4x4xi32>) -> tensor<2x8xi32>
- %2 = tosa.reshape %arg0 {new_shape = array<i64: 2, -1>} : (tensor<4x4xi32>) -> tensor<?x?xi32>
+ // CHECK: %[[CONST5:.+]] = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: tosa.reshape %arg0, %[[CONST5]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
+ %5 = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %2 = tosa.reshape %arg0, %5 : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
return
}
+
// -----
// CHECK-LABEL: @test_dynamic_reshape
func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
- // CHECK: %0 = tosa.reshape %arg0 {new_shape = array<i64: 16>} : (tensor<4x?xi32>) -> tensor<16xi32>
- %0 = tosa.reshape %arg0 {new_shape = array<i64: 16>} : (tensor<4x?xi32>) -> tensor<?xi32>
-
- // CHECK: %1 = tosa.reshape %arg0 {new_shape = array<i64: -1>} : (tensor<4x?xi32>) -> tensor<?xi32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: -1>} : (tensor<4x?xi32>) -> tensor<?xi32>
-
- // CHECK: %2 = tosa.reshape %arg0 {new_shape = array<i64: 2, -1>} : (tensor<4x?xi32>) -> tensor<2x?xi32>
- %2 = tosa.reshape %arg0 {new_shape = array<i64: 2, -1>} : (tensor<4x?xi32>) -> tensor<?x?xi32>
+ // CHECK: %0 = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %0 = tosa.const_shape {value = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK: %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
+ %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+
+ // CHECK: %2 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %2 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ // CHECK: %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+ %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+
+ // CHECK: %4 = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %4 = tosa.const_shape {value = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
+ %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
index f274eb9c10a815a..947335e45a9d939 100644
--- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
@@ -141,12 +141,14 @@ func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3
// COM: this case is a reshape we don't convert, since can't fold the transpose into it.
// COM: a transform actually occurs underneath the hood, but it results in identical IR.
// CHECK-LABEL: @test_basic_non_broadcasting_reshape
-// CHECK: "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
-// CHECK: tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
-// CHECK: tosa.transpose %1, %0 : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}>
+// CHECK: %[[VAL_3:.*]] = tosa.reshape %arg0, %[[VAL_1]] : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
+// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_3]], %[[VAL_2]] : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+ %shape = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
%perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
+ %1 = tosa.reshape %arg0, %shape : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
%2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
return %2 : tensor<1x2x3xi32>
}
@@ -154,11 +156,13 @@ func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor
// -----
// CHECK-LABEL: @test_dynamic_broadcasting_reshape
-// CHECK: %[[RES:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, -1>} : (tensor<?xi32>) -> tensor<1x1x?xi32>
+// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[1, 1, -1]> : tensor<3xindex>}
+// CHECK: %[[RES:.*]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<?xi32>, !tosa.shape<3>) -> tensor<1x1x?xi32>
// CHECK: return %[[RES]]
func.func @test_dynamic_broadcasting_reshape(%arg0: tensor<?xi32>) -> tensor<1x1x?xi32> {
+ %shape = tosa.const_shape {value = dense<[1, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, -1, 1>} : (tensor<?xi32>) -> tensor<1x?x1xi32>
+ %1 = tosa.reshape %arg0, %shape : (tensor<?xi32>, !tosa.shape<3>) -> tensor<1x?x1xi32>
%2 = tosa.transpose %1, %perms : (tensor<1x?x1xi32>, tensor<3xi32>) -> tensor<1x1x?xi32>
return %2 : tensor<1x1x?xi32>
}
@@ -167,12 +171,14 @@ func.func @test_dynamic_broadcasting_reshape(%arg0: tensor<?xi32>) -> tensor<1x1
// CHECK-LABEL: @test_reshape_for_broadcast
// CHECK-DAG: %[[RESHAPE_INPUT:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4]>
-// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]] {new_shape = array<i64: 4, 1, 1>}
-// CHECK-DAG: %[[ADD:.*]] = tosa.add %arg0, %[[RESHAPE]]
+// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[4, 1, 1]> : tensor<3xindex>}
+// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]], %[[SHAPE]] : (tensor<4xi32>, !tosa.shape<3>) -> tensor<4x1x1xi32>
+// CHECK: %[[ADD:.*]] = tosa.add %arg0, %[[RESHAPE]]
// CHECK: return %[[ADD]]
func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2xi32> {
%0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32>
- %reshape = tosa.reshape %0 {new_shape = array<i64: 1, 1, 4>} : (tensor<4xi32>) -> tensor<1x1x4xi32>
+ %1 = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %reshape = tosa.reshape %0, %1 : (tensor<4xi32>, !tosa.shape<3>) -> tensor<1x1x4xi32>
%perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
%transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32>
%add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32>
@@ -187,25 +193,28 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
// CHECK-LABEL: @test_resnet18_common_case
// COM: note that %74 is now represented by %arg2
-// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
-// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
-// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
-// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
-// CHECK-DAG: %[[VAL_6:.*]] = tosa.add %arg1, %[[VAL_4]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
-// CHECK-DAG: %[[VAL_7:.*]] = tosa.pow %[[VAL_6]], %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
-// CHECK-DAG: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<64xf32>) -> tensor<64xf32>
-// CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
-// CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
-// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
-// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
-// CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
-// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
-// CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
-// CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
-// CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
-// CHECK: return %[[VAL_17]] : tensor<1x112x112x64xf32>
-
+// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = tosa.add %arg1, %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_8:.*]] = tosa.pow %[[VAL_7]], %[[VAL_6]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = tosa.reciprocal %[[VAL_8]] : (tensor<64xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_12:.*]] = tosa.sub %arg2, %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_13]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[VAL_17:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_16]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_19]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_18]], %[[VAL_20]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> {
+ %58 = tosa.const_shape {value = dense<[1, 64, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%59 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
%60 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
%63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
@@ -216,13 +225,13 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
%77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
%78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
- %79 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %79 = tosa.reshape %arg0, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
- %81 = tosa.reshape %78 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %81 = tosa.reshape %78, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%82 = tosa.mul %80, %81 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
- %83 = tosa.reshape %60 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %83 = tosa.reshape %60, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
- %85 = tosa.reshape %59 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
%87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
%88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
@@ -285,7 +294,8 @@ func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>)
// CHECK: return %[[RESHAPE]], %[[CLAMP]]
func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<1x1x64xf32>) {
%0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1>} : (tensor<64xf32>) -> tensor<1x64x1xf32>
+ %shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %1 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32>
%2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
%3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
%4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
@@ -305,7 +315,8 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t
func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<64x1x1xf32>) {
%0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
%1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
- %2 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1>} : (tensor<64xf32>) -> tensor<1x64x1xf32>
+ %shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %2 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32>
%3 = tosa.clamp %2 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
%4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
%5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32>
More information about the Mlir-commits
mailing list