[Mlir-commits] [mlir] [mlir][tosa] Add dim operation (PR #169368)

Luke Hutton llvmlistbot at llvm.org
Fri Dec 19 11:06:50 PST 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/169368

>From 67e936060cbf391a0d79816250dae4faa26eed60 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 12 Nov 2025 22:27:54 +0000
Subject: [PATCH] [mlir][tosa] Add dim operation

This commit adds the ext-shape operation `DIM`. This includes
the definition, verifier and validation. It does not currently
include support for folding or shape inference. This will
be added in a later commit.

See https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce
for the specification change.

Change-Id: I67ac3d3c26ea9ab150854b0e06916f64896792c7
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  | 28 ++++++++++++++-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  8 +----
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      | 20 +++++++++++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 35 ++++++++++---------
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  7 ++++
 .../Tosa/Transforms/TosaValidation.cpp        |  2 ++
 mlir/test/Dialect/Tosa/level_check.mlir       |  8 +++++
 mlir/test/Dialect/Tosa/ops.mlir               |  7 ++++
 .../tosa-validation-version-1p1-valid.mlir    |  8 +++++
 mlir/test/Dialect/Tosa/verifier.mlir          | 24 +++++++++++++
 10 files changed, 123 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..fd55cd82b8663 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -941,6 +941,32 @@ extensionComplianceMap = {
        {{{i8T}, SpecificationVersion::V_1_0},
         {{fp16T}, SpecificationVersion::V_1_0},
         {{fp32T}, SpecificationVersion::V_1_0}}}}},
-};
+    {"tosa.dim",
+     {{{Extension::shape},
+       {{{boolT}, SpecificationVersion::V_1_1_DRAFT},
+        {{i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3, Extension::shape},
+       {{{fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::shape},
+       {{{fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::shape},
+       {{{bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::mxfp, Extension::shape},
+       {{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::int64, Extension::shape},
+       {{{i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}}};
 
 // End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 2d4e7cf8b9dbd..c5196349dbb1a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -124,15 +124,9 @@ class TosaResolvableShapeOperands
   }
 };
 
-LogicalResult verifyTosaShapeOperator(Operation *op);
 /// This class indicates that op operates on tosa shape types
 template <typename ConcreteType>
-class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
-public:
-  static LogicalResult verifyTrait(Operation *op) {
-    return verifyTosaShapeOperator(op);
-  }
-};
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {};
 
 LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
 /// This class indicates that op operates on tosa shape types
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d5a46b1b34312..182504676710f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -99,6 +99,26 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: Dim
+//===----------------------------------------------------------------------===//
+def Tosa_DimOp : Tosa_ShapeOp<"dim", [Pure]> {
+  let summary = "Extract size of dimension from input tensor.";
+
+  let description = [{
+    Returns a length 1 shape_t of the size of the input tensor for the given axis.
+  }];
+
+  let arguments = (ins
+    Tosa_TensorAtLeast1D:$input1,
+    I32Attr:$axis
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: DivCeilShape
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..c9ef172e4fced 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4602,24 +4602,9 @@ LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
   return success();
 }
 
-LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
-  for (auto type : op->getOperandTypes()) {
-    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
-      return op->emitOpError("must have operands with tosa shape type");
-    }
-  }
-  for (auto type : op->getResultTypes()) {
-    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
-      return op->emitOpError("must have result with tosa shape type");
-    }
-  }
-  return success();
-}
-
 LogicalResult
 OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
-  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
-      failed(verifyTosaShapeOperator(op)))
+  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
     return failure();
 
   // delegate function that returns rank of shape type
@@ -4663,6 +4648,24 @@ LogicalResult tosa::ConstShapeOp::verify() {
   return success();
 }
 
+LogicalResult tosa::DimOp::verify() {
+  const tosa::shapeType outShapeType =
+      cast<tosa::shapeType>(getResult().getType());
+  if (outShapeType.getRank() != 1)
+    return emitOpError("expect output shape type to contain one element, got ")
+           << outShapeType;
+
+  const ShapeAdaptor inputType(getInput1().getType());
+  if (inputType.hasRank()) {
+    const int64_t inputRank = inputType.getRank();
+    const int64_t axis = getAxisAttr().getInt();
+    if (axis < 0 || axis >= inputRank)
+      return emitOpError("expect axis to be in the range [0, ")
+             << inputRank << "), got " << axis;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Attribute Definitions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index c9150d5b34d00..62f66543d74cc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -217,6 +217,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
   return success();
 }
 
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
+  addValue(op.getInput1());
+  return success();
+}
+
 LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 // This helper function only populates the info for the customised operands.
 #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)                                   \
@@ -256,6 +262,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(MatMul)
   POPULATE_PROFILE_INFO_CUSTOM(Variable)
   POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
+  POPULATE_PROFILE_INFO_CUSTOM(Dim)
 
   // For the most of tosa operators, all operands are profile/extension related
   // and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 530c6ae85287c..897cc87529eca 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -659,6 +659,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(Variable);
   CHECK_RANKS_AND_SIZES(VariableWrite);
   CHECK_RANKS_AND_SIZES(VariableRead);
+  // Shape Operators
+  CHECK_RANKS_AND_SIZES(Dim);
 
   // For the following operators, check whether the size of each tensor
   // operand is valid in a given Level.
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 213c4ae054c51..da21c18e19783 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1682,3 +1682,11 @@ func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
   %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
   return %c : !tosa.shape<7>
 }
+
+// -----
+
+func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
+  // expected-error at +1 {{'tosa.dim' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
+  return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b9e4d18156898..4e1935233535d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1426,3 +1426,10 @@ func.func @test_div_floor_shape() -> !tosa.shape<4> {
   %c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
   return %c : !tosa.shape<4>
 }
+
+// -----
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+  %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+  return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 10d322cf64fb7..63379ed8d8a4d 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -158,3 +158,11 @@ func.func @test_add_shape() -> !tosa.shape<4> {
   %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
   return %c : !tosa.shape<4>
 }
+
+// -----
+
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+  %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+  return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index d73650ddd0563..7104285af8446 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1246,3 +1246,27 @@ func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4
   %0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
   return %0 : !tosa.shape<3>
 }
+
+// -----
+
+func.func @test_dim_invalid_output_rank(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<2> {
+  // expected-error at +1 {{'tosa.dim' op expect output shape type to contain one element, got '!tosa.shape<2>'}}
+  %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<2>
+  return %0 : !tosa.shape<2>
+}
+
+// -----
+
+func.func @test_dim_invalid_axis(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<1> {
+  // expected-error at +1 {{'tosa.dim' op expect axis to be in the range [0, 3), got 4}}
+  %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<1>
+  return %0 : !tosa.shape<1>
+}
+
+// -----
+
+func.func @test_dim_scalar(%arg0: tensor<i32>) -> !tosa.shape<1> {
+  // expected-error at +1 {{'tosa.dim' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+  %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<i32>) -> !tosa.shape<1>
+  return %0 : !tosa.shape<1>
+}



More information about the Mlir-commits mailing list