[Mlir-commits] [mlir] Dim shape op (PR #169368)
Luke Hutton
llvmlistbot at llvm.org
Mon Nov 24 09:10:18 PST 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/169368
This commit adds the ext-shape operation `DIM`. This includes the definition, verifier and validation. It does not currently include support for folding or shape inference. This will be added in a later commit.
See https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce for the specification change.
Based on work originally implemented by @Tai78641.
Note: this change is dependent on https://github.com/llvm/llvm-project/pull/169321 so also contains its contents.
>From 4d44aaa63b49390ae8b379a3458efd68e058312e Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 10 Nov 2025 18:30:03 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add add/sub/mul/div_floor/div_ceil_shape ops
Adds initial support for the ext-shape extension, including
the operations:
- ADD_SHAPE
- SUB_SHAPE
- MUL_SHAPE
- DIV_FLOOR_SHAPE
- DIV_CEIL_SHAPE
to align with the spec change:
https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce.
This includes the operator definition, same rank checks
and level checks during validation. It does not currently
include support for folding or shape inference. This will
be added in a later commit.
Change-Id: I544af295552b9a9fecaba50b6131d7876113e47c
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 6 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 1 +
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 123 +++++++++++++++++-
mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 1 +
.../Tosa/Transforms/TosaProfileCompliance.cpp | 5 +
.../Tosa/Transforms/TosaValidation.cpp | 18 ++-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 10 ++
mlir/test/Dialect/Tosa/level_check.mlir | 22 +++-
mlir/test/Dialect/Tosa/ops.mlir | 45 +++++++
.../tosa-validation-version-1p1-valid.mlir | 11 +-
mlir/test/Dialect/Tosa/verifier.mlir | 16 +++
11 files changed, 244 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..419340256fa59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
// MXFP : Microscaling formats.
+// SHAPE : Shape calcuation operators.
//===----------------------------------------------------------------------===//
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;
def Tosa_EXT_INT64 : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_SHAPE : I32EnumAttrCase<"shape", 14>;
def Tosa_ExtensionAttr
@@ -281,7 +283,7 @@ def Tosa_ExtensionAttr
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
- Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+ Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_SHAPE,
]> {
let extraClassDeclaration = [{
static llvm::SmallVector<Extension, 13> getAllValues() {
@@ -290,7 +292,7 @@ def Tosa_ExtensionAttr
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
Extension::variable, Extension::controlflow, Extension::doubleround,
Extension::inexactround, Extension::dynamic, Extension::mxfp,
- Extension::int64
+ Extension::int64, Extension::shape
};
}
}];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..bee253689bab7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -154,6 +154,7 @@ class TosaProfileCompliance {
case Extension::controlflow:
case Extension::dynamic:
case Extension::int64:
+ case Extension::shape:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 90cda42d95624..7b1c7e208ebe3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -30,15 +30,8 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[]>,
- ];
-
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
-
- let hasFolder = 1;
}
// op trait: shape operator has same ranks for operands and results
@@ -53,6 +46,29 @@ class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
}
+//===----------------------------------------------------------------------===//
+// Operator: AddShape
+//===----------------------------------------------------------------------===//
+def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
+ let summary = "Elementwise addition of shapes.";
+
+ let description = [{
+ Elementwise addition of input1 and input2. Size of shapes must match.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input1,
+ Tosa_Shape:$input2
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// Operator: ConstShape
//===----------------------------------------------------------------------===//
@@ -80,6 +96,99 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
];
let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: DivCeilShape
+//===----------------------------------------------------------------------===//
+def Tosa_DivCeilShapeOp : Tosa_ElementwiseShapeOp<"div_ceil_shape", [Pure]> {
+ let summary = "Elementwise ceiling divide of shapes.";
+
+ let description = [{
+ Elementwise divide of input1 by input2. The result of the divide is rounded up.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input1,
+ Tosa_Shape:$input2
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: DivFloorShape
+//===----------------------------------------------------------------------===//
+def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> {
+ let summary = "Elementwise floor divide of shapes.";
+
+ let description = [{
+ Elementwise integer divide of input1 by input2. The result of the divide is rounded down.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input1,
+ Tosa_Shape:$input2
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: MulShape
+//===----------------------------------------------------------------------===//
+def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> {
+ let summary = "Elementwise multiplication of shapes.";
+
+ let description = [{
+ Elementwise multiplication of input1 and input2.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input1,
+ Tosa_Shape:$input2
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: SubShape
+//===----------------------------------------------------------------------===//
+def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> {
+ let summary = "Elementwise subtraction of shapes.";
+
+ let description = [{
+ Elementwise subtraction of input1 and input2. Size of shapes must match.
+ }];
+
+ let arguments = (ins
+ Tosa_Shape:$input1,
+ Tosa_Shape:$input2
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_SHAPE]>,
+ ];
}
#endif // TOSA_SHAPE_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..01f78f86d427b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
return TosaSpecificationVersion(1, 0);
case Extension::mxfp:
case Extension::int64:
+ case Extension::shape:
return TosaSpecificationVersion(1, 1);
case Extension::none:
return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ddd9c70402fdc..c9150d5b34d00 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -317,7 +317,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// Type Invariant Extension, a capability extension that is independent
// of the data type, meaning any compatible type can be used. No type
// constraint for those operations.
+ POPULATE_PROFILE_INFO_SKIP(AddShape)
POPULATE_PROFILE_INFO_SKIP(ConstShape)
+ POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
+ POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
+ POPULATE_PROFILE_INFO_SKIP(MulShape)
+ POPULATE_PROFILE_INFO_SKIP(SubShape)
POPULATE_PROFILE_INFO_SKIP(Yield)
POPULATE_PROFILE_INFO_SKIP(If)
POPULATE_PROFILE_INFO_SKIP(While)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b54ed5585d72d..421ef237e628f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -218,6 +218,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (type.getRank() > highest_rank)
return op->emitOpError() << "failed level check: " << operandOrResult
<< " rank(shape) <= MAX_RANK";
+ } else if (tosa::shapeType shapeType =
+ dyn_cast<tosa::shapeType>(typeToCheck)) {
+ if (shapeType.getRank() > highest_rank)
+ return op->emitOpError()
+ << "failed shape type level check: " << typeToCheck
+ << " exceeds MAX_RANK";
}
return success();
}
@@ -638,15 +644,21 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
+ // Data Nodes
+ CHECK_RANKS_AND_SIZES(Const);
+ CHECK_RANKS_AND_SIZES(Identity);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
// Variable Operators
CHECK_RANKS_AND_SIZES(Variable);
CHECK_RANKS_AND_SIZES(VariableWrite);
CHECK_RANKS_AND_SIZES(VariableRead);
- // Data Nodes
- CHECK_RANKS_AND_SIZES(Const);
- CHECK_RANKS_AND_SIZES(Identity);
+ // Shape Operators
+ CHECK_RANKS_AND_SIZES(AddShape);
+ CHECK_RANKS_AND_SIZES(DivCeilShape);
+ CHECK_RANKS_AND_SIZES(DivFloorShape);
+ CHECK_RANKS_AND_SIZES(MulShape);
+ CHECK_RANKS_AND_SIZES(SubShape);
// For the following operators, check whether the size of each tensor
// operand is valid in a given Level.
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 68a95787b81c7..a06406fcdab1f 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -584,3 +584,13 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
}
+
+// -----
+
+func.func @test_mul_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.mul_shape' op illegal: requires [shape] but not enabled in target}}
+ %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index a7087647e542b..213c4ae054c51 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
%1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
+ // expected-error at +1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
%0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
return %0 : tensor<1x1x1x1x1x1x819xf32>
}
@@ -1662,3 +1662,23 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
}
+
+// -----
+
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
+ %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
+ // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
+ %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
+ return %c : !tosa.shape<13>
+}
+
+// -----
+
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+ // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+ %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
+ return %c : !tosa.shape<7>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4591f7ffd393..2c4ec857ad20e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1374,3 +1374,48 @@ func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
%0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
return %0 : tensor<2x!tosa.mxint8>
}
+
+// -----
+// CHECK-LABEL: test_add_shape
+func.func @test_add_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_sub_shape
+func.func @test_sub_shape() -> !tosa.shape<3> {
+ %a = tosa.const_shape {values = dense<[10, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %b = tosa.const_shape {values = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %c = tosa.sub_shape %a, %b : (!tosa.shape<3>, !tosa.shape<3>) -> !tosa.shape<3>
+ return %c : !tosa.shape<3>
+}
+
+// -----
+// CHECK-LABEL: test_mul_shape
+func.func @test_mul_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[2, 3, 4, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.const_shape {values = dense<[7, 0, 2, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_div_ceil_shape
+func.func @test_div_ceil_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_div_floor_shape
+func.func @test_div_floor_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index c285ae3cf44ee..66a94559348a8 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,shape" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
@@ -156,3 +156,12 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
+
+// -----
+// CHECK-LABEL: test_add_shape
+func.func @test_add_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+ return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cdc7ad8e..a70709b4ecc6a 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,19 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
}
+
+// -----
+
+func.func @test_elementwise_shape_op_same_inputs_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<3>) -> !tosa.shape<4> {
+ // expected-error at +1 {{'tosa.add_shape' op operands don't have matching ranks}}
+ %0 = tosa.add_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<3>) -> !tosa.shape<4>
+ return %0 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<4>) -> !tosa.shape<3> {
+ // expected-error at +1 {{'tosa.div_floor_shape' op result shape has different rank than operands}}
+ %0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
+ return %0 : !tosa.shape<3>
+}
>From 91b0ca1fcd07d4804746c7f75b6a559453cc0c94 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 12 Nov 2025 22:27:54 +0000
Subject: [PATCH 2/2] [mlir][tosa] Add dim operation
This commit adds the ext-shape operation `DIM`. This includes
the definition, verifier and validation. It does not currently
include support for folding or shape inference. This will
be added in a later commit.
See https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce
for the specification change.
Change-Id: I67ac3d3c26ea9ab150854b0e06916f64896792c7
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 22 +++++++++++++++
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 28 +++++++++++++++++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 18 ++++++++++++
.../Tosa/Transforms/TosaProfileCompliance.cpp | 7 +++++
.../Tosa/Transforms/TosaValidation.cpp | 1 +
mlir/test/Dialect/Tosa/level_check.mlir | 8 ++++++
mlir/test/Dialect/Tosa/ops.mlir | 7 +++++
.../tosa-validation-version-1p1-valid.mlir | 8 ++++++
mlir/test/Dialect/Tosa/verifier.mlir | 24 ++++++++++++++++
9 files changed, 123 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 0005402cd1f44..3b98b36dbd654 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -949,6 +949,28 @@ extensionComplianceMap = {
{{{i8T}, SpecificationVersion::V_1_0},
{{fp16T}, SpecificationVersion::V_1_0},
{{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.dim",
+ {{{Extension::shape},
+ {{{boolT}, SpecificationVersion::V_1_1_DRAFT},
+ {{i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3, Extension::shape},
+ {{{fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::fp8e5m2, Extension::shape},
+ {{{fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::bf16, Extension::shape},
+ {{{bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::mxfp, Extension::shape},
+ {{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::int64, Extension::shape},
+ {{{i64T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}}
};
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 7b1c7e208ebe3..9cea37b37eee1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -99,6 +99,34 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: Dim
+//===----------------------------------------------------------------------===//
+def Tosa_DimOp : Tosa_Op<"dim", [Pure]> {
+ let summary = "Extract size of dimension from input tensor.";
+
+ let description = [{
+ Returns a length 1 shape_t of the size of the input tensor for the given axis.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorAtLeast1D:$input1,
+ I32Attr:$axis
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
+
+ list<Availability> availability = [
+ Profile<[]>,
+ Extension<[Tosa_EXT_SHAPE]>
+ ];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: DivCeilShape
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 65e0a59d39168..3e7791af3b5b2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4664,6 +4664,24 @@ LogicalResult tosa::ConstShapeOp::verify() {
return success();
}
+LogicalResult tosa::DimOp::verify() {
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ if (outShapeType.getRank() != 1)
+ return emitOpError("expect output shape type to contain one element, got ")
+ << outShapeType;
+
+ const ShapeAdaptor inputType(getInput1().getType());
+ if (inputType.hasRank()) {
+ const int64_t inputRank = inputType.getRank();
+ const int64_t axis = getAxisAttr().getInt();
+ if (axis < 0 || axis >= inputRank)
+ return emitOpError("expect axis to be in the range [0, ")
+ << inputRank << "), got " << axis;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index c9150d5b34d00..62f66543d74cc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -217,6 +217,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
return success();
}
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
+ addValue(op.getInput1());
+ return success();
+}
+
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -256,6 +262,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
+ POPULATE_PROFILE_INFO_CUSTOM(Dim)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 421ef237e628f..0ec103a2ce73a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -655,6 +655,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(VariableRead);
// Shape Operators
CHECK_RANKS_AND_SIZES(AddShape);
+ CHECK_RANKS_AND_SIZES(Dim);
CHECK_RANKS_AND_SIZES(DivCeilShape);
CHECK_RANKS_AND_SIZES(DivFloorShape);
CHECK_RANKS_AND_SIZES(MulShape);
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 213c4ae054c51..da21c18e19783 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1682,3 +1682,11 @@ func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
return %c : !tosa.shape<7>
}
+
+// -----
+
+func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 2c4ec857ad20e..2e1fb6f1f24f6 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1419,3 +1419,10 @@ func.func @test_div_floor_shape() -> !tosa.shape<4> {
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
return %c : !tosa.shape<4>
}
+
+// -----
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 66a94559348a8..c38cd435aec88 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -165,3 +165,11 @@ func.func @test_add_shape() -> !tosa.shape<4> {
%c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
return %c : !tosa.shape<4>
}
+
+// -----
+
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index a70709b4ecc6a..56688544d990f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1238,3 +1238,27 @@ func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4
%0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
return %0 : !tosa.shape<3>
}
+
+// -----
+
+func.func @test_dim_invalid_output_rank(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<2> {
+ // expected-error at +1 {{'tosa.dim' op expect output shape type to contain one element, got '!tosa.shape<2>'}}
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<2>
+ return %0 : !tosa.shape<2>
+}
+
+// -----
+
+func.func @test_dim_invalid_axis(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op expect axis to be in the range [0, 3), got 4}}
+ %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
+
+// -----
+
+func.func @test_dim_scalar(%arg0: tensor<i32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+ %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<i32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
More information about the Mlir-commits
mailing list