[Mlir-commits] [mlir] [mlir][tosa] Add missing verifier for `tosa.pad` (PR #120934)

Longsheng Mou llvmlistbot at llvm.org
Fri Jan 3 05:11:49 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/120934

>From 7e67d67713b6ee74c09c2be9b5b57d7ad7cc98ea Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 23 Dec 2024 14:46:50 +0800
Subject: [PATCH] [mlir][tosa] Add missing verifier for `tosa.pad`

This PR adds a missing verifier for `tosa.pad`, ensuring that the padding shape matches [2*rank(input)]
according to V1.0 specification.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 10 +++----
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 14 ++++-----
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 11 ++++---
 .../Tosa/Transforms/TosaDecomposeConv2D.cpp   |  2 +-
 .../Transforms/TosaDecomposeDepthwise.cpp     |  2 +-
 .../Transforms/TosaDecomposeTransposeConv.cpp |  2 +-
 .../TosaToTensor/tosa-to-tensor.mlir          | 24 +++++++--------
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 27 ++++++++---------
 mlir/test/Dialect/Tosa/invalid.mlir           | 30 ++++++++++++-------
 mlir/test/Dialect/Tosa/ops.mlir               |  8 ++---
 .../Dialect/Tosa/tosa-decompose-conv2d.mlir   |  4 +--
 .../Tosa/tosa-decompose-depthwise.mlir        |  4 +--
 .../Tosa/tosa-decompose-transpose-conv.mlir   | 14 ++++-----
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 18 +++++------
 14 files changed, 88 insertions(+), 82 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..09b52692974dd7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1552,21 +1552,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Example:
 
     ```mlir
-    %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)
+    %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
+    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<4x9xf32>)
     ```
 
     Example 2:
 
     ```mlir
-    %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
-    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+    %0 = arith.constant dense<[[-1, 2, 3, 4]]> : tensor<4xi32>
+    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
     ```
   }];
 
   let arguments = (ins
     Tosa_RankedTensor:$input1,
-    Tosa_Int32Or64Tensor:$padding,
+    1DTensorOf<[Tosa_Int32Or64]>:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
     OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
   );
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 6f085cb6ed06d2..dd180f4b1f9886 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -338,11 +338,6 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
           padOp, "tosa.pad was unable to determine the pad constant value.");
     }
 
-    Value lowIndex =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
-    Value highIndex =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
-
     SmallVector<OpFoldResult, 3> lowValues;
     SmallVector<OpFoldResult, 3> highValues;
 
@@ -350,11 +345,14 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
     highValues.reserve(rank);
 
     for (int i = 0; i < rank; i++) {
-      Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
+      Value lowIndex =
+          rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
+      Value highIndex =
+          rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
       Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({inputIndex, lowIndex}));
+          loc, padding, ValueRange({lowIndex}));
       Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({inputIndex, highIndex}));
+          loc, padding, ValueRange({highIndex}));
 
       lowVal = rewriter.createOrFold<arith::IndexCastOp>(
           loc, rewriter.getIndexType(), lowVal);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..d5a37eef76e2b9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -787,7 +787,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
       return success();
     }
 
-    outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
+    outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
@@ -823,13 +823,16 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 LogicalResult tosa::PadOp::verify() {
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
-  TensorType paddingType = getPadding().getType();
+  RankedTensorType paddingType = getPadding().getType();
 
   if (inputType.getRank() != outputType.getRank())
     return emitOpError() << "expect same input and output tensor rank.";
 
-  if (paddingType.hasRank() && paddingType.getRank() != 2)
-    return emitOpError() << "expect 'padding' tensor rank equal to 2.";
+  if (!paddingType.isDynamicDim(0) &&
+      paddingType.getDimSize(0) != inputType.getRank() * 2)
+    return emitOpError() << "expected padding tensor dim 0 to have size "
+                         << inputType.getRank() << " (2*rank(input)) but got size "
+                         << paddingType.getDimSize(0);
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 44f64f76e9b027..04a709c5967795 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -81,7 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
         }
       }
 
-      auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
+      auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
       auto padSize =
           DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
       Value padSizeVal =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37ab..14f392ab8c45c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -108,7 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
         }
       }
 
-      auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
+      auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
       auto padSize =
           DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
       Value padSizeVal =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0779cdb9667a1a..a574bd9643db0e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -139,7 +139,7 @@ class TransposeConvStridedConverter
     weightPadding[5] =
         (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
     DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+        RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
     Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
         rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
 
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 1e62e25176a007..bbf87d1a020589 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -459,7 +459,7 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
 // CHECK-LABEL: @pad_float
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
   // TODO: Output contains multiple "arith.constant 1 : index".
   // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
   // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
@@ -469,32 +469,32 @@ func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<4x9xf32>)
   return %1 : tensor<4x9xf32>
 }
 
 func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
   // CHECK: [[CST:%.+]] = arith.constant 0 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<4xi32>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
 
 func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
   // CHECK: [[CST:%.+]] = arith.constant 42 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
 
 // -----
 
 func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
   // TODO: Output contains multiple "arith.constant 1 : index".
   // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
   // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
@@ -505,14 +505,14 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
   %1 = arith.constant dense<42.0> : tensor<f32>
-  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
   return %2 : tensor<4x9xf32>
 }
 
 // -----
 
 func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
   // TODO: Output contains multiple "arith.constant 1 : index".
   // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
   // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
@@ -522,12 +522,12 @@ func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
   // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
   return %1 : tensor<?x9xf32>
 }
 
 func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
+  %0 = arith.constant dense<[[-1, 2, 3, 4]]> : tensor<4xi32>
   // TODO: Output contains multiple "arith.constant 1 : index".
   // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
   // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
@@ -537,7 +537,7 @@ func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
   // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
   return %1 : tensor<?x9xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..063b0b2095df0b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
 // CHECK-LABEL: @pad_noop
 func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %arg0
-  %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+  %0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
+  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: %[[PAD:.+]] = tosa.pad
   // CHECK: return %[[PAD]]
-  %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+  %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
+  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -234,42 +234,39 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
   // CHECK: return %[[PAD]]
 
   %c0_i32 = arith.constant 0 : i32
-  %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
+  %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
 
-  %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
+  %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @pad_determine_val_i32
-func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
   // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
-  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-  %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
+  %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
 
 // -----
 
 // CHECK-LABEL: @pad_determine_val_f32
-func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
+func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
   // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
-  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-  %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+  %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @pad_determine_val_quant
-func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
   // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
-  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
-  %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
+  %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..3b27dd758e4164 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -77,48 +77,56 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
   // expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
-  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
-  %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+  %0 = "tosa.const"() {value = dense<[[0, 0, 0, 1, 0, 1]]> : tensor<6xi32>} : () -> tensor<6xi32>
   // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
-  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor<i8>) -> tensor<13x21x3xi8>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<6xi32>, tensor<i8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
 
 // -----
 
-func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
   // expected-error at +1 {{'tosa.pad' op expect same input and output tensor rank.}}
-  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32>
+  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
   return
 }
 
 // -----
 
-func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) {
-  // expected-error at +1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}}
-  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+  // expected-error at +1 {{'tosa.pad' op operand #1 must be 1D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2x2xi32>'}}
+  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21xf32>
   return
 }
 
 // -----
 
-func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
   %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}}
-  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<2x2xi32>, tensor<1xf32>) -> tensor<13x21xf32>
+  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<4xi32>, tensor<1xf32>) -> tensor<13x21xf32>
   return
 }
 
 // -----
 
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<4xi32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(input)) but got size 4}}
+  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
 func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
   // expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
   %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 88fa94ae90db69..a8c86960a6c86f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -525,16 +525,16 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
 
 // -----
 // CHECK-LABEL: pad
-func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 // CHECK-LABEL: pad_explicit_value
-func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
   %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
-  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<13x21x3xf32>
+  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<6xi32>, tensor<f32>) -> tensor<13x21x3xf32>
   return %1 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index d876ccfb3b9110..f322792ab8c9f5 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -58,9 +58,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
 
 // CHECK-LABEL: @conv2d_as_fully_connected_padded
 func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
-  // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>}
+  // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 1, 1, 1, 1, 0, 0]]> : tensor<8xi64>}
   // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
-  // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<4x2xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
+  // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
   // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
   // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
   // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 2224bf3f57b255..8a5c70781ba4b2 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -46,10 +46,10 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
 
 // CHECK-LABEL: @depthwise_conv2d_as_mul_padded
 func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
-  // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0], [0, 0]]> : tensor<5x2xi64>}
+  // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 1, 1, 1, 1, 0, 0, 0, 0]]> : tensor<10xi64>}
   // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
   // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
-  // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+  // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<10xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
   // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
   // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
   // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 1f2bb3fb9a3657..8ac6ed4789ac17 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -44,7 +44,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
 // CHECK-LABEL: @transpose_conv2d_strided
 func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
   // Manipulate the weight matrix to handle striding.
-  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 1, 0, 1, 0, 0]]> : tensor<8xi32>}
   // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]]
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -54,7 +54,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
 
   // Pad out the input matrix to handle the transpose conv.
-  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 1, 1, 1, 1, 0, 0]]> : tensor<8xi32>}
   // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
 
@@ -77,7 +77,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
 // CHECK-LABEL: @transpose_conv2d_strided_quantized
 func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
   // Manipulate the weight matrix to handle striding.
-  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 1, 0, 1, 0, 0]]> : tensor<8xi32>}
   // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -87,7 +87,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
 
   // Pad out the input matrix to handle the transpose conv.
-  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 1, 1, 1, 1, 0, 0]]> : tensor<8xi32>}
   // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
 
@@ -108,12 +108,12 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
 
 // CHECK-LABEL: @transpose_conv2d_strided_overpad
 func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
-  // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32>
+  // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0, 0, 0, 0, 1, 0, 0]]> : tensor<8xi32>
   // CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
-  // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0, 1, 1, 0, 0, 0, 0]]> : tensor<8xi32>}
   // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
   // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
-  // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}[0, 0, 2, 0, 0, 0, 0, 0]]> : tensor<8xi32>}
   // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
   // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
   // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..b3c1bfe03af622 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -495,9 +495,9 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
 // -----
 
 // CHECK-LABEL: @test_padding_no_const
-func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
-  // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
-  %0 = tosa.pad %arg0, %arg1  : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<4xi32>) -> () {
+  // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+  %0 = tosa.pad %arg0, %arg1  : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
   return
 }
 
@@ -505,9 +505,9 @@ func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32
 
 // CHECK-LABEL:@test_padding_dynamic_input
 func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<4x?xf32>
-  %1 = tosa.pad %arg0, %0  : (tensor<1x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
+  // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<4x?xf32>
+  %1 = tosa.pad %arg0, %0  : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
   return
 }
 
@@ -515,9 +515,9 @@ func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
 
 // CHECK-LABEL: @test_padding_simple
 func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<4x9xf32>
-  %1 = tosa.pad %arg0, %0  : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
+  %0 = arith.constant dense<[[1, 2, 3, 4]]> : tensor<4xi32>
+  // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<4x9xf32>
+  %1 = tosa.pad %arg0, %0  : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
   return
 }
 



More information about the Mlir-commits mailing list