[Mlir-commits] [mlir] [mlir][tosa] Add concat/slice_shape ops (PR #174620)
Luke Hutton
llvmlistbot at llvm.org
Thu Jan 8 02:18:32 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/174620
>From c5b1a9c9e3de8f985dfc8b9452d24a2d1af3d32b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 10 Dec 2025 12:36:10 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add concat/slice_shape ops
Adds support for the concat/slice_shape operations after spec change:
https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce.
This includes the operator definition, same rank checks and level checks
during validation. It does not currently include support for folding or
shape inference. This will be added in a later commit.
Co-authored-by: Iliyan Georgiev <Iliyan.Georgiev at arm.com>
Change-Id: I3f2fd31ca2f7adb115c16a507b01e546de1badaa
---
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 41 +++++++++++++++
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 1 +
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 50 +++++++++++++++++++
.../Tosa/Transforms/TosaProfileCompliance.cpp | 2 +
.../Tosa/Transforms/TosaValidation.cpp | 14 ++++++
mlir/test/Dialect/Tosa/dynamic_extension.mlir | 10 +++-
mlir/test/Dialect/Tosa/invalid.mlir | 22 +++++++-
mlir/test/Dialect/Tosa/level_check.mlir | 29 +++++++++++
mlir/test/Dialect/Tosa/ops.mlir | 38 ++++++++++++++
mlir/test/Dialect/Tosa/verifier.mlir | 44 ++++++++++++++++
10 files changed, 249 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 182504676710f..79967b7c9585e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -69,6 +69,26 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
let results = (outs Tosa_Shape:$output);
}
+//===----------------------------------------------------------------------===//
+// Operator: ConcatShape
+//===----------------------------------------------------------------------===//
+def Tosa_ConcatShapeOp : Tosa_ShapeOp<"concat_shape", [Pure]> {
+ let summary = "Concatenates a list of shapes.";
+
+ let description = [{
+ Concatenates a list of shapes into a new shape with length equal to the sum
+ of the lengths of the inputs.
+ }];
+
+ let arguments = (ins
+ Variadic<Tosa_Shape>:$input
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: ConstShape
//===----------------------------------------------------------------------===//
@@ -173,6 +193,27 @@ def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> {
let results = (outs Tosa_Shape:$output);
}
+//===----------------------------------------------------------------------===//
+// Operator: SliceShape
+//===----------------------------------------------------------------------===//
+def Tosa_SliceShapeOp : Tosa_ShapeOp<"slice_shape", [Pure]> {
+ let summary = "Extract slice of a shape.";
+
+ let description = [{
+ Extract a shape of size from input.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input,
+ Tosa_ScalarInt32Tensor:$start,
+ Tosa_ScalarInt32Tensor:$size
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: SubShape
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 266a9e3a7d946..a105b58e57e2c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -176,6 +176,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
def Tosa_ScalarInt8Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int8]>, TosaScalarTensorOf<[Tosa_Int8], [1]>]>;
def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
+def Tosa_ScalarInt32Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int32]>, TosaScalarTensorOf<[Tosa_Int32], [1]>]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c9ef172e4fced..5fd45ad11de70 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4666,6 +4666,56 @@ LogicalResult tosa::DimOp::verify() {
return success();
}
+LogicalResult tosa::ConcatShapeOp::verify() {
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ const int64_t outputRank = outShapeType.getRank();
+ const Operation::operand_range inputList = getInput();
+ const int64_t inputsRank =
+ llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
+ const tosa::shapeType inShapeType =
+ cast<tosa::shapeType>(input.getType());
+ return acc + inShapeType.getRank();
+ });
+ if (outputRank != inputsRank)
+ return emitOpError("requires output shape rank to be equal to the sum of "
+ "the input shape ranks (")
+ << inputsRank << "), got " << outputRank;
+
+ return success();
+}
+
+LogicalResult tosa::SliceShapeOp::verify() {
+ std::optional<int32_t> start;
+ DenseIntElementsAttr startAttr;
+ if (matchPattern(getStart(), m_Constant(&startAttr)))
+ start = startAttr.getValues<int32_t>()[0];
+ if (start && start.value() < 0)
+ return emitOpError("expected non-negative start index, got ")
+ << start.value();
+
+ std::optional<int32_t> size;
+ DenseIntElementsAttr sizeAttr;
+ if (matchPattern(getSize(), m_Constant(&sizeAttr)))
+ size = sizeAttr.getValues<int32_t>()[0];
+ if (size && size.value() <= 0)
+ return emitOpError("expected positive size, got ") << size.value();
+
+ if (!start || !size)
+ return success();
+
+ const tosa::shapeType inShapeType =
+ cast<tosa::shapeType>(getInput().getType());
+ const int64_t inputRank = inShapeType.getRank();
+ const int64_t sliceSize = start.value() + size.value();
+ if (sliceSize > inputRank)
+ return emitOpError("expected start + size to be less than or equal to "
+ "input shape rank (")
+ << inputRank << "), got " << sliceSize;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 62f66543d74cc..e8a24057a96ac 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -325,10 +325,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// of the data type, meaning any compatible type can be used. No type
// constraint for those operations.
POPULATE_PROFILE_INFO_SKIP(AddShape)
+ POPULATE_PROFILE_INFO_SKIP(ConcatShape)
POPULATE_PROFILE_INFO_SKIP(ConstShape)
POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
POPULATE_PROFILE_INFO_SKIP(MulShape)
+ POPULATE_PROFILE_INFO_SKIP(SliceShape)
POPULATE_PROFILE_INFO_SKIP(SubShape)
POPULATE_PROFILE_INFO_SKIP(Yield)
POPULATE_PROFILE_INFO_SKIP(If)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 897cc87529eca..4ccd7163d4c3d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -129,6 +129,15 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
+static LogicalResult checkConstantOperandSilceShape(Operation *op,
+ const TargetEnv &env) {
+ if (!env.allows(Extension::dynamic) && isa<tosa::SliceShapeOp>(op)) {
+ // Check 'start' and 'size'
+ return checkConstantOperands(op, {1, 2});
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
@@ -177,6 +186,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
constCheckers.emplace_back(checkConstantOperandMatMul);
constCheckers.emplace_back(checkConstantOperandAvgPool2d);
constCheckers.emplace_back(checkConstantOperandNegate);
+ constCheckers.emplace_back(checkConstantOperandSilceShape);
}
LogicalResult levelCheckKernel(Operation *op, int32_t v,
@@ -481,6 +491,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return failure();
}
}
+ if (auto concat_shape = dyn_cast<tosa::ConcatShapeOp>(op))
+ return levelCheckListSize(op, concat_shape.getInput().size(), "input");
return success();
}
@@ -693,9 +705,11 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
// Shape Operators
CHECK_RANKS(AddShape);
+ CHECK_RANKS(ConcatShape);
CHECK_RANKS(DivCeilShape);
CHECK_RANKS(DivFloorShape);
CHECK_RANKS(MulShape);
+ CHECK_RANKS(SliceShape);
CHECK_RANKS(SubShape);
#undef CHECK_RANKS_AND_SIZES
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 60b70b8754611..a1329afc3bb03 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=dynamic,shape" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
@@ -85,3 +85,11 @@ func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
+
+// -----
+
+func.func @test_slice_shape_non_const_start_size(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %3 = tosa.slice_shape %0, %arg0, %arg1 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3d24928487ed2..e8206c24f1507 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,shape" -tosa-validate="strict-op-spec-alignment"
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2067,3 +2067,23 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, ten
%0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
}
+
+// -----
+
+func.func @test_slice_shape_non_const_start(%arg0: tensor<1xi32>) -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.slice_shape' op expected compile time resolvable constant, but got variable value for operand #1}}
+ %3 = tosa.slice_shape %0, %arg0, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
+
+// -----
+
+func.func @test_slice_shape_non_const_size(%arg0: tensor<1xi32>) -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.slice_shape' op expected compile time resolvable constant, but got variable value for operand #2}}
+ %3 = tosa.slice_shape %0, %1, %arg0 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index da21c18e19783..ec540bda5e57d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1690,3 +1690,32 @@ func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
%0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
return %0 : !tosa.shape<1>
}
+
+
+// -----
+
+func.func @test_concat_shape_invalid_list_size() {
+ %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ // expected-error at +1 {{'tosa.concat_shape' op failed level check for MAX_TENSOR_LIST_SIZE: input}}
+ %1 = tosa.concat_shape %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0, %0, %0, %0, %0, %0, %0, %0,
+ %0 :
+ (
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>,
+ !tosa.shape<0>
+ ) -> !tosa.shape<0>
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 4e1935233535d..276eac4d6166d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1433,3 +1433,41 @@ func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
%0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
return %0 : !tosa.shape<1>
}
+
+// -----
+// CHECK-LABEL: test_concat_shape
+func.func @test_concat_shape() -> !tosa.shape<5> {
+ %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %2 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %3 = tosa.concat_shape %0, %1, %2 : (!tosa.shape<1>, !tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<5>
+ return %3 : !tosa.shape<5>
+}
+
+// -----
+// CHECK-LABEL: test_concat_shape_rank_0
+func.func @test_concat_shape_rank_0() -> !tosa.shape<0> {
+ %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ %2 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ %3 = tosa.concat_shape %0, %1, %2 : (!tosa.shape<0>, !tosa.shape<0>, !tosa.shape<0>) -> !tosa.shape<0>
+ return %3 : !tosa.shape<0>
+}
+
+// -----
+// CHECK-LABEL: test_slice_shape
+func.func @test_slice_shape() -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
+
+// -----
+// CHECK-LABEL: test_slice_shape_dynamic
+func.func @test_slice_shape_dynamic(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %3 = tosa.slice_shape %0, %arg0, %arg1 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 7104285af8446..089dee22b2f87 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1270,3 +1270,47 @@ func.func @test_dim_scalar(%arg0: tensor<i32>) -> !tosa.shape<1> {
%0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<i32>) -> !tosa.shape<1>
return %0 : !tosa.shape<1>
}
+
+// -----
+
+func.func @test_concat_shape_rank_mismatch() -> !tosa.shape<4> {
+ %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %2 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.concat_shape' op requires output shape rank to be equal to the sum of the input shape ranks (5), got 4}}
+ %3 = tosa.concat_shape %0, %1, %2 : (!tosa.shape<1>, !tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<4>
+ return %3 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_slice_shape_negative_start() -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.slice_shape' op expected non-negative start index, got -1}}
+ %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
+
+// -----
+
+func.func @test_slice_shape_non_positive_size() -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.slice_shape' op expected positive size, got 0}}
+ %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
+
+// -----
+
+func.func @test_slice_out_of_range() -> !tosa.shape<3> {
+ %0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %1 = "tosa.const"() {values = dense<5> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.slice_shape' op expected start + size to be less than or equal to input shape rank (6), got 8}}
+ %3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
+ return %3 : !tosa.shape<3>
+}
>From d696494525882c9290c28b2bc713831d38f9f002 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 8 Jan 2026 10:05:25 +0000
Subject: [PATCH 2/2] check slice output size matches slice size
Change-Id: I04b6d4b9c658f87432e360f29e5e88a1048abe63
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 ++++++++++++-
mlir/test/Dialect/Tosa/verifier.mlir | 11 +++++++++++
2 files changed, 23 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5fd45ad11de70..5656f3de698c5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4701,7 +4701,18 @@ LogicalResult tosa::SliceShapeOp::verify() {
if (size && size.value() <= 0)
return emitOpError("expected positive size, got ") << size.value();
- if (!start || !size)
+ if (!size)
+ return success();
+
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ const int64_t outputRank = outShapeType.getRank();
+ if (outputRank != size)
+ return emitOpError(
+ "expected output type size to be equal to size attribute, got ")
+ << outputRank << " vs " << size.value();
+
+ if (!start)
return success();
const tosa::shapeType inShapeType =
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 089dee22b2f87..a51ed4f09400f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1314,3 +1314,14 @@ func.func @test_slice_out_of_range() -> !tosa.shape<3> {
%3 = tosa.slice_shape %0, %1, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
return %3 : !tosa.shape<3>
}
+
+// -----
+
+func.func @test_slice_shape_incorrect_output_size() -> !tosa.shape<4> {
+ %shape = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %start = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %size = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.slice_shape' op expected output type size to be equal to size attribute, got 4 vs 3}}
+ %slice = tosa.slice_shape %shape, %start, %size : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<4>
+ return %slice : !tosa.shape<4>
+}
More information about the Mlir-commits
mailing list