[Mlir-commits] [mlir] [mlir][tosa] Add dim operation (PR #169368)
Luke Hutton
llvmlistbot at llvm.org
Fri Dec 19 11:02:11 PST 2025
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/169368
>From 8a5e343bf96af760ce639be90f808616a23bbfc1 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 12 Nov 2025 22:27:54 +0000
Subject: [PATCH] [mlir][tosa] Add dim operation
This commit adds the ext-shape operation `DIM`. This includes
the definition, verifier and validation. It does not currently
include support for folding or shape inference. This will
be added in a later commit.
See https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce
for the specification change.
Change-Id: I67ac3d3c26ea9ab150854b0e06916f64896792c7
---
.../Dialect/Tosa/IR/TosaComplianceData.h.inc | 22 ++++++++++++
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 8 +----
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 20 +++++++++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 35 ++++++++++---------
.../Tosa/Transforms/TosaProfileCompliance.cpp | 7 ++++
.../Tosa/Transforms/TosaValidation.cpp | 2 ++
mlir/test/Dialect/Tosa/level_check.mlir | 8 +++++
mlir/test/Dialect/Tosa/ops.mlir | 7 ++++
.../tosa-validation-version-1p1-valid.mlir | 8 +++++
mlir/test/Dialect/Tosa/verifier.mlir | 24 +++++++++++++
10 files changed, 118 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..b7933e4154575 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -941,6 +941,28 @@ extensionComplianceMap = {
{{{i8T}, SpecificationVersion::V_1_0},
{{fp16T}, SpecificationVersion::V_1_0},
{{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.dim",
+ {{{Extension::shape},
+ {{{boolT}, SpecificationVersion::V_1_1_DRAFT},
+ {{i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3, Extension::shape},
+ {{{fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::fp8e5m2, Extension::shape},
+ {{{fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::bf16, Extension::shape},
+ {{{bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::mxfp, Extension::shape},
+ {{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::int64, Extension::shape},
+ {{{i64T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}}
};
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 2d4e7cf8b9dbd..c5196349dbb1a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -124,15 +124,9 @@ class TosaResolvableShapeOperands
}
};
-LogicalResult verifyTosaShapeOperator(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
-class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
-public:
- static LogicalResult verifyTrait(Operation *op) {
- return verifyTosaShapeOperator(op);
- }
-};
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {};
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
/// This class indicates that op operates on tosa shape types
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d5a46b1b34312..182504676710f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -99,6 +99,26 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: Dim
+//===----------------------------------------------------------------------===//
+def Tosa_DimOp : Tosa_ShapeOp<"dim", [Pure]> {
+ let summary = "Extract size of dimension from input tensor.";
+
+ let description = [{
+ Returns a length 1 shape_t of the size of the input tensor for the given axis.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorAtLeast1D:$input1,
+ I32Attr:$axis
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: DivCeilShape
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..c9ef172e4fced 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4602,24 +4602,9 @@ LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
return success();
}
-LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
- for (auto type : op->getOperandTypes()) {
- if (!mlir::isa<mlir::tosa::shapeType>(type)) {
- return op->emitOpError("must have operands with tosa shape type");
- }
- }
- for (auto type : op->getResultTypes()) {
- if (!mlir::isa<mlir::tosa::shapeType>(type)) {
- return op->emitOpError("must have result with tosa shape type");
- }
- }
- return success();
-}
-
LogicalResult
OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
- failed(verifyTosaShapeOperator(op)))
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();
// delegate function that returns rank of shape type
@@ -4663,6 +4648,24 @@ LogicalResult tosa::ConstShapeOp::verify() {
return success();
}
+LogicalResult tosa::DimOp::verify() {
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ if (outShapeType.getRank() != 1)
+ return emitOpError("expect output shape type to contain one element, got ")
+ << outShapeType;
+
+ const ShapeAdaptor inputType(getInput1().getType());
+ if (inputType.hasRank()) {
+ const int64_t inputRank = inputType.getRank();
+ const int64_t axis = getAxisAttr().getInt();
+ if (axis < 0 || axis >= inputRank)
+ return emitOpError("expect axis to be in the range [0, ")
+ << inputRank << "), got " << axis;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index c9150d5b34d00..62f66543d74cc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -217,6 +217,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
return success();
}
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
+ addValue(op.getInput1());
+ return success();
+}
+
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -256,6 +262,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
+ POPULATE_PROFILE_INFO_CUSTOM(Dim)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 530c6ae85287c..897cc87529eca 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -659,6 +659,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Variable);
CHECK_RANKS_AND_SIZES(VariableWrite);
CHECK_RANKS_AND_SIZES(VariableRead);
+ // Shape Operators
+ CHECK_RANKS_AND_SIZES(Dim);
// For the following operators, check whether the size of each tensor
// operand is valid in a given Level.
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 213c4ae054c51..da21c18e19783 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1682,3 +1682,11 @@ func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
return %c : !tosa.shape<7>
}
+
+// -----
+
+func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b9e4d18156898..4e1935233535d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1426,3 +1426,10 @@ func.func @test_div_floor_shape() -> !tosa.shape<4> {
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
return %c : !tosa.shape<4>
}
+
+// -----
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 10d322cf64fb7..63379ed8d8a4d 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -158,3 +158,11 @@ func.func @test_add_shape() -> !tosa.shape<4> {
%c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
return %c : !tosa.shape<4>
}
+
+// -----
+
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index d73650ddd0563..7104285af8446 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1246,3 +1246,27 @@ func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4
%0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
return %0 : !tosa.shape<3>
}
+
+// -----
+
+func.func @test_dim_invalid_output_rank(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<2> {
+ // expected-error at +1 {{'tosa.dim' op expect output shape type to contain one element, got '!tosa.shape<2>'}}
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<2>
+ return %0 : !tosa.shape<2>
+}
+
+// -----
+
+func.func @test_dim_invalid_axis(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op expect axis to be in the range [0, 3), got 4}}
+ %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
+
+// -----
+
+func.func @test_dim_scalar(%arg0: tensor<i32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+ %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<i32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
More information about the Mlir-commits
mailing list