[Mlir-commits] [mlir] [mlir][tosa] Add support for assert equal shape op (PR #176900)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 20 03:28:23 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Iliyan Georgiev (iliyan-georgiev-arm)

<details>
<summary>Changes</summary>

Adds support for assert_equal_shape operation after spec change: https://github.com/arm/tosa-specification/commit/575a50016de50d227eb517775eb4e7b137421fa1

This includes:
- Operator definition
- Tests


Change-Id: I6652bbcbd5e3716f140681b9d73ef8940564d7d3

---
Full diff: https://github.com/llvm/llvm-project/pull/176900.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+19-1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+6-4) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+10) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+9) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+8) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d8597151714c3..8fd176d3ea390 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -29,7 +29,7 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
 }
 
 class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
-    : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+    : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator])> {
   list<Availability> availability = [
     Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
     Extension<[Tosa_EXT_SHAPE]>,
@@ -337,4 +337,22 @@ def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: AssertEqualShape
+//===----------------------------------------------------------------------===//
+def Tosa_AssertEqualShapeOp
+    : Tosa_ShapeOp<"assert_equal_shape", [TosaShapeOperatorWithSameRanks]> {
+  let summary = "Verify two shapes are equal.";
+
+  let description = [{
+      Verify input1 and input2 are equal. If allow_broadcast is set, shapes which
+      are broadcast compatible are allowed.
+  }];
+
+  let arguments = (ins Tosa_Shape:$input1, Tosa_Shape:$input2,
+      BoolAttr:$allow_broadcast);
+
+  let results = (outs);
+}
+
 #endif // TOSA_SHAPE_OPS
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f26554fb5768a..f69b287e20a42 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -338,6 +338,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   // of the data type, meaning any compatible type can be used. No type
   // constraint for those operations.
   POPULATE_PROFILE_INFO_SKIP(AddShape)
+  POPULATE_PROFILE_INFO_SKIP(AssertEqualShape)
   POPULATE_PROFILE_INFO_SKIP(ConcatShape)
   POPULATE_PROFILE_INFO_SKIP(ConstShape)
   POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 5f26adabf409c..d4241d1428956 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -293,7 +293,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
     return success();
   }
-
   // Level check shape lengths of all operands and results of an operation that
   // are tosa.shape type.
   template <typename T>
@@ -302,9 +301,11 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
       if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
         return failure();
     }
-    if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(),
-                                     "result")))
-      return failure();
+    for (const auto &v : tosaOp->getResults()) {
+      if (failed(levelCheckShapeLength(tosaOp, v.getType(), "result")))
+        return failure();
+    }
+
     return success();
   }
 
@@ -772,6 +773,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 
   // Shape Operators
   CHECK_SHAPE_LEN(AddShape);
+  CHECK_SHAPE_LEN(AssertEqualShape);
   CHECK_SHAPE_LEN(ConcatShape);
   CHECK_SHAPE_LEN(DivCeilShape);
   CHECK_SHAPE_LEN(DivFloorShape);
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 3ebf0ff8a2f69..fe02864c8b28f 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1826,3 +1826,13 @@ func.func @test_conv2d_block_scaled_stride_y(%arg0: tensor<1x8194x33x32xf8E4M3FN
             (tensor<1x8194x33x32xf8E4M3FN>, tensor<1x8194x33x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x2x32x16xf32>
   return %0 : tensor<1x2x32x16xf32>
 }
+
+// -----
+
+func.func @test_assert_equal_shape_invalid_rank() -> () {
+  %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.assert_equal_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  tosa.assert_equal_shape %a, %b {allow_broadcast = true} : (!tosa.shape<17>, !tosa.shape<17>) -> ()
+  return
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 626d2b6caafd1..8b4b972d2633d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1578,3 +1578,12 @@ func.func @test_conv2d_block_scaled_dynamic(%arg0: tensor<*xf4E2M1FN>, %arg1: te
   %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
   return %3 : tensor<*xf32>
 }
+
+// -----
+// CHECK-LABEL: test_assert_equal_shape
+func.func @test_assert_equal_shape() {
+  %0 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<2>, !tosa.shape<2>) -> ()
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index ef52b90f194de..97fb14927f7e8 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -243,3 +243,11 @@ func.func @test_conv2d_block_scaled(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: ten
   %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
   return %0 : tensor<1x4x4x8xf32>
 }
+
+// CHECK-LABEL: test_assert_equal_shape
+func.func @test_assert_equal_shape() {
+  %0 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<2>, !tosa.shape<2>) -> ()
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index e16a12b94b923..742bae3847da5 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1467,3 +1467,11 @@ func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E
   %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<6xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32>
   return %0 : tensor<1x4x4x8xf32>
 }
+
+func.func @test_missmatched_ranks() {
+  %0 = tosa.const_shape {values = dense<[10]> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape {values = dense<[10, 15]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.assert_equal_shape' op operands don't have matching ranks}}
+  tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<1>, !tosa.shape<2>) -> ()
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/176900


More information about the Mlir-commits mailing list