[Mlir-commits] [mlir] [mlir][tosa] Add missing verifier for `tosa.pad` (PR #120934)
Longsheng Mou
llvmlistbot at llvm.org
Fri Jan 3 06:24:36 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/120934
>From 6fc902013f0ac94dfa3ccf482515e3f4a1b84866 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 23 Dec 2024 14:46:50 +0800
Subject: [PATCH] [mlir][tosa] Add missing verifier for `tosa.pad`
This PR adds a missing verifier for `tosa.pad`, ensuring that the padding shape matches [2*rank(input)]
according to V1.0 specification.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 10 ++--
.../Conversion/TosaToTensor/TosaToTensor.cpp | 12 ++---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 12 +++--
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 2 +-
.../Transforms/TosaDecomposeDepthwise.cpp | 2 +-
.../Transforms/TosaDecomposeTransposeConv.cpp | 6 +--
.../TosaToTensor/tosa-to-tensor.mlir | 52 ++++++-------------
mlir/test/Dialect/Tosa/canonicalize.mlir | 27 +++++-----
mlir/test/Dialect/Tosa/invalid.mlir | 30 +++++++----
mlir/test/Dialect/Tosa/ops.mlir | 8 +--
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 4 +-
.../Tosa/tosa-decompose-depthwise.mlir | 4 +-
.../Tosa/tosa-decompose-transpose-conv.mlir | 14 ++---
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 18 +++----
14 files changed, 93 insertions(+), 108 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..157d04db074547 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1552,21 +1552,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Example:
```mlir
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
```
Example 2:
```mlir
- %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
- tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
+ tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
```
}];
let arguments = (ins
Tosa_RankedTensor:$input1,
- Tosa_Int32Or64Tensor:$padding,
+ 1DTensorOf<[Tosa_Int32Or64]>:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
);
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 6f085cb6ed06d2..b5a0da15e780e0 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -338,11 +338,6 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
padOp, "tosa.pad was unable to determine the pad constant value.");
}
- Value lowIndex =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
- Value highIndex =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
-
SmallVector<OpFoldResult, 3> lowValues;
SmallVector<OpFoldResult, 3> highValues;
@@ -350,11 +345,12 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
- Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
+ Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({inputIndex, lowIndex}));
+ loc, padding, ValueRange({lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
- loc, padding, ValueRange({inputIndex, highIndex}));
+ loc, padding, ValueRange({highIndex}));
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), lowVal);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..c7ccf467a21469 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -787,7 +787,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
return success();
}
- outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
+ outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
@@ -823,13 +823,17 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::verify() {
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
- TensorType paddingType = getPadding().getType();
+ RankedTensorType paddingType = getPadding().getType();
if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";
- if (paddingType.hasRank() && paddingType.getRank() != 2)
- return emitOpError() << "expect 'padding' tensor rank equal to 2.";
+ if (!paddingType.isDynamicDim(0) &&
+ paddingType.getDimSize(0) != inputType.getRank() * 2)
+ return emitOpError() << "expected padding tensor dim 0 to have size "
+ << inputType.getRank() * 2
+ << " (2*rank(input)) but got size "
+ << paddingType.getDimSize(0);
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 44f64f76e9b027..04a709c5967795 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -81,7 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
}
}
- auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
+ auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37ab..14f392ab8c45c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -108,7 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
}
}
- auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
+ auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0779cdb9667a1a..fda39c516077d3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -139,7 +139,7 @@ class TransposeConvStridedConverter
weightPadding[5] =
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
@@ -202,7 +202,7 @@ class TransposeConvStridedConverter
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
@@ -314,7 +314,7 @@ class TransposeConvStridedConverter
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
+ RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 1e62e25176a007..0b9a64494bc0f1 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -459,85 +459,65 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// CHECK-LABEL: @pad_float
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
return %1 : tensor<4x9xf32>
}
func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: [[CST:%.+]] = arith.constant 0 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
// -----
func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
%1 = arith.constant dense<42.0> : tensor<f32>
- %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>) -> (tensor<4x9xf32>)
+ %2 = "tosa.pad"(%arg0, %0, %1) : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>) -> (tensor<4x9xf32>)
return %2 : tensor<4x9xf32>
}
// -----
func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
- %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
- // TODO: Output contains multiple "arith.constant 1 : index".
- // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
- // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
- // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
- // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+ %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
- // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
+ // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
- %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..063b0b2095df0b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
// CHECK-LABEL: @pad_noop
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: return %arg0
- %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[PAD:.+]] = tosa.pad
// CHECK: return %[[PAD]]
- %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -234,42 +234,39 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
// CHECK: return %[[PAD]]
%c0_i32 = arith.constant 0 : i32
- %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
+ %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
- %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
+ %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_i32
-func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
+ %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
// -----
// CHECK-LABEL: @pad_determine_val_f32
-func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
+func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_quant
-func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
// CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
- %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
- %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
+ %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..483956c2bd1139 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -77,48 +77,56 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te
// -----
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
// expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
- %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %0 = "tosa.const"() {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xi32>} : () -> tensor<6xi32>
// expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
- %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor<i8>) -> tensor<13x21x3xi8>
+ %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<6xi32>, tensor<i8>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
// -----
-func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
// expected-error at +1 {{'tosa.pad' op expect same input and output tensor rank.}}
- %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
return
}
// -----
-func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) {
- // expected-error at +1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}}
- %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+ // expected-error at +1 {{'tosa.pad' op operand #1 must be 1D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2x2xi32>'}}
+ %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21xf32>
return
}
// -----
-func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
%0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}}
- %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<2x2xi32>, tensor<1xf32>) -> tensor<13x21xf32>
+ %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<4xi32>, tensor<1xf32>) -> tensor<13x21xf32>
return
}
// -----
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<4xi32>) -> tensor<13x21x3xf32> {
+ // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(input)) but got size 4}}
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
// expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 88fa94ae90db69..a8c86960a6c86f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -525,16 +525,16 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
// -----
// CHECK-LABEL: pad
-func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
- %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: pad_explicit_value
-func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
%0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
- %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<13x21x3xf32>
+ %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<6xi32>, tensor<f32>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index d876ccfb3b9110..fc9c947e203c4f 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -58,9 +58,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
// CHECK-LABEL: @conv2d_as_fully_connected_padded
func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
- // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>}
+ // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>}
// CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
- // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<4x2xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
+ // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
// CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 2224bf3f57b255..0df299080d8512 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -46,10 +46,10 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
// CHECK-LABEL: @depthwise_conv2d_as_mul_padded
func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
- // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0], [0, 0]]> : tensor<5x2xi64>}
+ // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xi64>}
// CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
// CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
- // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+ // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<10xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
// CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
// CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
// CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 1f2bb3fb9a3657..893ec4a7de65db 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -44,7 +44,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
// CHECK-LABEL: @transpose_conv2d_strided
func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
// Manipulate the weight matrix to handle striding.
- // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]]
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -54,7 +54,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// Pad out the input matrix to handle the transpose conv.
- // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
@@ -77,7 +77,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_strided_quantized
func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
// Manipulate the weight matrix to handle striding.
- // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -87,7 +87,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// Pad out the input matrix to handle the transpose conv.
- // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
@@ -108,12 +108,12 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-LABEL: @transpose_conv2d_strided_overpad
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
- // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32>
+ // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xi32>
// CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi32>}
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xi32>}
// CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..7daf46e375e12f 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -495,9 +495,9 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
// -----
// CHECK-LABEL: @test_padding_no_const
-func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
- // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
- %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<4xi32>) -> () {
+ // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+ %0 = tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return
}
@@ -505,9 +505,9 @@ func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32
// CHECK-LABEL:@test_padding_dynamic_input
func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<4x?xf32>
- %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<4x?xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return
}
@@ -515,9 +515,9 @@ func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
// CHECK-LABEL: @test_padding_simple
func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
- %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<4x9xf32>
- %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+ %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<4x9xf32>
+ %1 = tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
return
}
More information about the Mlir-commits
mailing list