[Mlir-commits] [mlir] [TOSA] Change PadOp padding to tosa.shape (PR #123133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 15 14:33:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

<details>
<summary>Changes</summary>

This patch changes PadOp's padding input to type !tosa.shape<2 * rank>, (where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor.

This patch is also a part of TOSA v1.0 effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

---

Patch is 42.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123133.diff


17 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+2) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+5-5) 
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+5) 
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+14-13) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+40-30) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+1-5) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-5) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+11-18) 
- (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+14) 
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+35-16) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+24-21) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+21-17) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-4) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+7-7) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+6-14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index e4f5d09064cd75..d54d4e17de8b54 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -39,6 +39,8 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
 void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
                      Attribute attr);
 
+bool collectShapeValue(Operation* op, llvm::SmallVector<int64_t>& newShape);
+
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
 } // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e1efa7a3001b9f..2953e006bbe8d1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Example:
 
     ```mlir
-    %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
-    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<4x9xf32>)
+    %0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+    tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<4x9xf32>)
     ```
 
     Example 2:
 
     ```mlir
-    %0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
-    tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<?x9xf32>)
+    %0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
+    tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<?x9xf32>)
     ```
   }];
 
   let arguments = (ins
     Tosa_RankedTensor:$input1,
-    TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
+    Tosa_Shape:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
     OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
   );
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 90fea1f68beb58..9b406f1083135c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -229,6 +229,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
   return permuted;
 }
 
+// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(PatternRewriter& rewriter, Location loc,
+                    llvm::ArrayRef<int64_t> shape);
+SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index b5a0da15e780e0..5aa0269a675cbe 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = padOp.getLoc();
     auto input = padOp.getInput1();
-    auto padding = padOp.getPadding();
+
+    ElementsAttr paddingElems;
+    if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+      return rewriter.notifyMatchFailure(
+          padOp, "padding must be a static shape value");
+    }
+    llvm::SmallVector<int64_t> paddingVals;
+    for (auto idx : paddingElems.getValues<IntegerAttr>()) {
+      paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
+    }
 
     ShapedType inputTy = cast<ShapedType>(input.getType());
     Type elementTy = inputTy.getElementType();
@@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
     highValues.reserve(rank);
 
     for (int i = 0; i < rank; i++) {
-      Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
-      Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
-      Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({lowIndex}));
-      Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({highIndex}));
-
-      lowVal = rewriter.createOrFold<arith::IndexCastOp>(
-          loc, rewriter.getIndexType(), lowVal);
-      highVal = rewriter.createOrFold<arith::IndexCastOp>(
-          loc, rewriter.getIndexType(), highVal);
-
+      Value lowVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(paddingVals[2 * i]));
+      Value highVal = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
       lowValues.push_back(lowVal);
       highValues.push_back(highVal);
     }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 83cf4a9415fe68..ee0f31652ba36a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -210,6 +210,26 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA shape inference helper
+//===----------------------------------------------------------------------===//
+bool mlir::tosa::collectShapeValue(Operation* op, llvm::SmallVector<int64_t>& newShape) {
+  if (!op) {
+    return false;
+  }
+  if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
+    Attribute constOpAttr = constOp->getAttr("value");
+    DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
+    for (int i = 0; i < elementsAttr.size(); i++) {
+      int64_t val = elementsAttr.getValues<int64_t>()[i];
+      newShape.push_back(val);
+    }
+    return true;
+  }
+  // for undefined op, return false.
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -823,51 +843,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
     PadOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
-  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
+  auto paddingRank =
+      cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
   SmallVector<int64_t> outputShape;
 
-  // If both inputs have unknown shape, we cannot determine the shape of the
-  // output.
-  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
-    return success();
-  }
-
-  // If the input rank is unknown we can info the output rank using the
-  // padding shape's first dim.
+  // If the input rank is unknown, we can infer the output rank using the
+  // padding shape's rank divided by 2.
   if (!inputShape.hasRank()) {
-    if (paddingShape.isDynamicDim(0)) {
-      inferredReturnShapes.push_back(ShapedTypeComponents());
-      return success();
-    }
-
-    outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
+    outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
-  DenseIntElementsAttr paddings;
+  SmallVector<int64_t> paddingValues;
   // If the paddings value is not a constant, all dimensions must be dynamic.
-  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
+  if (!tosa::collectShapeValue(adaptor.getPadding().getDefiningOp(),
+                               paddingValues)) {
     outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
-  SmallVector<int64_t> paddingValues;
-  for (auto val : paddings) {
-    paddingValues.push_back(val.getSExtValue());
-  }
-
   outputShape.reserve(inputShape.getRank());
   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
     if (inputShape.isDynamicDim(i)) {
       outputShape.push_back(ShapedType::kDynamic);
       continue;
     }
+    auto padFront = paddingValues[i * 2];
+    auto padBack = paddingValues[i * 2 + 1];
+    if (padFront < 0 || padBack < 0) {
+      // if either padding for dim i is -1, output dim is unknown
+      outputShape.push_back(ShapedType::kDynamic);
+      continue;
+    }
 
-    outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
-                          paddingValues[i * 2 + 1]);
+    outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -877,17 +888,16 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 LogicalResult tosa::PadOp::verify() {
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
-  RankedTensorType paddingType = getPadding().getType();
+  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
 
   if (inputType.getRank() != outputType.getRank())
     return emitOpError() << "expect same input and output tensor rank.";
-
-  if (!paddingType.isDynamicDim(0) &&
-      paddingType.getDimSize(0) != inputType.getRank() * 2)
+  
+  if (paddingRank != inputType.getRank() * 2)
     return emitOpError() << "expected padding tensor dim 0 to have size "
                          << inputType.getRank() * 2
                          << " (2*rank(shape1)) but got size "
-                         << paddingType.getDimSize(0);
+                         << paddingRank;
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 04a709c5967795..cb08360f902286 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
         }
       }
 
-      auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
-      auto padSize =
-          DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
-      Value padSizeVal =
-          rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+      Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
 
       auto padTy = RankedTensorType::get({}, inputETy);
       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 14f392ab8c45c1..45f4419875b485 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
         }
       }
 
-      auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
-      auto padSize =
-          DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
-      Value padSizeVal =
-          rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+      Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
 
       auto padTy = RankedTensorType::get({}, inputETy);
       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index db1e219b601b30..1b97f0b245d9ba 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -135,15 +135,14 @@ class TransposeConvStridedConverter
     int64_t inputChannels = weightTy.getDimSize(3);
 
     // Pad the weight so that it is modulo of the striding.
-    llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
     weightPadding[3] =
         (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
     weightPadding[5] =
-        (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
-    DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
-    Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
-        rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
+        weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
+
+    Value weightPaddingVal =
+        getTosaConstShape(rewriter, op->getLoc(), weightPadding);
 
     if (op.getQuantizationInfo().has_value()) {
       auto quantInfo = op.getQuantizationInfo().value();
@@ -197,17 +196,14 @@ class TransposeConvStridedConverter
         /* axis = */ rewriter.getI32IntegerAttr(2));
 
     // We need to pad the input far enough that we can pull all values.
-    llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
 
-    DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
-
-    Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
-        rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
+    Value inputPaddingVal =
+        getTosaConstShape(rewriter, op->getLoc(), inputPadding);
 
     if (op.getQuantizationInfo().has_value()) {
       auto quantInfo = op.getQuantizationInfo().value();
@@ -310,17 +306,14 @@ class TransposeConvStridedConverter
                      rewriter.getDenseI64ArrayAttr(sliceSize))
                      .getResult();
 
-    llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
     resultPadding[2] = resultPadTop;
     resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
     resultPadding[4] = resultPadLeft;
     resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
 
-    DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
-
-    Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
-        rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
+    Value resultPaddingVal =
+        getTosaConstShape(rewriter, op->getLoc(), resultPadding);
 
     Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), slice,
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 1f6e3b2ab83919..3a56ec2a10bcf6 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -160,3 +160,17 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
 
   return success();
 }
+
+Value mlir::tosa::getTosaConstShape(PatternRewriter& rewriter, Location loc,
+                    llvm::ArrayRef<int64_t> shape) {
+  auto attr = rewriter.getIndexTensorAttr(shape);
+  auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
+  mlir::Operation *mlir_op = rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+  return mlir_op->getResult(0);
+}
+
+SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
+  return to_vector(llvm::map_range(shape, [](int64_t dim) {
+    return ShapedType::isDynamic(dim) ? -1 : dim;
+  }));
+}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 0b9a64494bc0f1..2f11b31aad2307 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -459,65 +459,84 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
 // CHECK-LABEL: @pad_float
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<4xi32>)  -> (tensor<4x9xf32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, !tosa.shape<4>)  -> (tensor<4x9xf32>)
   return %1 : tensor<4x9xf32>
 }
+// -----
 
 func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: [[CST:%.+]] = arith.constant 0 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<4xi32>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, !tosa.shape<4>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
+// -----
 
 func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: [[CST:%.+]] = arith.constant 42 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<4xi32>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
 
 // -----
 
 func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
   %1 = arith.constant dense<42.0> : tensor<f32>
-  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<4xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>)  -> (tensor<4x9xf32>)
   return %2 : tensor<4x9xf32>
 }
 
 // -----
 
 func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %0 = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
   // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, %{{.*}}] high{{\[}}%{{.*}}, %{{.*}}] {
+  // CHECK: tens...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/123133


More information about the Mlir-commits mailing list