[Mlir-commits] [mlir] [TOSA] Change PadOp padding to tosa.shape (PR #123133)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 16 16:33:04 PST 2025
https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/123133
>From 9f9ddd471e3249463050734c89a3dfd94c6829c5 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Wed, 24 Jan 2024 23:49:53 +0000
Subject: [PATCH] [TOSA] Change PadOp padding to tosa.shape
This patch changes PadOp's padding input to type !tosa.shape<2 * rank>,
(where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Signed-off-by: Jerry Ge <Jerry.Ge at arm.com>
Change-Id: I08526a699d6b8ebbaf9ee092cd37580e5d78f919
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 10 ++--
.../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 8 +++
.../Conversion/TosaToTensor/TosaToTensor.cpp | 27 +++++-----
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 50 ++++++++----------
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 6 +--
.../Transforms/TosaDecomposeDepthwise.cpp | 6 +--
.../Transforms/TosaDecomposeTransposeConv.cpp | 29 ++++-------
.../Dialect/Tosa/Utils/ConversionUtils.cpp | 33 ++++++++++++
.../TosaToTensor/tosa-to-tensor.mlir | 51 +++++++++++++------
mlir/test/Dialect/Tosa/canonicalize.mlir | 45 ++++++++--------
mlir/test/Dialect/Tosa/invalid.mlir | 38 +++++++-------
mlir/test/Dialect/Tosa/ops.mlir | 10 ++--
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 4 +-
.../Tosa/tosa-decompose-depthwise.mlir | 4 +-
.../Tosa/tosa-decompose-transpose-conv.mlir | 14 ++---
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 20 +++-----
16 files changed, 196 insertions(+), 159 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e1efa7a3001b9f..2953e006bbe8d1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Example:
```mlir
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
- tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
+ %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
```
Example 2:
```mlir
- %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
- tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
+ %0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
```
}];
let arguments = (ins
Tosa_RankedTensor:$input1,
- TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
+ Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
);
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 90fea1f68beb58..2407d71d4cbb55 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -229,6 +229,14 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
return permuted;
}
+// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
+ llvm::ArrayRef<int64_t> shape);
+SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
+
+bool ExtractConstShapeValue(Operation *op,
+ llvm::SmallVector<int64_t> &result_shape);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index b5a0da15e780e0..5aa0269a675cbe 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
ConversionPatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
- auto padding = padOp.getPadding();
+
+ ElementsAttr paddingElems;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+ return rewriter.notifyMatchFailure(
+ padOp, "padding must be a static shape value");
+ }
+ llvm::SmallVector<int64_t> paddingVals;
+ for (auto idx : paddingElems.getValues<IntegerAttr>()) {
+ paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
+ }
ShapedType inputTy = cast<ShapedType>(input.getType());
Type elementTy = inputTy.getElementType();
@@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
- Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
- Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
- Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({lowIndex}));
- Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({highIndex}));
-
- lowVal = rewriter.createOrFold<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), lowVal);
- highVal = rewriter.createOrFold<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), highVal);
-
+ Value lowVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(paddingVals[2 * i]));
+ Value highVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 83cf4a9415fe68..f490b420eb9978 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -36,6 +36,7 @@ using namespace mlir;
using namespace mlir::tosa;
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
//===----------------------------------------------------------------------===//
// Tosa dialect interface includes.
@@ -823,51 +824,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
PadOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
- ShapeAdaptor paddingShape(adaptor.getPadding().getType());
+ auto paddingRank =
+ cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
SmallVector<int64_t> outputShape;
- // If both inputs have unknown shape, we cannot determine the shape of the
- // output.
- if (!inputShape.hasRank() && !paddingShape.hasRank()) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
- return success();
- }
-
- // If the input rank is unknown we can info the output rank using the
- // padding shape's first dim.
+ // If the input rank is unknown, we can infer the output rank using the
+ // padding shape's rank divided by 2.
if (!inputShape.hasRank()) {
- if (paddingShape.isDynamicDim(0)) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
- return success();
- }
-
- outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
+ outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
- DenseIntElementsAttr paddings;
+ SmallVector<int64_t> paddingValues;
// If the paddings value is not a constant, all dimensions must be dynamic.
- if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
+ if (!tosa::ExtractConstShapeValue(adaptor.getPadding().getDefiningOp(),
+ paddingValues)) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
- SmallVector<int64_t> paddingValues;
- for (auto val : paddings) {
- paddingValues.push_back(val.getSExtValue());
- }
-
outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
if (inputShape.isDynamicDim(i)) {
outputShape.push_back(ShapedType::kDynamic);
continue;
}
+ auto padFront = paddingValues[i * 2];
+ auto padBack = paddingValues[i * 2 + 1];
+ if (padFront < 0 || padBack < 0) {
+ // if either padding for dim i is -1, output dim is unknown
+ outputShape.push_back(ShapedType::kDynamic);
+ continue;
+ }
- outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
- paddingValues[i * 2 + 1]);
+ outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -877,17 +869,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::verify() {
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
- RankedTensorType paddingType = getPadding().getType();
+ auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";
- if (!paddingType.isDynamicDim(0) &&
- paddingType.getDimSize(0) != inputType.getRank() * 2)
+ if (paddingRank != inputType.getRank() * 2)
return emitOpError() << "expected padding tensor dim 0 to have size "
<< inputType.getRank() * 2
- << " (2*rank(shape1)) but got size "
- << paddingType.getDimSize(0);
+ << " (2*rank(shape1)) but got size " << paddingRank;
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 04a709c5967795..cb08360f902286 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
}
}
- auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
- auto padSize =
- DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
- Value padSizeVal =
- rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+ Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
auto padTy = RankedTensorType::get({}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 14f392ab8c45c1..45f4419875b485 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
}
}
- auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
- auto padSize =
- DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
- Value padSizeVal =
- rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+ Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
auto padTy = RankedTensorType::get({}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index db1e219b601b30..1b97f0b245d9ba 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -135,15 +135,14 @@ class TransposeConvStridedConverter
int64_t inputChannels = weightTy.getDimSize(3);
// Pad the weight so that it is modulo of the striding.
- llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
weightPadding[3] =
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
weightPadding[5] =
- (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
- DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
- Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
- rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
+ weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
+
+ Value weightPaddingVal =
+ getTosaConstShape(rewriter, op->getLoc(), weightPadding);
if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
@@ -197,17 +196,14 @@ class TransposeConvStridedConverter
/* axis = */ rewriter.getI32IntegerAttr(2));
// We need to pad the input far enough that we can pull all values.
- llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
- DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
-
- Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
- rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
+ Value inputPaddingVal =
+ getTosaConstShape(rewriter, op->getLoc(), inputPadding);
if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
@@ -310,17 +306,14 @@ class TransposeConvStridedConverter
rewriter.getDenseI64ArrayAttr(sliceSize))
.getResult();
- llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
resultPadding[2] = resultPadTop;
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
resultPadding[4] = resultPadLeft;
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
- DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
-
- Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
- rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
+ Value resultPaddingVal =
+ getTosaConstShape(rewriter, op->getLoc(), resultPadding);
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 1f6e3b2ab83919..ab378ecfd1e371 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -160,3 +160,36 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
return success();
}
+
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+ llvm::ArrayRef<int64_t> shape) {
+ auto attr = rewriter.getIndexTensorAttr(shape);
+ auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
+ mlir::Operation *mlir_op =
+ rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+ return mlir_op->getResult(0);
+}
+
+SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
+ return to_vector(llvm::map_range(shape, [](int64_t dim) {
+ return ShapedType::isDynamic(dim) ? -1 : dim;
+ }));
+}
+
+bool mlir::tosa::ExtractConstShapeValue(
+ Operation *op, llvm::SmallVector<int64_t> &result_shape) {
+ if (!op) {
+ return false;
+ }
+ if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
+ Attribute constOpAttr = constOp->getAttr("value");
+ DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
+ for (int i = 0; i < elementsAttr.size(); i++) {
+ int64_t val = elementsAttr.getValues<int64_t>()[i];
+ result_shape.push_back(val);
+ }
+ return true;
+ }
+ // for undefined op, return false.
+ return false;
+}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 0b9a64494bc0f1..2f11b31aad2307 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -459,65 +459,84 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// CHECK-LABEL: @pad_float
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
return %1 : tensor<4x9xf32>
}
+// -----
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
+// -----
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
// -----
func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
%1 = arith.constant dense<42.0> : tensor<f32>
- %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>) -> (tensor<4x9xf32>)
+ %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>) -> (tensor<4x9xf32>)
return %2 : tensor<4x9xf32>
}
// -----
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
+// -----
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
+ %0 = tosa.const_shape {value = dense<[-1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK-DAG: [[INDEX1:%.+]] = arith.constant -1 : index
+ // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 889e2eda9e5b84..e394188e9a9311 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
// CHECK-LABEL: @pad_noop
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: return %arg0
- %0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+ %0 = tosa.const_shape { value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
- %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+ %shape = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.pad %arg0, %shape : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -232,41 +232,44 @@ func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x
func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
-
- %c0_i32 = arith.constant 0 : i32
- %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
-
- %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
+ %shape = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, !tosa.shape<2>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_i32
-func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
- // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
- // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
+func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+ // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+ %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
// -----
// CHECK-LABEL: @pad_determine_val_f32
-func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
- // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
- // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+ // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+ %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_quant
-func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
- // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
- // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
+func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+ // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+ %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.pad %arg0, %0 {input_zp = 42 : i32} : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cc7fd009f01fa6..e58a45ca80990e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,52 +165,56 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
- // expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
+ // expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
- %0 = "tosa.const"() {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xi32>} : () -> tensor<6xi32>
+ %0 = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
- %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<6xi32>, tensor<i8>) -> tensor<13x21x3xi8>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<i8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
// -----
-func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
+func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
+ %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error at +1 {{'tosa.pad' op expect same input and output tensor rank.}}
- %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return
}
// -----
-func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
- // expected-error at +1 {{'tosa.pad' op operand #1 must be 1D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2x2xi32>'}}
- %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
+ %0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
+ %1 = tosa.pad %arg0, %0 : (tensor<13x21xf32>, !tosa.shape<6>) -> tensor<13x21xf32>
return
}
// -----
-func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
- %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}}
- %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<4xi32>, tensor<1xf32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+ %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = "tosa.const"() {value = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
+ // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<2xf32>'}}
+ %2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
return
}
// -----
-func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<4xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
+ return %1 : tensor<13x21x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 690e208af1e5f9..563c5fa457d351 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -525,16 +525,18 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
// -----
// CHECK-LABEL: pad
-func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: pad_explicit_value
-func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
- %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<6xi32>, tensor<f32>) -> tensor<13x21x3xf32>
+ %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 8df4630f9c17ff..95d9bb1b98ab74 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -58,9 +58,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
// CHECK-LABEL: @conv2d_as_fully_connected_padded
func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
- // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>}
+ // CHECK-DAG: %[[PAD_SHAPE:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
- // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
+ // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
// CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index cfff6396ad486d..bbcc206e1490c7 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -46,10 +46,10 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
// CHECK-LABEL: @depthwise_conv2d_as_mul_padded
func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
- // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xi64>}
+ // CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
// CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
// CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
- // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<10xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+ // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
// CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
// CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
// CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index c361c7c2899fc3..96f71c349938b9 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -45,7 +45,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
// CHECK-LABEL: @transpose_conv2d_strided
func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
// Manipulate the weight matrix to handle striding.
- // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
+ // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]]
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -55,7 +55,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// Pad out the input matrix to handle the transpose conv.
- // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
+ // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
@@ -78,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_strided_quantized
func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
// Manipulate the weight matrix to handle striding.
- // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
+ // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -88,7 +88,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// Pad out the input matrix to handle the transpose conv.
- // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
+ // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
@@ -109,12 +109,12 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-LABEL: @transpose_conv2d_strided_overpad
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
- // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xi32>
+ // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi32>}
+ // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xi32>}
+ // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f4da66ef561b26..314cccab1d709c 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -492,22 +492,14 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
return
}
-// -----
-
-// CHECK-LABEL: @test_padding_no_const
-func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<4xi32>) -> () {
- // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
- %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
- return
-}
// -----
// CHECK-LABEL:@test_padding_dynamic_input
func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
- // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<4x?xf32>
- %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+ %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ // CHECK: tosa.pad %arg0, %0 : (tensor<1x?xf32>, !tosa.shape<4>) -> tensor<4x?xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
return
}
@@ -515,9 +507,9 @@ func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
// CHECK-LABEL: @test_padding_simple
func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
- %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
- // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<4x9xf32>
- %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+ %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ // CHECK: tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> tensor<4x9xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
return
}
More information about the Mlir-commits
mailing list