[Mlir-commits] [mlir] [mlir][tosa] Add support for assert equal shape op (PR #176900)
Iliyan Georgiev
llvmlistbot at llvm.org
Tue Jan 20 03:27:50 PST 2026
https://github.com/iliyan-georgiev-arm created https://github.com/llvm/llvm-project/pull/176900
Adds support for assert_equal_shape operation after spec change: https://github.com/arm/tosa-specification/commit/575a50016de50d227eb517775eb4e7b137421fa1
This includes:
- Operator definition
- Tests
Change-Id: I6652bbcbd5e3716f140681b9d73ef8940564d7d3
>From 915cfc8f3d6728408f9d4573e7253556978e6c53 Mon Sep 17 00:00:00 2001
From: Iliyan Georgiev <Iliyan.Georgiev at arm.com>
Date: Fri, 16 Jan 2026 10:01:07 +0000
Subject: [PATCH] [mlir][tosa] Add support for assert equal shape op
Adds support for assert_equal_shape operation after spec change:
https://github.com/arm/tosa-specification/commit/575a50016de50d227eb517775eb4e7b137421fa1
This includes:
- Operator definition
- Tests
Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev at arm.com>
Change-Id: I6652bbcbd5e3716f140681b9d73ef8940564d7d3
---
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 20 ++++++++++++++++++-
.../Tosa/Transforms/TosaProfileCompliance.cpp | 1 +
.../Tosa/Transforms/TosaValidation.cpp | 10 ++++++----
mlir/test/Dialect/Tosa/level_check.mlir | 10 ++++++++++
mlir/test/Dialect/Tosa/ops.mlir | 9 +++++++++
.../tosa-validation-version-1p1-valid.mlir | 8 ++++++++
mlir/test/Dialect/Tosa/verifier.mlir | 8 ++++++++
7 files changed, 61 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d8597151714c3..8fd176d3ea390 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -29,7 +29,7 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
}
class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
- : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+ : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator])> {
list<Availability> availability = [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_SHAPE]>,
@@ -337,4 +337,22 @@ def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> {
let results = (outs Tosa_Shape:$output);
}
+//===----------------------------------------------------------------------===//
+// Operator: AssertEqualShape
+//===----------------------------------------------------------------------===//
+def Tosa_AssertEqualShapeOp
+ : Tosa_ShapeOp<"assert_equal_shape", [TosaShapeOperatorWithSameRanks]> {
+ let summary = "Verify two shapes are equal.";
+
+ let description = [{
+ Verify input1 and input2 are equal. If allow_broadcast is set, shapes which
+ are broadcast compatible are allowed.
+ }];
+
+ let arguments = (ins Tosa_Shape:$input1, Tosa_Shape:$input2,
+ BoolAttr:$allow_broadcast);
+
+ let results = (outs);
+}
+
#endif // TOSA_SHAPE_OPS
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f26554fb5768a..f69b287e20a42 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -338,6 +338,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// of the data type, meaning any compatible type can be used. No type
// constraint for those operations.
POPULATE_PROFILE_INFO_SKIP(AddShape)
+ POPULATE_PROFILE_INFO_SKIP(AssertEqualShape)
POPULATE_PROFILE_INFO_SKIP(ConcatShape)
POPULATE_PROFILE_INFO_SKIP(ConstShape)
POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 5f26adabf409c..d4241d1428956 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -293,7 +293,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
return success();
}
-
// Level check shape lengths of all operands and results of an operation that
// are tosa.shape type.
template <typename T>
@@ -302,9 +301,11 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
return failure();
}
- if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(),
- "result")))
- return failure();
+ for (const auto &v : tosaOp->getResults()) {
+ if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
+ return failure();
+ }
+
return success();
}
@@ -772,6 +773,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
// Shape Operators
CHECK_SHAPE_LEN(AddShape);
+ CHECK_SHAPE_LEN(AssertEqualShape);
CHECK_SHAPE_LEN(ConcatShape);
CHECK_SHAPE_LEN(DivCeilShape);
CHECK_SHAPE_LEN(DivFloorShape);
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 3ebf0ff8a2f69..fe02864c8b28f 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1826,3 +1826,13 @@ func.func @test_conv2d_block_scaled_stride_y(%arg0: tensor<1x8194x33x32xf8E4M3FN
(tensor<1x8194x33x32xf8E4M3FN>, tensor<1x8194x33x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x2x32x16xf32>
return %0 : tensor<1x2x32x16xf32>
}
+
+// -----
+
+func.func @test_assert_equal_shape_invalid_rank() -> () {
+ %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.assert_equal_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ tosa.assert_equal_shape %a, %b {allow_broadcast = true} : (!tosa.shape<17>, !tosa.shape<17>) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 626d2b6caafd1..8b4b972d2633d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1578,3 +1578,12 @@ func.func @test_conv2d_block_scaled_dynamic(%arg0: tensor<*xf4E2M1FN>, %arg1: te
%3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
+
+// -----
+// CHECK-LABEL: test_assert_equal_shape
+func.func @test_assert_equal_shape() {
+ %0 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<2>, !tosa.shape<2>) -> ()
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index ef52b90f194de..97fb14927f7e8 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -243,3 +243,11 @@ func.func @test_conv2d_block_scaled(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: ten
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
+
+// CHECK-LABEL: test_assert_equal_shape
+func.func @test_assert_equal_shape() {
+ %0 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<2>, !tosa.shape<2>) -> ()
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index e16a12b94b923..742bae3847da5 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1467,3 +1467,11 @@ func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<6xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
+
+func.func @test_missmatched_ranks() {
+ %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.assert_equal_shape' op operands don't have matching ranks}}
+ tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<1>, !tosa.shape<2>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list