[Mlir-commits] [mlir] [mlir][tosa] Change the start and size of slice to tosa shape type (PR #124209)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 23 15:51:32 PST 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/124209

>From 80f3b6ae30e763fbd58b16adbd0e280de805f8df Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Mon, 22 Jan 2024 14:32:54 -0800
Subject: [PATCH] [TOSA] Change the start and size of slice to tosa shape type

Update to use getConstShapeValue to collect shape information
along the graph.

Change-Id: Ic6fc2341e3bcfbec06a1d08986e26dd08573bd9c
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  4 +-
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 34 +++++++--
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 45 ++++++++----
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 22 ++++--
 .../Transforms/TosaDecomposeTransposeConv.cpp |  4 +-
 .../TosaToTensor/tosa-to-tensor-invalid.mlir  |  6 +-
 .../TosaToTensor/tosa-to-tensor.mlir          | 10 ++-
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 71 +++++++++++++------
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  |  9 ++-
 mlir/test/Dialect/Tosa/invalid.mlir           |  8 ++-
 mlir/test/Dialect/Tosa/level_check.mlir       | 13 ++--
 mlir/test/Dialect/Tosa/ops.mlir               | 15 +++-
 .../Tosa/tosa-decompose-transpose-conv.mlir   | 18 +++--
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 40 ++++++++---
 14 files changed, 220 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2953e006bbe8d1..d2ac206263bb13 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1664,8 +1664,8 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$start,
-    DenseI64ArrayAttr:$size
+    Tosa_Shape:$start,
+    Tosa_Shape:$size
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5aa0269a675cbe..d6319c1c3fd891 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -268,12 +268,28 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
     if (llvm::isa<UnrankedTensorType>(resultType))
       return failure();
+
+    ElementsAttr StartElems;
+    ElementsAttr SizeElems;
+
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&StartElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&SizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+
+    llvm::SmallVector<int64_t> SliceStarts =
+        llvm::to_vector(StartElems.getValues<int64_t>());
+    llvm::SmallVector<int64_t> SliceSizes =
+        llvm::to_vector(SizeElems.getValues<int64_t>());
+
     SmallVector<int64_t> strides, sizes;
-    ArrayRef<int64_t> starts = sliceOp.getStart();
     strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
 
     SmallVector<Value> dynSizes;
-    for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
+    for (const auto &i : llvm::enumerate(SliceSizes)) {
       int64_t size = i.value();
       size_t index = i.index();
       sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
@@ -282,17 +298,27 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
 
       auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
       auto offset = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIndexAttr(starts[index]));
+          loc, rewriter.getIndexAttr(SliceStarts[index]));
       dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
     }
 
     auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
         sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
-        ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
+        ValueRange({}), rewriter.getDenseI64ArrayAttr(SliceStarts),
         rewriter.getDenseI64ArrayAttr(sizes),
         rewriter.getDenseI64ArrayAttr(strides));
 
     rewriter.replaceOp(sliceOp, newSliceOp.getResult());
+
+    auto removeIfRedundant = [&](Operation *op) {
+      if (op->getResults().size() == 1 && op->getResult(0).hasOneUse())
+        rewriter.eraseOp(op);
+    };
+
+    // Remove const_shape ops when it no longer has use point.
+    removeIfRedundant(sliceOp.getStart().getDefiningOp());
+    removeIfRedundant(sliceOp.getSize().getDefiningOp());
+
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f7a596f1ccb192..90c3b081be14b5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -393,8 +393,21 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
           sliceOp, "slice input must be a static ranked tensor");
     int32_t axis = concatOp.getAxis();
 
-    llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
-    llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
+    DenseElementsAttr StartElems;
+    DenseElementsAttr SizeElems;
+
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&StartElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&SizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+
+    llvm::SmallVector<int64_t> SliceStarts =
+        llvm::to_vector(StartElems.getValues<int64_t>());
+    llvm::SmallVector<int64_t> SliceSizes =
+        llvm::to_vector(SizeElems.getValues<int64_t>());
 
     // Validate slice on the concatenated axis. Slicing along this
     // axis should span only one of the inputs to the concatenate
@@ -406,17 +419,20 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
         return rewriter.notifyMatchFailure(
             sliceOp, "concat input must be a static ranked tensor");
 
-      if (sliceStart[axis] >= 0 &&
-          (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
-        replaceWithSlice = rewriter
-                               .create<tosa::SliceOp>(
-                                   sliceOp.getLoc(), sliceOp.getType(), input,
-                                   rewriter.getDenseI64ArrayAttr(sliceStart),
-                                   rewriter.getDenseI64ArrayAttr(sliceSize))
-                               .getResult();
+      if (SliceStarts[axis] >= 0 && (SliceStarts[axis] + SliceSizes[axis]) <=
+                                        inputType.getDimSize(axis)) {
+        auto start_op =
+            getTosaConstShape(rewriter, sliceOp.getLoc(), SliceStarts);
+        auto size_op =
+            getTosaConstShape(rewriter, sliceOp.getLoc(), SliceSizes);
+        replaceWithSlice =
+            rewriter
+                .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
+                                       input, start_op, size_op)
+                .getResult();
         break;
       }
-      sliceStart[axis] -= inputType.getDimSize(axis);
+      SliceStarts[axis] -= inputType.getDimSize(axis);
     }
 
     if (!replaceWithSlice)
@@ -963,7 +979,12 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 
   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
       outputTy.getNumElements() == 1) {
-    llvm::SmallVector<uint64_t> indices(getStart());
+    DenseElementsAttr StartElems;
+    if (!matchPattern(getStart(), m_Constant(&StartElems)))
+      return {};
+
+    llvm::SmallVector<uint64_t> indices =
+        llvm::to_vector(StartElems.getValues<uint64_t>());
     auto value = operand.getValues<Attribute>()[indices];
     return SplatElementsAttr::get(outputTy, value);
   }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fdccce60fe1d86..2f7bd75974b79b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -891,8 +891,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  auto start = adaptor.getStart();
-  auto size = adaptor.getSize();
+
+  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+  SmallVector<int64_t> start;
+  SmallVector<int64_t> size;
+
+  if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
+      !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
+    auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  }
 
   // if size[i] is -1, all remaining elements in dimension i are included
   // in the slice, similar to TF.
@@ -933,11 +943,15 @@ LogicalResult tosa::SliceOp::verify() {
   if (!inputType)
     return success();
 
-  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
+  auto StartShapeRank =
+      llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
+  if (inputType.getRank() != StartShapeRank)
     return emitOpError(
         "length of start attribute is not equal rank of input shape");
 
-  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
+  auto SizeShapeRank =
+      llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
+  if (inputType.getRank() != SizeShapeRank)
     return emitOpError(
         "length of size attribute is not equal rank of input shape");
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 1b97f0b245d9ba..807f9cd683bb8c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -302,8 +302,8 @@ class TransposeConvStridedConverter
 
     auto slice = CreateOpAndInferShape<tosa::SliceOp>(
                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-                     rewriter.getDenseI64ArrayAttr(sliceBegin),
-                     rewriter.getDenseI64ArrayAttr(sliceSize))
+                     getTosaConstShape(rewriter, loc, sliceBegin),
+                     getTosaConstShape(rewriter, loc, sliceSize))
                      .getResult();
 
     llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
index 36eb4d4669b07a..a72d6b333f7ea4 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
@@ -2,7 +2,9 @@
 
 // CHECK-LABEL:  @slice_resultType_unranked
 func.func @slice_resultType_unranked(%arg0: tensor<?xf32>) -> (tensor<*xf32>) {
+  %0 = tosa.const_shape  {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape  {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1>
   // expected-error at +1 {{failed to legalize operation 'tosa.slice'}}
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 0>} : (tensor<?xf32>)  -> (tensor<*xf32>)
-  return %0 : tensor<*xf32>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
+  return %2 : tensor<*xf32>
 }
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 2f11b31aad2307..dd97a872a07b46 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -437,7 +437,9 @@ func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3x
 // CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>)  -> (tensor<1xf32>)
+  %0 = tosa.const_shape  {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape  {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<6xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<1xf32>
   return
 }
 
@@ -450,8 +452,10 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
   // CHECK: %[[C2:.+]] = arith.constant 2 : index
   // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
   // CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: -1>} : (tensor<?xf32>)  -> (tensor<?xf32>)
-  return %0 : tensor<?xf32>
+  %0 = tosa.const_shape  {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape  {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xf32>
+  return %2 : tensor<?xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e394188e9a9311..4a8694c4713e4e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -572,18 +572,22 @@ func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<
 
 // CHECK-LABEL: @slice_fold
 func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
+  %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // CHECK: return %arg0
-  %0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<3x4xf32>
-  return %0 : tensor<3x4xf32>
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4xf32>
+  return %3 : tensor<3x4xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @slice_nofold
 func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // CHECK: tosa.slice
-  %0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<?x4xf32>) -> tensor<?x4xf32>
-  return %0 : tensor<?x4xf32>
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
 }
 
 // -----
@@ -663,9 +667,12 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
 func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
-  return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %2 = tosa.const_shape {value = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %3 = tosa.const_shape {value = dense<[1, 12, 12, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %4 = tosa.slice %0, %1, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
+  %5 = tosa.slice %0, %2, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
+  return %4, %5 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
 }
 
 // -----
@@ -675,38 +682,56 @@ func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
 func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
-  return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.const_shape {value = dense<[0, 12, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %3 = tosa.const_shape {value = dense<[1, 12, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %4 = tosa.slice %0, %1, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
+  %5 = tosa.slice %0, %2, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
+  return %4, %5 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @canonicalize_cross_concat_inputs
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
-// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
-// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape  {value = dense<[1, 12, 20]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape  {value = dense<[1, 12, 15]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape  {value = dense<[0, 0, 4]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape  {value = dense<0> : tensor<3xindex>}
+// CHECK: %[[VAL_6:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_5]], %[[VAL_3]]
+// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_6]], %[[VAL_4]], %[[VAL_2]]
+// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
 func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
-  return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %3 = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %4 = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %5 = tosa.slice %0, %1, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x15xf32>
+  %6 = tosa.slice %0, %2, %4 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x20xf32>
+  return %5, %6 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.slice %[[VAL_0]] {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
-// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape  {value = dense<[1, 3, 0]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape  {value = dense<[1, 3, 12]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape  {value = dense<0> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape  {value = dense<[1, 6, 12]> : tensor<3xindex>}
+// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_0]], %[[VAL_4]], %[[VAL_5]]
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]
+// CHECK: return %[[VAL_6]], %[[VAL_7]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
 func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
-  return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.const_shape {value = dense<[1, 6, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %3 = tosa.const_shape {value = dense<[1, 3, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %4 = tosa.slice %0, %1, %2 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x6x12xf32>
+  %5 = tosa.slice %0, %3, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x3x12xf32>
+  return %4, %5 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 8198903b78ac05..32dbe3315ca0bc 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -506,7 +506,10 @@ func.func @reshape_splat() -> tensor<6x5x4xi32> {
 func.func @slice_splat() -> tensor<1x1x1xi32> {
   // CHECK: %[[SLICE:.+]] = "tosa.const"() <{value = dense<42> : tensor<1x1x1xi32>}
   %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
-  %slice = tosa.slice %splat { size = array<i64: 1, 1, 1>, start = array<i64: 1, 2, 3> } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32>
+  %start = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32>
+
   // CHECK: return %[[SLICE]]
   return %slice : tensor<1x1x1xi32>
 }
@@ -517,7 +520,9 @@ func.func @slice_splat() -> tensor<1x1x1xi32> {
 func.func @slice_singleton() -> tensor<1x1xi32> {
   %splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
   // CHECK: %[[SLICE:.+]] = "tosa.const"() <{value = dense<4> : tensor<1x1xi32>}
-  %slice = tosa.slice %splat { size = array<i64: 1, 1>, start = array<i64: 1, 1> } : (tensor<3x3xi32>) -> tensor<1x1xi32>
+  %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %size = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32>
   // CHECK: return %[[SLICE]]
   return %slice : tensor<1x1xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 4808867b28bb97..3358660d89fc10 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -607,8 +607,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
 
 func.func @test_slice_invalid_start() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
+  %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
   // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
-  %1 = tosa.slice %0 {size = array<i64: 1, 1, 1>, start = array<i64: 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
+  %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
   return
 }
 
@@ -616,8 +618,10 @@ func.func @test_slice_invalid_start() {
 
 func.func @test_slice_invalid_size() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
+  %start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
   // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
-  %1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
+  %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0fe35d88f0e73a..26bebdd898a0d6 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -86,10 +86,11 @@ func.func @test_reverse(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13
 // -----
 // CHECK-LABEL: slice
 func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> {
+  %0 = tosa.const_shape {value = dense<[0, 0, 0, 0, 6, 8, 0]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 4, 11, 1]> : tensor<7xindex>} : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.slice' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 0, 0, 0, 0, 6, 8, 0>, size = array<i64: 1, 1, 1, 1, 4, 11, 1>} :
-          (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32>
-  return %0 : tensor<1x1x1x1x4x11x1xf32>
+  %2= tosa.slice %arg0, %0, %1 : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>, !tosa.shape<7>) -> tensor<1x1x1x1x4x11x1xf32>
+  return %2 : tensor<1x1x1x1x4x11x1xf32>
 }
 
 // -----
@@ -736,8 +737,10 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
 
 // CHECK-LABEL: unranked_tensor
 func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
+  %0 = tosa.const_shape {value = dense<[0]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+
   // expected-error at +1 {{'tosa.slice' op failed level check: unranked tensor}}
-  %0 = "tosa.slice"(%arg0) {start = array<i64>, size = array<i64>} :
-          (tensor<*xf32>) -> tensor<*xf32>
+  %2= tosa.slice %arg0, %0, %1 : (tensor<*xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 563c5fa457d351..98505bb9a48312 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -557,8 +557,19 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
 // -----
 // CHECK-LABEL: slice
 func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
-  %0 = tosa.slice %arg0 {size = array<i64: 4, 11, 1>, start = array<i64: 6, 8, 0>} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32>
-  return %0 : tensor<4x11x1xf32>
+  %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32>
+  return %2 : tensor<4x11x1xf32>
+}
+
+// -----
+// CHECK-LABEL: slice_size
+func.func @test_slice_size(%arg0: tensor<13x21x3xf32>) -> tensor<7x11x1xf32> {
+  %0 = tosa.const_shape {value = dense<[-1, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x11x1xf32>
+  return %2 : tensor<7x11x1xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 96f71c349938b9..12691f2e325a28 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -43,6 +43,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
 // -----
 
 // CHECK-LABEL: @transpose_conv2d_strided
+
 func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
   // Manipulate the weight matrix to handle striding.
   // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
@@ -64,9 +65,11 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
   // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
   // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
-  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array<i64: 2, 36, 48, 5>}
-  // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
-  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 5>}
+  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
+  // CHECK-DAG: %[[START:.*]] = tosa.const_shape  {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape  {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
+  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
   // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
   %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
@@ -76,6 +79,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
 // -----
 
 // CHECK-LABEL: @transpose_conv2d_strided_quantized
+
 func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
   // Manipulate the weight matrix to handle striding.
   // CHECK-DAG: %[[PADV:.+]]  = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
@@ -97,9 +101,11 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
   // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
   // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
-  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array<i64: 2, 36, 48, 5>}
-  // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
-  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 5>}
+  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
+  // CHECK-DAG: %[[START:.*]] = tosa.const_shape  {value = dense<0> : tensor<4xindex>}
+  // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape  {value = dense<[2, 35, 47, 5]> : tensor<4xindex>}
+  // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
+  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
   // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
   return %0 : tensor<2x35x47x5xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 6beb1ad6296135..0c59cb089b9357 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -517,8 +517,12 @@ func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
 
 // CHECK-LABEL: @test_slice
 func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: 2>, start = array<i64: 1>} : (tensor<?xi32>) -> tensor<2xi32>
-  %0 = tosa.slice %arg0 { size = array<i64: 2>, start = array<i64: 1> } : (tensor<?xi32>) -> tensor<?xi32>
+  // CHECK: %0 = tosa.const_shape  {value = dense<1> : tensor<1xindex>}
+  // CHECK: %1 = tosa.const_shape  {value = dense<2> : tensor<1xindex>}
+  // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
+  %0 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2= tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xi32>
   return
 }
 
@@ -526,13 +530,17 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_size_minus_one
 func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+  // CHECK: %[[Start:.+]] = tosa.const_shape
+  // CHECK: %[[Size:.+]] = tosa.const_shape
+  // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[Start]], %[[Size]] : (tensor<?x8x8x8xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x7x?x?xi32>
   // this checks following
   //  dim 0: size=-1, input dim=? => inferred output dim is ?
   //  dim 1: size=-1 => inferred output dim is input_dim - start
   //  dim 2: size=-1, start=-1 => inferred output dim is ?
   //  dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
-  %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+  %start = tosa.const_shape {value = dense<[0, 1, -1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape {value = dense<[-1, -1, -1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %2= tosa.slice %arg0, %start, %size : (tensor<?x8x8x8xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x?xi32>
   return
 }
 
@@ -540,13 +548,17 @@ func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_size_out_of_bound
 func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // CHECK: %[[Start:.+]] = tosa.const_shape
+  // CHECK: %[[Size:.+]] = tosa.const_shape
+  // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[Start]], %[[Size]] : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x4xi32>
   // this checks following
   //  dim 0: size=0 => inferred output dim is ?
   //  dim 1: size=-2 => inferred output dim is ?
   //  dim 3: start+size out of bound because size too big: inferred output dim is ?
   //  dim 4: size=4, input dim=? => inferred output dim is 4
-  %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  %start = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape {value = dense<[0, -2, 9, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %2= tosa.slice %arg0, %start, %size : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x?xi32>
   return
 }
 
@@ -554,13 +566,17 @@ func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_start_out_of_bound
 func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // CHECK: %[[Start:.+]] = tosa.const_shape
+  // CHECK: %[[Size:.+]] = tosa.const_shape
+  // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[Start]], %[[Size]] : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x4xi32>
   // this checks following
   //  dim 0: start=-1 => inferred output dim is ?
   //  dim 1: start=8 => inferred output dim is ?
   //  dim 2: start+size out of bound: inferred output dim is ?
   //  dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
-  %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  %start = tosa.const_shape {value = dense<[-1, 8, 6, 8000000]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape {value = dense<[1, 1, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %2= tosa.slice %arg0, %start, %size : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x?x?x?xi32>
   return
 }
 
@@ -568,8 +584,12 @@ func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
 
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
-  %0 = tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<?x?x?xf32>
+  // CHECK: %0 = tosa.const_shape  {value = dense<[1, 0, 0]> : tensor<3xindex>}
+  // CHECK: %1 = tosa.const_shape  {value = dense<[7, -1, 1]> : tensor<3xindex>}
+  // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
+  %0 = tosa.const_shape {value = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %1 = tosa.const_shape {value = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2= tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<?x?x?xf32>
   return
 }
 



More information about the Mlir-commits mailing list