[Mlir-commits] [mlir] [mlir][tosa] Add concat/slice_shape ops (PR #174620)

Luke Hutton llvmlistbot at llvm.org
Thu Jan 8 02:18:32 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/174620

>From c5b1a9c9e3de8f985dfc8b9452d24a2d1af3d32b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 10 Dec 2025 12:36:10 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add concat/slice_shape ops

Adds support for the concat/slice_shape operations after spec change:
https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce.

This includes the operator definition, same rank checks and level checks
during validation. It does not currently include support for folding or
shape inference. This will be added in a later commit.

Co-authored-by: Iliyan Georgiev <Iliyan.Georgiev at arm.com>
Change-Id: I3f2fd31ca2f7adb115c16a507b01e546de1badaa
---
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      | 41 +++++++++++++++
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  1 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 50 +++++++++++++++++++
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  2 +
 .../Tosa/Transforms/TosaValidation.cpp        | 14 ++++++
 mlir/test/Dialect/Tosa/dynamic_extension.mlir | 10 +++-
 mlir/test/Dialect/Tosa/invalid.mlir           | 22 +++++++-
 mlir/test/Dialect/Tosa/level_check.mlir       | 29 +++++++++++
 mlir/test/Dialect/Tosa/ops.mlir               | 38 ++++++++++++++
 mlir/test/Dialect/Tosa/verifier.mlir          | 44 ++++++++++++++++
 10 files changed, 249 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 182504676710f..79967b7c9585e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -69,6 +69,26 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: ConcatShape
+//===----------------------------------------------------------------------===//
+def Tosa_ConcatShapeOp : Tosa_ShapeOp<"concat_shape", [Pure]> {
+  let summary = "Concatenates a list of shapes.";
+
+  let description = [{
+    Concatenates a list of shapes into a new shape with length equal to the sum
+    of the lengths of the inputs.
+  }];
+
+  let arguments = (ins
+    Variadic<Tosa_Shape>:$input
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: ConstShape
 //===----------------------------------------------------------------------===//
@@ -173,6 +193,27 @@ def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: SliceShape
+//===----------------------------------------------------------------------===//
+def Tosa_SliceShapeOp : Tosa_ShapeOp<"slice_shape", [Pure]> {
+  let summary = "Extract slice of a shape.";
+
+  let description = [{
+    Extract a shape of size from input.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input,
+    Tosa_ScalarInt32Tensor:$start,
+    Tosa_ScalarInt32Tensor:$size
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: SubShape
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 266a9e3a7d946..a105b58e57e2c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -176,6 +176,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
 def Tosa_ScalarInt8Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int8]>, TosaScalarTensorOf<[Tosa_Int8], [1]>]>;
 def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
+def Tosa_ScalarInt32Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int32]>, TosaScalarTensorOf<[Tosa_Int32], [1]>]>;
 
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9ef172e4fced..5fd45ad11de70 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4666,6 +4666,56 @@ LogicalResult tosa::DimOp::verify() {
   return success();
 }
 
+LogicalResult tosa::ConcatShapeOp::verify() {
+  const tosa::shapeType outShapeType =
+      cast<tosa::shapeType>(getResult().getType());
+  const int64_t outputRank = outShapeType.getRank();
+  const Operation::operand_range inputList = getInput();
+  const int64_t inputsRank =
+      llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
+        const tosa::shapeType inShapeType =
+            cast<tosa::shapeType>(input.getType());
+        return acc + inShapeType.getRank();
+      });
+  if (outputRank != inputsRank)
+    return emitOpError("requires output shape rank to be equal to the sum of "
+                       "the input shape ranks (")
+           << inputsRank << "), got " << outputRank;
+
+  return success();
+}
+
+LogicalResult tosa::SliceShapeOp::verify() {
+  std::optional<int32_t> start;
+  DenseIntElementsAttr startAttr;
+  if (matchPattern(getStart(), m_Constant(&startAttr)))
+    start = startAttr.getValues<int32_t>()[0];
+  if (start && start.value() < 0)
+    return emitOpError("expected non-negative start index, got ")
+           << start.value();
+
+  std::optional<int32_t> size;
+  DenseIntElementsAttr sizeAttr;
+  if (matchPattern(getSize(), m_Constant(&sizeAttr)))
+    size = sizeAttr.getValues<int32_t>()[0];
+  if (size && size.value() <= 0)
+    return emitOpError("expected positive size, got ") << size.value();
+
+  if (!start || !size)
+    return success();
+
+  const tosa::shapeType inShapeType =
+      cast<tosa::shapeType>(getInput().getType());
+  const int64_t inputRank = inShapeType.getRank();
+  const int64_t sliceSize = start.value() + size.value();
+  if (sliceSize > inputRank)
+    return emitOpError("expected start + size to be less than or equal to "
+                       "input shape rank (")
+           << inputRank << "), got " << sliceSize;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Attribute Definitions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 62f66543d74cc..e8a24057a96ac 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -325,10 +325,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   // of the data type, meaning any compatible type can be used. No type
   // constraint for those operations.
   POPULATE_PROFILE_INFO_SKIP(AddShape)
+  POPULATE_PROFILE_INFO_SKIP(ConcatShape)
   POPULATE_PROFILE_INFO_SKIP(ConstShape)
   POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
   POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
   POPULATE_PROFILE_INFO_SKIP(MulShape)
+  POPULATE_PROFILE_INFO_SKIP(SliceShape)
   POPULATE_PROFILE_INFO_SKIP(SubShape)
   POPULATE_PROFILE_INFO_SKIP(Yield)
   POPULATE_PROFILE_INFO_SKIP(If)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 897cc87529eca..4ccd7163d4c3d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -129,6 +129,15 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
   return success();
 }
 
+static LogicalResult checkConstantOperandSilceShape(Operation *op,
+                                                    const TargetEnv &env) {
+  if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
+    // Check 'start' and 'size'
+    return checkConstantOperands(op, {1, 2});
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Validation Pass.
 //===----------------------------------------------------------------------===//
@@ -177,6 +186,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandMatMul);
     constCheckers.emplace_back(checkConstantOperandAvgPool2d);
     constCheckers.emplace_back(checkConstantOperandNegate);
+    constCheckers.emplace_back(checkConstantOperandSilceShape);
   }
 
   LogicalResult levelCheckKernel(Operation *op, int32_t v,
@@ -481,6 +491,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         return failure();
       }
     }
+    if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
+      return levelCheckListSize(op, concat_shape.getInput().size(), "input");
     return success();
   }
 
@@ -693,9 +705,11 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 
   // Shape Operators
   CHECK_RANKS(AddShape);
+  CHECK_RANKS(ConcatShape);
   CHECK_RANKS(DivCeilShape);
   CHECK_RANKS(DivFloorShape);
   CHECK_RANKS(MulShape);
+  CHECK_RANKS(SliceShape);
   CHECK_RANKS(SubShape);
 
 #undef CHECK_RANKS_AND_SIZES
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 60b70b8754611..a1329afc3bb03 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
 // Check operations when the dynamic extension is enabled.
 //--------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=dynamic,shape" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
 
 // -----
 
@@ -85,3 +85,11 @@ func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp
          (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
 }
+
+// -----
+
+func.func @test_slice_shape_non_const_start_size(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %3 = tosa.slice_shape %0, %arg0, %arg1 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3d24928487ed2..e8206c24f1507 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,shape" -tosa-validate="strict-op-spec-alignment"
 
 
 func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2067,3 +2067,23 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, ten
   %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
   return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
 }
+
+// -----
+
+func.func @test_slice_shape_non_const_start(%arg0: tensor<1xi32>) -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.slice_shape' op expected compile time resolvable constant, but got variable value for operand #1}}
+  %3 = tosa.slice_shape %0, %arg0, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
+
+// -----
+
+func.func @test_slice_shape_non_const_size(%arg0: tensor<1xi32>) -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.slice_shape' op expected compile time resolvable constant, but got variable value for operand #2}}
+  %3 = tosa.slice_shape %0, %1, %arg0 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index da21c18e19783..ec540bda5e57d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1690,3 +1690,32 @@ func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
   %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
   return %0 : !tosa.shape<1>
 }
+
+
+// -----
+
+func.func @test_concat_shape_invalid_list_size() {
+  %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  // expected-error at +1 {{'tosa.concat_shape' op failed level check for MAX_TENSOR_LIST_SIZE: input}}
+  %1 = tosa.concat_shape %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0, %0, %0, %0, %0, %0, %0, %0,
+                         %0 :
+                         (
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+                          !tosa.shape<0>
+                         ) -> !tosa.shape<0>
+  return
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 4e1935233535d..276eac4d6166d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1433,3 +1433,41 @@ func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
   %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
   return %0 : !tosa.shape<1>
 }
+
+// -----
+// CHECK-LABEL: test_concat_shape
+func.func @test_concat_shape() -> !tosa.shape<5> {
+  %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.concat_shape %0, %1, %2 : (!tosa.shape<1>, !tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<5>
+  return %3 : !tosa.shape<5>
+}
+
+// -----
+// CHECK-LABEL: test_concat_shape_rank_0
+func.func @test_concat_shape_rank_0() -> !tosa.shape<0> {
+  %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  %2 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  %3 = tosa.concat_shape %0, %1, %2 : (!tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>) -> !tosa.shape<0>
+  return %3 : !tosa.shape<0>
+}
+
+// -----
+// CHECK-LABEL: test_slice_shape
+func.func @test_slice_shape() -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
+
+// -----
+// CHECK-LABEL: test_slice_shape_dynamic
+func.func @test_slice_shape_dynamic(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %3 = tosa.slice_shape %0, %arg0, %arg1 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 7104285af8446..089dee22b2f87 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1270,3 +1270,47 @@ func.func @test_dim_scalar(%arg0: tensor<i32>) -> !tosa.shape<1> {
   %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<i32>) -> !tosa.shape<1>
   return %0 : !tosa.shape<1>
 }
+
+// -----
+
+func.func @test_concat_shape_rank_mismatch() -> !tosa.shape<4> {
+  %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.concat_shape' op requires output shape rank to be equal to the sum of the input shape ranks (5), got 4}}
+  %3 = tosa.concat_shape %0, %1, %2 : (!tosa.shape<1>, !tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<4>
+  return %3 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_slice_shape_negative_start() -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.slice_shape' op expected non-negative start index, got -1}}
+  %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
+
+// -----
+
+func.func @test_slice_shape_non_positive_size() -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.slice_shape' op expected positive size, got 0}}
+  %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}
+
+// -----
+
+func.func @test_slice_out_of_range() -> !tosa.shape<3> {
+  %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %1 = "tosa.const"() {values = dense<5> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.slice_shape' op expected start + size to be less than or equal to input shape rank (6), got 8}}
+  %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+  return %3 : !tosa.shape<3>
+}

>From d696494525882c9290c28b2bc713831d38f9f002 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 8 Jan 2026 10:05:25 +0000
Subject: [PATCH 2/2] check slice output size matches slice size

Change-Id: I04b6d4b9c658f87432e360f29e5e88a1048abe63
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 ++++++++++++-
 mlir/test/Dialect/Tosa/verifier.mlir | 11 +++++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5fd45ad11de70..5656f3de698c5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4701,7 +4701,18 @@ LogicalResult tosa::SliceShapeOp::verify() {
   if (size && size.value() <= 0)
     return emitOpError("expected positive size, got ") << size.value();
 
-  if (!start || !size)
+  if (!size)
+    return success();
+
+  const tosa::shapeType outShapeType =
+      cast<tosa::shapeType>(getResult().getType());
+  const int64_t outputRank = outShapeType.getRank();
+  if (outputRank != size)
+    return emitOpError(
+               "expected output type size to be equal to size attribute, got ")
+           << outputRank << " vs " << size.value();
+
+  if (!start)
     return success();
 
   const tosa::shapeType inShapeType =
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 089dee22b2f87..a51ed4f09400f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1314,3 +1314,14 @@ func.func @test_slice_out_of_range() -> !tosa.shape<3> {
   %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
   return %3 : !tosa.shape<3>
 }
+
+// -----
+
+func.func @test_slice_shape_incorrect_output_size() -> !tosa.shape<4> {
+  %shape = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %start = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %size  = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.slice_shape' op expected output type size to be equal to size attribute, got 4 vs 3}}
+  %slice = tosa.slice_shape %shape, %start, %size : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<4>
+  return %slice : !tosa.shape<4>
+}



More information about the Mlir-commits mailing list