[Mlir-commits] [mlir] [TOSA] Change PadOp padding to tosa.shape (PR #123133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 15 14:33:02 PST 2025


https://github.com/Jerry-Ge created https://github.com/llvm/llvm-project/pull/123133

This patch changes PadOp's padding input to type !tosa.shape<2 * rank>, (where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor.

This patch is also a part of TOSA v1.0 effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

>From 6291f0846288e9899c72aac99b3a209d44371a2d Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Wed, 24 Jan 2024 23:49:53 +0000
Subject: [PATCH] [TOSA] Change PadOp padding to tosa.shape

This patch changes PadOp's padding input to type !tosa.shape<2 * rank>,
(where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I08526a699d6b8ebbaf9ee092cd37580e5d78f919
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  2 +
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 10 +--
 .../mlir/Dialect/Tosa/Utils/ConversionUtils.h |  5 ++
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 27 +++----
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 70 +++++++++++--------
 .../Tosa/Transforms/TosaDecomposeConv2D.cpp   |  6 +-
 .../Transforms/TosaDecomposeDepthwise.cpp     |  6 +-
 .../Transforms/TosaDecomposeTransposeConv.cpp | 29 +++-----
 .../Dialect/Tosa/Utils/ConversionUtils.cpp    | 14 ++++
 .../TosaToTensor/tosa-to-tensor.mlir          | 51 +++++++++-----
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 45 ++++++------
 mlir/test/Dialect/Tosa/invalid.mlir           | 38 +++++-----
 mlir/test/Dialect/Tosa/ops.mlir               | 10 +--
 .../Dialect/Tosa/tosa-decompose-conv2d.mlir   |  4 +-
 .../Tosa/tosa-decompose-depthwise.mlir        |  4 +-
 .../Tosa/tosa-decompose-transpose-conv.mlir   | 14 ++--
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 20 ++----
 17 files changed, 196 insertions(+), 159 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index e4f5d09064cd75..d54d4e17de8b54 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -39,6 +39,8 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
 void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
                      Attribute attr);
 
+bool collectShapeValue(Operation* op, llvm::SmallVector<int64_t>& newShape);
+
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
 } // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e1efa7a3001b9f..2953e006bbe8d1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Example:
 
     ```mlir
-    %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<4x9xf32>)
+    %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+    tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<4x9xf32>)
     ```
 
     Example 2:
 
     ```mlir
-    %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
-    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
+    %0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+    tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<?x9xf32>)
     ```
   }];
 
   let arguments = (ins
     Tosa_RankedTensor:$input1,
-    TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
+    Tosa_Shape:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
     OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
   );
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 90fea1f68beb58..9b406f1083135c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -229,6 +229,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
   return permuted;
 }
 
+// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(PatternRewriter& rewriter, Location loc,
+                    llvm::ArrayRef<int64_t> shape);
+SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index b5a0da15e780e0..5aa0269a675cbe 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = padOp.getLoc();
     auto input = padOp.getInput1();
-    auto padding = padOp.getPadding();
+
+    ElementsAttr paddingElems;
+    if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+      return rewriter.notifyMatchFailure(
+          padOp, "padding must be a static shape value");
+    }
+    llvm::SmallVector<int64_t> paddingVals;
+    for (auto idx : paddingElems.getValues<IntegerAttr>()) {
+      paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
+    }
 
     ShapedType inputTy = cast<ShapedType>(input.getType());
     Type elementTy = inputTy.getElementType();
@@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
     highValues.reserve(rank);
 
     for (int i = 0; i < rank; i++) {
-      Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
-      Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
-      Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({lowIndex}));
-      Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({highIndex}));
-
-      lowVal = rewriter.createOrFold<arith::IndexCastOp>(
-          loc, rewriter.getIndexType(), lowVal);
-      highVal = rewriter.createOrFold<arith::IndexCastOp>(
-          loc, rewriter.getIndexType(), highVal);
-
+      Value lowVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(paddingVals[2 * i]));
+      Value highVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
       lowValues.push_back(lowVal);
       highValues.push_back(highVal);
     }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 83cf4a9415fe68..ee0f31652ba36a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -210,6 +210,26 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA shape inference helper
+//===----------------------------------------------------------------------===//
+bool mlir::tosa::collectShapeValue(Operation* op, llvm::SmallVector<int64_t>& newShape) {
+  if (!op) {
+    return false;
+  }
+  if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
+    Attribute constOpAttr = constOp->getAttr("value");
+    DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
+    for (int i = 0; i < elementsAttr.size(); i++) {
+      int64_t val = elementsAttr.getValues<int64_t>()[i];
+      newShape.push_back(val);
+    }
+    return true;
+  }
+  // for undefined op, return false.
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -823,51 +843,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
     PadOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
-  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
+  auto paddingRank =
+      cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
   SmallVector<int64_t> outputShape;
 
-  // If both inputs have unknown shape, we cannot determine the shape of the
-  // output.
-  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
-    return success();
-  }
-
-  // If the input rank is unknown we can info the output rank using the
-  // padding shape's first dim.
+  // If the input rank is unknown, we can infer the output rank using the
+  // padding shape's rank divided by 2.
   if (!inputShape.hasRank()) {
-    if (paddingShape.isDynamicDim(0)) {
-      inferredReturnShapes.push_back(ShapedTypeComponents());
-      return success();
-    }
-
-    outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
+    outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
-  DenseIntElementsAttr paddings;
+  SmallVector<int64_t> paddingValues;
   // If the paddings value is not a constant, all dimensions must be dynamic.
-  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
+  if (!tosa::collectShapeValue(adaptor.getPadding().getDefiningOp(),
+                               paddingValues)) {
     outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
-  SmallVector<int64_t> paddingValues;
-  for (auto val : paddings) {
-    paddingValues.push_back(val.getSExtValue());
-  }
-
   outputShape.reserve(inputShape.getRank());
   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
     if (inputShape.isDynamicDim(i)) {
       outputShape.push_back(ShapedType::kDynamic);
       continue;
     }
+    auto padFront = paddingValues[i * 2];
+    auto padBack = paddingValues[i * 2 + 1];
+    if (padFront < 0 || padBack < 0) {
+      // if either padding for dim i is -1, output dim is unknown
+      outputShape.push_back(ShapedType::kDynamic);
+      continue;
+    }
 
-    outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
-                          paddingValues[i * 2 + 1]);
+    outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -877,17 +888,16 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 LogicalResult tosa::PadOp::verify() {
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
-  RankedTensorType paddingType = getPadding().getType();
+  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
 
   if (inputType.getRank() != outputType.getRank())
     return emitOpError() << "expect same input and output tensor rank.";
-
-  if (!paddingType.isDynamicDim(0) &&
-      paddingType.getDimSize(0) != inputType.getRank() * 2)
+  
+  if (paddingRank != inputType.getRank() * 2)
     return emitOpError() << "expected padding tensor dim 0 to have size "
                          << inputType.getRank() * 2
                          << " (2*rank(shape1)) but got size "
-                         << paddingType.getDimSize(0);
+                         << paddingRank;
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 04a709c5967795..cb08360f902286 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
         }
       }
 
-      auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
-      auto padSize =
-          DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
-      Value padSizeVal =
-          rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+      Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
 
       auto padTy = RankedTensorType::get({}, inputETy);
       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 14f392ab8c45c1..45f4419875b485 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
         }
       }
 
-      auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
-      auto padSize =
-          DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
-      Value padSizeVal =
-          rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+      Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
 
       auto padTy = RankedTensorType::get({}, inputETy);
       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index db1e219b601b30..1b97f0b245d9ba 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -135,15 +135,14 @@ class TransposeConvStridedConverter
     int64_t inputChannels = weightTy.getDimSize(3);
 
     // Pad the weight so that it is modulo of the striding.
-    llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
     weightPadding[3] =
         (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
     weightPadding[5] =
-        (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
-    DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
-    Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
-        rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
+        weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
+
+    Value weightPaddingVal =
+        getTosaConstShape(rewriter, op->getLoc(), weightPadding);
 
     if (op.getQuantizationInfo().has_value()) {
       auto quantInfo = op.getQuantizationInfo().value();
@@ -197,17 +196,14 @@ class TransposeConvStridedConverter
         /* axis = */ rewriter.getI32IntegerAttr(2));
 
     // We need to pad the input far enough that we can pull all values.
-    llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
 
-    DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
-
-    Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
-        rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
+    Value inputPaddingVal =
+        getTosaConstShape(rewriter, op->getLoc(), inputPadding);
 
     if (op.getQuantizationInfo().has_value()) {
       auto quantInfo = op.getQuantizationInfo().value();
@@ -310,17 +306,14 @@ class TransposeConvStridedConverter
                      rewriter.getDenseI64ArrayAttr(sliceSize))
                      .getResult();
 
-    llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
     resultPadding[2] = resultPadTop;
     resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
     resultPadding[4] = resultPadLeft;
     resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
 
-    DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
-
-    Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
-        rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
+    Value resultPaddingVal =
+        getTosaConstShape(rewriter, op->getLoc(), resultPadding);
 
     Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), slice,
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 1f6e3b2ab83919..3a56ec2a10bcf6 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -160,3 +160,17 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
 
   return success();
 }
+
+Value mlir::tosa::getTosaConstShape(PatternRewriter& rewriter, Location loc,
+                    llvm::ArrayRef<int64_t> shape) {
+  auto attr = rewriter.getIndexTensorAttr(shape);
+  auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
+  mlir::Operation *mlir_op = rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+  return mlir_op->getResult(0);
+}
+
+SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
+  return to_vector(llvm::map_range(shape, [](int64_t dim) {
+    return ShapedType::isDynamic(dim) ? -1 : dim;
+  }));
+}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 0b9a64494bc0f1..2f11b31aad2307 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -459,65 +459,84 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
 // CHECK-LABEL: @pad_float
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<4x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<4x9xf32>)
   return %1 : tensor<4x9xf32>
 }
+// -----
 
 func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: [[CST:%.+]] = arith.constant 0 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<4xi32>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, !tosa.shape<4>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
+// -----
 
 func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: [[CST:%.+]] = arith.constant 42 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
 
 // -----
 
 func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
   %1 = arith.constant dense<42.0> : tensor<f32>
-  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>)  -> (tensor<4x9xf32>)
   return %2 : tensor<4x9xf32>
 }
 
 // -----
 
 func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, !tosa.shape<4>)  -> (tensor<?x9xf32>)
   return %1 : tensor<?x9xf32>
 }
+// -----
 
 func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[-1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant -1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<?x9xf32>)
   return %1 : tensor<?x9xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 889e2eda9e5b84..e394188e9a9311 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -210,8 +210,8 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
 // CHECK-LABEL: @pad_noop
 func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %arg0
-  %0 = "tosa.const"() { value = dense<0> : tensor<4xi32>} : () -> tensor<4xi32>
-  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+  %0 = tosa.const_shape { value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -221,8 +221,8 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: %[[PAD:.+]] = tosa.pad
   // CHECK: return %[[PAD]]
-  %0 = "tosa.const"() { value = dense_resource<__elided__> : tensor<4xi32>} : () -> tensor<4xi32>
-  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+  %shape = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.pad %arg0, %shape : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -232,41 +232,44 @@ func.func @pad_noop_padding_mismatch_nofold(%arg0: tensor<?x?xf32>) -> tensor<?x
 func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
   // CHECK: %[[PAD:.+]] = tosa.pad
   // CHECK: return %[[PAD]]
-
-  %c0_i32 = arith.constant 0 : i32
-  %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<2xi32>
-
-  %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<2xi32>) -> tensor<?xf32>
+  %shape = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, !tosa.shape<2>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @pad_determine_val_i32
-func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
-  // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
-  %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
+func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+  %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.pad %arg0, %0 : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
 
 // -----
 
 // CHECK-LABEL: @pad_determine_val_f32
-func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<4xi32>) -> tensor<?x?xf32> {
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
-  // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
-  %1 = tosa.pad %arg0, %arg1 : (tensor<?x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+  %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.pad %arg0, %0 : (tensor<?x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @pad_determine_val_quant
-func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<4xi32>) -> tensor<?x?xi32> {
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
-  // CHECK: tosa.pad %arg0, %arg1, %[[ZERO]]
-  %1 = tosa.pad %arg0, %arg1 {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<?x?xi32>, tensor<4xi32>) -> tensor<?x?xi32>
+func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
+  %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = tosa.pad %arg0, %0 {input_zp = 42 : i32} : (tensor<?x?xi32>, !tosa.shape<4>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cc7fd009f01fa6..e58a45ca80990e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,52 +165,56 @@ func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : te
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
-  // expected-error at +1 {{'tosa.pad' op padding of pad is not constant}}
-  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
+  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
-  %0 = "tosa.const"() {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xi32>} : () -> tensor<6xi32>
+  %0 = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
-  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, tensor<6xi32>, tensor<i8>) -> tensor<13x21x3xi8>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<i8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
 
 // -----
 
-func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
+func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
+  %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>  
   // expected-error at +1 {{'tosa.pad' op expect same input and output tensor rank.}}
-  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
+  %1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
   return
 }
 
 // -----
 
-func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
-  // expected-error at +1 {{'tosa.pad' op operand #1 must be 1D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2x2xi32>'}}
-  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
+  %0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
+  %1 = tosa.pad %arg0, %0 : (tensor<13x21xf32>, !tosa.shape<6>) -> tensor<13x21xf32>
   return
 }
 
 // -----
 
-func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<4xi32>) {
-  %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
-  // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}}
-  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<4xi32>, tensor<1xf32>) -> tensor<13x21xf32>
+func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+  %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %1 = "tosa.const"() {value = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
+  // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<2xf32>'}}
+  %2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
   return
 }
 
 // -----
 
-func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<4xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>  
   // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
-  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
+  return %1 : tensor<13x21x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 690e208af1e5f9..563c5fa457d351 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -525,16 +525,18 @@ func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -
 
 // -----
 // CHECK-LABEL: pad
-func.func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<6xi32>) -> tensor<13x21x3xf32>
+func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 // CHECK-LABEL: pad_explicit_value
-func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6xi32>) -> tensor<13x21x3xf32> {
+func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
-  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21x3xf32>, tensor<6xi32>, tensor<f32>) -> tensor<13x21x3xf32>
+  %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
   return %1 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 8df4630f9c17ff..95d9bb1b98ab74 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -58,9 +58,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
 
 // CHECK-LABEL: @conv2d_as_fully_connected_padded
 func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
-  // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi64>}
+  // CHECK-DAG: %[[PAD_SHAPE:.+]] = tosa.const_shape  {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
-  // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, tensor<8xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
+  // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
   // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
   // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
   // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index cfff6396ad486d..bbcc206e1490c7 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -46,10 +46,10 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
 
 // CHECK-LABEL: @depthwise_conv2d_as_mul_padded
 func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
-  // CHECK-DAG: %[[pad:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xi64>}
+  // CHECK-DAG: %[[pad:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
   // CHECK-DAG: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
   // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
-  // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<10xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+  // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
   // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
   // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
   // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index c361c7c2899fc3..96f71c349938b9 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -45,7 +45,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
 // CHECK-LABEL: @transpose_conv2d_strided
 func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
   // Manipulate the weight matrix to handle striding.
-  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
+  // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]]
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -55,7 +55,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
 
   // Pad out the input matrix to handle the transpose conv.
-  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
+  // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
 
@@ -78,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
 // CHECK-LABEL: @transpose_conv2d_strided_quantized
 func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
   // Manipulate the weight matrix to handle striding.
-  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xi32>}
+  // CHECK-DAG: %[[PADV:.+]]  = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
@@ -88,7 +88,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[NEWWEIGHT:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
 
   // Pad out the input matrix to handle the transpose conv.
-  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xi32>}
+  // CHECK-DAG: %[[PAD:.+]]  = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
 
@@ -109,12 +109,12 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
 
 // CHECK-LABEL: @transpose_conv2d_strided_overpad
 func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
-  // CHECK-DAG: %[[WEIGHT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xi32>
+  // CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
-  // CHECK-DAG: %[[INPUT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi32>}
+  // CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
   // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
-  // CHECK-DAG: %[[RESULT_PAD:.+]] = "tosa.const"() <{value = dense<{{\[}}0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xi32>}
+  // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
   // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
   // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f4da66ef561b26..314cccab1d709c 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -492,22 +492,14 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
   return
 }
 
-// -----
-
-// CHECK-LABEL: @test_padding_no_const
-func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<4xi32>) -> () {
-  // CHECK: tosa.pad %arg0, %arg1 : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
-  %0 = tosa.pad %arg0, %arg1  : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
-  return
-}
 
 // -----
 
 // CHECK-LABEL:@test_padding_dynamic_input
 func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-  // CHECK: tosa.pad %arg0, %cst : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<4x?xf32>
-  %1 = tosa.pad %arg0, %0  : (tensor<1x?xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+  %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  // CHECK: tosa.pad %arg0, %0 : (tensor<1x?xf32>, !tosa.shape<4>) -> tensor<4x?xf32>
+  %1 = tosa.pad %arg0, %0  : (tensor<1x?xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
   return
 }
 
@@ -515,9 +507,9 @@ func.func @test_padding_dynamic_input(%arg0 : tensor<1x?xf32>) -> () {
 
 // CHECK-LABEL: @test_padding_simple
 func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-  // CHECK: tosa.pad %arg0, %cst : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<4x9xf32>
-  %1 = tosa.pad %arg0, %0  : (tensor<1x2xf32>, tensor<4xi32>) -> tensor<?x?xf32>
+  %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  // CHECK: tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> tensor<4x9xf32>
+  %1 = tosa.pad %arg0, %0  : (tensor<1x2xf32>, !tosa.shape<4>) -> tensor<?x?xf32>
   return
 }
 



More information about the Mlir-commits mailing list