[Mlir-commits] [mlir] [mlir][tosa] Add dim operation (PR #169368)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 19 11:02:53 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit adds the ext-shape operation `DIM`. This includes the definition, verifier and validation. It does not currently include support for folding or shape inference. This will be added in a later commit.
See https://github.com/arm/tosa-specification/commit/efc88a100e2db06c2d6bc479fa63b26daab899ce for the specification change.
Based on work originally implemented by @<!-- -->Tai78641.
Note: this change is dependent on https://github.com/llvm/llvm-project/pull/169321 so also contains its contents.
---
Full diff: https://github.com/llvm/llvm-project/pull/169368.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+22)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+1-7)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+20)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+19-16)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+7)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+2)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+8)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+7)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+8)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+24)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..b7933e4154575 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -941,6 +941,28 @@ extensionComplianceMap = {
{{{i8T}, SpecificationVersion::V_1_0},
{{fp16T}, SpecificationVersion::V_1_0},
{{fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.dim",
+ {{{Extension::shape},
+ {{{boolT}, SpecificationVersion::V_1_1_DRAFT},
+ {{i8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3, Extension::shape},
+ {{{fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::fp8e5m2, Extension::shape},
+ {{{fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::bf16, Extension::shape},
+ {{{bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::mxfp, Extension::shape},
+ {{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+ {{mxint8T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
+ {{Extension::int64, Extension::shape},
+ {{{i64T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}}
};
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 2d4e7cf8b9dbd..c5196349dbb1a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -124,15 +124,9 @@ class TosaResolvableShapeOperands
}
};
-LogicalResult verifyTosaShapeOperator(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
-class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
-public:
- static LogicalResult verifyTrait(Operation *op) {
- return verifyTosaShapeOperator(op);
- }
-};
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {};
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
/// This class indicates that op operates on tosa shape types
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index d5a46b1b34312..182504676710f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -99,6 +99,26 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: Dim
+//===----------------------------------------------------------------------===//
+def Tosa_DimOp : Tosa_ShapeOp<"dim", [Pure]> {
+ let summary = "Extract size of dimension from input tensor.";
+
+ let description = [{
+ Returns a length 1 shape_t of the size of the input tensor for the given axis.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorAtLeast1D:$input1,
+ I32Attr:$axis
+ );
+
+ let results = (outs Tosa_Shape:$output);
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: DivCeilShape
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..c9ef172e4fced 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4602,24 +4602,9 @@ LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
return success();
}
-LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
- for (auto type : op->getOperandTypes()) {
- if (!mlir::isa<mlir::tosa::shapeType>(type)) {
- return op->emitOpError("must have operands with tosa shape type");
- }
- }
- for (auto type : op->getResultTypes()) {
- if (!mlir::isa<mlir::tosa::shapeType>(type)) {
- return op->emitOpError("must have result with tosa shape type");
- }
- }
- return success();
-}
-
LogicalResult
OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
- failed(verifyTosaShapeOperator(op)))
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();
// delegate function that returns rank of shape type
@@ -4663,6 +4648,24 @@ LogicalResult tosa::ConstShapeOp::verify() {
return success();
}
+LogicalResult tosa::DimOp::verify() {
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ if (outShapeType.getRank() != 1)
+ return emitOpError("expect output shape type to contain one element, got ")
+ << outShapeType;
+
+ const ShapeAdaptor inputType(getInput1().getType());
+ if (inputType.hasRank()) {
+ const int64_t inputRank = inputType.getRank();
+ const int64_t axis = getAxisAttr().getInt();
+ if (axis < 0 || axis >= inputRank)
+ return emitOpError("expect axis to be in the range [0, ")
+ << inputRank << "), got " << axis;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index c9150d5b34d00..62f66543d74cc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -217,6 +217,12 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
return success();
}
+template <>
+LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::DimOp op) {
+ addValue(op.getInput1());
+ return success();
+}
+
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -256,6 +262,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
+ POPULATE_PROFILE_INFO_CUSTOM(Dim)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 530c6ae85287c..897cc87529eca 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -659,6 +659,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Variable);
CHECK_RANKS_AND_SIZES(VariableWrite);
CHECK_RANKS_AND_SIZES(VariableRead);
+ // Shape Operators
+ CHECK_RANKS_AND_SIZES(Dim);
// For the following operators, check whether the size of each tensor
// operand is valid in a given Level.
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 213c4ae054c51..da21c18e19783 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1682,3 +1682,11 @@ func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
return %c : !tosa.shape<7>
}
+
+// -----
+
+func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op failed level check: operand rank(shape) <= MAX_RANK}}
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b9e4d18156898..4e1935233535d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1426,3 +1426,10 @@ func.func @test_div_floor_shape() -> !tosa.shape<4> {
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
return %c : !tosa.shape<4>
}
+
+// -----
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 10d322cf64fb7..63379ed8d8a4d 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -158,3 +158,11 @@ func.func @test_add_shape() -> !tosa.shape<4> {
%c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
return %c : !tosa.shape<4>
}
+
+// -----
+
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index d73650ddd0563..7104285af8446 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1246,3 +1246,27 @@ func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4
%0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
return %0 : !tosa.shape<3>
}
+
+// -----
+
+func.func @test_dim_invalid_output_rank(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<2> {
+ // expected-error at +1 {{'tosa.dim' op expect output shape type to contain one element, got '!tosa.shape<2>'}}
+ %0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<2>
+ return %0 : !tosa.shape<2>
+}
+
+// -----
+
+func.func @test_dim_invalid_axis(%arg0: tensor<1x2x3xi32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op expect axis to be in the range [0, 3), got 4}}
+ %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<1x2x3xi32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
+
+// -----
+
+func.func @test_dim_scalar(%arg0: tensor<i32>) -> !tosa.shape<1> {
+ // expected-error at +1 {{'tosa.dim' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+ %0 = tosa.dim %arg0 {axis = 4 : i32} : (tensor<i32>) -> !tosa.shape<1>
+ return %0 : !tosa.shape<1>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/169368
More information about the Mlir-commits
mailing list