[Mlir-commits] [mlir] [mlir][tosa] Add add/sub/mul/div_floor/div_ceil_shape ops (PR #169321)

Luke Hutton llvmlistbot at llvm.org
Wed Dec 17 01:47:27 PST 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/169321

>From 85ce20eff20fd3cefa6f38a8cecc3de503d30a38 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 10 Nov 2025 18:30:03 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add add/sub/mul/div_floor/div_ceil_shape ops

Adds initial support for the ext-shape extension, including
the operations:
- ADD_SHAPE
- SUB_SHAPE
- MUL_SHAPE
- DIV_FLOOR_SHAPE
- DIV_CEIL_SHAPE
to align with the spec change:
https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce.

This includes the operator definition, same rank checks
and level checks during validation. It does not currently
include support for folding or shape inference. This will
be added in a later commit.

Change-Id: I544af295552b9a9fecaba50b6131d7876113e47c
---
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   6 +-
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   |   1 +
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      | 123 +++++++++++++++++-
 mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp        |   1 +
 .../Tosa/Transforms/TosaProfileCompliance.cpp |   5 +
 .../Tosa/Transforms/TosaValidation.cpp        |  18 ++-
 mlir/test/Dialect/Tosa/invalid_extension.mlir |  10 ++
 mlir/test/Dialect/Tosa/level_check.mlir       |  22 +++-
 mlir/test/Dialect/Tosa/ops.mlir               |  45 +++++++
 .../tosa-validation-version-1p1-valid.mlir    |  12 +-
 mlir/test/Dialect/Tosa/verifier.mlir          |  16 +++
 11 files changed, 245 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..419340256fa59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 // DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 // MXFP         : Microscaling formats.
+// SHAPE        : Shape calcuation operators.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 def Tosa_EXT_MXFP         : I32EnumAttrCase<"mxfp", 12>;
 def Tosa_EXT_INT64        : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_SHAPE        : I32EnumAttrCase<"shape", 14>;
 
 
 def Tosa_ExtensionAttr
@@ -281,7 +283,7 @@ def Tosa_ExtensionAttr
       Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
       Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
       Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
-      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_SHAPE,
     ]> {
   let extraClassDeclaration = [{
     static llvm::SmallVector<Extension, 13> getAllValues() {
@@ -290,7 +292,7 @@ def Tosa_ExtensionAttr
         Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
         Extension::variable, Extension::controlflow, Extension::doubleround,
         Extension::inexactround, Extension::dynamic, Extension::mxfp,
-        Extension::int64
+        Extension::int64, Extension::shape
       };
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..bee253689bab7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -154,6 +154,7 @@ class TosaProfileCompliance {
     case Extension::controlflow:
     case Extension::dynamic:
     case Extension::int64:
+    case Extension::shape:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 90cda42d95624..7b1c7e208ebe3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -30,15 +30,8 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
 
 class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
     : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
-  ];
-
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
-
-  let hasFolder = 1;
 }
 
 // op trait: shape operator has same ranks for operands and results
@@ -53,6 +46,29 @@ class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
 }
 
 
+//===----------------------------------------------------------------------===//
+// Operator: AddShape
+//===----------------------------------------------------------------------===//
+def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
+  let summary = "Elementwise addition of shapes.";
+
+  let description = [{
+      Elementwise addition of input1 and input2. Size of shapes must match.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: ConstShape
 //===----------------------------------------------------------------------===//
@@ -80,6 +96,99 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
   ];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: DivCeilShape
+//===----------------------------------------------------------------------===//
+def Tosa_DivCeilShapeOp : Tosa_ElementwiseShapeOp<"div_ceil_shape", [Pure]> {
+  let summary = "Elementwise ceiling divide of shapes.";
+
+  let description = [{
+      Elementwise divide of input1 by input2. The result of the divide is rounded up.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: DivFloorShape
+//===----------------------------------------------------------------------===//
+def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> {
+  let summary = "Elementwise floor divide of shapes.";
+
+  let description = [{
+      Elementwise integer divide of input1 by input2. The result of the divide is rounded down.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: MulShape
+//===----------------------------------------------------------------------===//
+def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> {
+  let summary = "Elementwise multiplication of shapes.";
+
+  let description = [{
+      Elementwise multiplication of input1 and input2.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: SubShape
+//===----------------------------------------------------------------------===//
+def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> {
+  let summary = "Elementwise subtraction of shapes.";
+
+  let description = [{
+      Elementwise subtraction of input1 and input2. Size of shapes must match.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>,
+  ];
 }
 
 #endif // TOSA_SHAPE_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..01f78f86d427b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
     return TosaSpecificationVersion(1, 0);
   case Extension::mxfp:
   case Extension::int64:
+  case Extension::shape:
     return TosaSpecificationVersion(1, 1);
   case Extension::none:
     return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ddd9c70402fdc..c9150d5b34d00 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -317,7 +317,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   // Type Invariant Extension, a capability extension that is independent
   // of the data type, meaning any compatible type can be used. No type
   // constraint for those operations.
+  POPULATE_PROFILE_INFO_SKIP(AddShape)
   POPULATE_PROFILE_INFO_SKIP(ConstShape)
+  POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
+  POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
+  POPULATE_PROFILE_INFO_SKIP(MulShape)
+  POPULATE_PROFILE_INFO_SKIP(SubShape)
   POPULATE_PROFILE_INFO_SKIP(Yield)
   POPULATE_PROFILE_INFO_SKIP(If)
   POPULATE_PROFILE_INFO_SKIP(While)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 9c7bc83f77ec7..0f3c3a4c9abb3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -218,6 +218,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
       if (type.getRank() > highest_rank)
         return op->emitOpError() << "failed level check: " << operandOrResult
                                  << " rank(shape) <= MAX_RANK";
+    } else if (tosa::shapeType shapeType =
+                   dyn_cast<tosa::shapeType>(typeToCheck)) {
+      if (shapeType.getRank() > highest_rank)
+        return op->emitOpError()
+               << "failed shape type level check: " << typeToCheck
+               << " exceeds MAX_RANK";
     }
     return success();
   }
@@ -638,15 +644,21 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
   CHECK_RANKS_AND_SIZES(CastToBlockScaled);
   CHECK_RANKS_AND_SIZES(Rescale);
+  // Data Nodes
+  CHECK_RANKS_AND_SIZES(Const);
+  CHECK_RANKS_AND_SIZES(Identity);
   // Control Flow Operators
   CHECK_RANKS_AND_SIZES(If);
   // Variable Operators
   CHECK_RANKS_AND_SIZES(Variable);
   CHECK_RANKS_AND_SIZES(VariableWrite);
   CHECK_RANKS_AND_SIZES(VariableRead);
-  // Data Nodes
-  CHECK_RANKS_AND_SIZES(Const);
-  CHECK_RANKS_AND_SIZES(Identity);
+  // Shape Operators
+  CHECK_RANKS_AND_SIZES(AddShape);
+  CHECK_RANKS_AND_SIZES(DivCeilShape);
+  CHECK_RANKS_AND_SIZES(DivFloorShape);
+  CHECK_RANKS_AND_SIZES(MulShape);
+  CHECK_RANKS_AND_SIZES(SubShape);
 
   // For the following operators, check whether the size of each tensor
   // operand is valid in a given Level.
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 177192ba5440d..0daa0c52941e0 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -577,3 +577,13 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_mul_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.mul_shape' op illegal: requires [shape] but not enabled in target}}
+  %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index a7087647e542b..213c4ae054c51 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
 
 func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
   %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
+  // expected-error at +1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
   %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
   return %0 : tensor<1x1x1x1x1x1x819xf32>
 }
@@ -1662,3 +1662,23 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
+  // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
+  %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
+  return %c : !tosa.shape<13>
+}
+
+// -----
+
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
+  return %c : !tosa.shape<7>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 652447bd6056e..b9e4d18156898 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1381,3 +1381,48 @@ func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
     %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
     return %0 : tensor<2x!tosa.mxint8>
 }
+
+// -----
+// CHECK-LABEL: test_add_shape
+func.func @test_add_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_sub_shape
+func.func @test_sub_shape() -> !tosa.shape<3> {
+  %a = tosa.const_shape {values = dense<[10, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %b = tosa.const_shape {values = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %c = tosa.sub_shape %a, %b : (!tosa.shape<3>, !tosa.shape<3>) -> !tosa.shape<3>
+  return %c : !tosa.shape<3>
+}
+
+// -----
+// CHECK-LABEL: test_mul_shape
+func.func @test_mul_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[2, 3, 4, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[7, 0, 2, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_div_ceil_shape
+func.func @test_div_ceil_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_div_floor_shape
+func.func @test_div_floor_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index e8faeca6f9b03..10d322cf64fb7 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,shape" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
 
 // -----
 
@@ -148,3 +148,13 @@ func.func @test_dynamic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
   %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
   return %0 : tensor<?x16xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_add_shape
+func.func @test_add_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index ea64d468f151e..d73650ddd0563 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1230,3 +1230,19 @@ func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32,
     %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
     return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
 }
+
+// -----
+
+func.func @test_elementwise_shape_op_same_inputs_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<3>) -> !tosa.shape<4> {
+  // expected-error at +1 {{'tosa.add_shape' op operands don't have matching ranks}}
+  %0 = tosa.add_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<3>) -> !tosa.shape<4>
+  return %0 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<4>) -> !tosa.shape<3> {
+  // expected-error at +1 {{'tosa.div_floor_shape' op result shape has different rank than operands}}
+  %0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
+  return %0 : !tosa.shape<3>
+}

>From 815239e0918d349b54b762b3cd4519dd20c3a701 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 15 Dec 2025 18:21:59 +0000
Subject: [PATCH 2/2] address comments

Change-Id: I1aed00c170d35f555d26b8ba736402ecd52d751f
---
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      | 39 +++++--------------
 .../Tosa/Transforms/TosaValidation.cpp        | 23 ++++++++---
 2 files changed, 26 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 7b1c7e208ebe3..40e8ea2b40882 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -32,6 +32,11 @@ class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
     : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[]>,
+  ];
 }
 
 // op trait: shape operator has same ranks for operands and results
@@ -43,6 +48,10 @@ def TosaShapeOperatorWithSameRanks
 class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
     : Tosa_ShapeOp<mnemonic,
                    !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>,
+  ];
 }
 
 
@@ -62,11 +71,6 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
-
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[Tosa_EXT_SHAPE]>,
-  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -90,11 +94,6 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
 
   let results = (outs Tosa_Shape:$output);
 
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
-  ];
-
   let hasVerifier = 1;
   let hasFolder = 1;
 }
@@ -115,11 +114,6 @@ def Tosa_DivCeilShapeOp : Tosa_ElementwiseShapeOp<"div_ceil_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
-
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[Tosa_EXT_SHAPE]>
-  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -138,11 +132,6 @@ def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
-
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[Tosa_EXT_SHAPE]>
-  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -161,11 +150,6 @@ def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
-
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[Tosa_EXT_SHAPE]>
-  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -184,11 +168,6 @@ def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> {
   );
 
   let results = (outs Tosa_Shape:$output);
-
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[Tosa_EXT_SHAPE]>,
-  ];
 }
 
 #endif // TOSA_SHAPE_OPS
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 0f3c3a4c9abb3..530c6ae85287c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -579,6 +579,12 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
       return failure();                                                        \
   }
 
+#define CHECK_RANKS(tosaOp)                                                    \
+  if (isa<tosa::tosaOp##Op>(op)) {                                             \
+    if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op))))                   \
+      return failure();                                                        \
+  }
+
   // Tensor Operators
   CHECK_RANKS_AND_SIZES(ArgMax);
   // Activation Functions
@@ -653,12 +659,6 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(Variable);
   CHECK_RANKS_AND_SIZES(VariableWrite);
   CHECK_RANKS_AND_SIZES(VariableRead);
-  // Shape Operators
-  CHECK_RANKS_AND_SIZES(AddShape);
-  CHECK_RANKS_AND_SIZES(DivCeilShape);
-  CHECK_RANKS_AND_SIZES(DivFloorShape);
-  CHECK_RANKS_AND_SIZES(MulShape);
-  CHECK_RANKS_AND_SIZES(SubShape);
 
   // For the following operators, check whether the size of each tensor
   // operand is valid in a given Level.
@@ -686,8 +686,19 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   // Shape Operators
   CHECK_SIZES(ConstShape);
 
+  // For the following operations, check whether the rank of each operand
+  // is valid given a level.
+
+  // Shape Operators
+  CHECK_RANKS(AddShape);
+  CHECK_RANKS(DivCeilShape);
+  CHECK_RANKS(DivFloorShape);
+  CHECK_RANKS(MulShape);
+  CHECK_RANKS(SubShape);
+
 #undef CHECK_RANKS_AND_SIZES
 #undef CHECK_SIZES
+#undef CHECK_RANKS
   return success();
 }
 



More information about the Mlir-commits mailing list