[Mlir-commits] [mlir] 15e8079 - [mlir][tosa] Add support for boolean cast/gather/scatter operations (#177693)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 4 06:30:31 PST 2026
Author: Luke Hutton
Date: 2026-02-04T14:30:25Z
New Revision: 15e807960a4f464a94ca093bb531cce8d86410ce
URL: https://github.com/llvm/llvm-project/commit/15e807960a4f464a94ca093bb531cce8d86410ce
DIFF: https://github.com/llvm/llvm-project/commit/15e807960a4f464a94ca093bb531cce8d86410ce.diff
LOG: [mlir][tosa] Add support for boolean cast/gather/scatter operations (#177693)
Aligns with the spec change:
https://github.com/arm/tosa-specification/pull/32
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 009775293a987..a4eb3dd12bd54 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -371,7 +371,10 @@ profileComplianceMap = {
{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
{{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
- {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp, Profile::pro_int},
+ {{{boolT, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+ anyOf}}},
{"tosa.scatter",
{{{Profile::pro_int},
{{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -379,7 +382,12 @@ profileComplianceMap = {
{{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
{{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
- {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp, Profile::pro_int},
+ {
+ {{boolT, i32T, boolT, boolT}, SpecificationVersion::V_1_1_DRAFT},
+ },
+ anyOf}}},
{"tosa.resize",
{{{Profile::pro_int},
{{{i8T, i32T}, SpecificationVersion::V_1_0},
@@ -402,7 +410,8 @@ profileComplianceMap = {
{{i32T, i8T}, SpecificationVersion::V_1_0},
{{i32T, i16T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+ {{{boolT, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i8T, fp16T}, SpecificationVersion::V_1_0},
{{i8T, fp32T}, SpecificationVersion::V_1_0},
{{i16T, fp16T}, SpecificationVersion::V_1_0},
{{i16T, fp32T}, SpecificationVersion::V_1_0},
@@ -415,7 +424,8 @@ profileComplianceMap = {
{{fp32T, i8T}, SpecificationVersion::V_1_0},
{{fp32T, i16T}, SpecificationVersion::V_1_0},
{{fp32T, i32T}, SpecificationVersion::V_1_0},
- {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
+ {{fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
{{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -824,7 +834,8 @@ extensionComplianceMap = {
{{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT},
{{i64T, i64T, i64T}, SpecificationVersion::V_1_1_DRAFT},
{{fp16T, i64T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
- {{fp32T, i64T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{fp32T, i64T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3, Extension::int64},
{{{fp8e4m3T, i64T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
allOf},
@@ -847,7 +858,8 @@ extensionComplianceMap = {
{{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
{{i64T, i64T, i64T, i64T}, SpecificationVersion::V_1_1_DRAFT},
{{fp16T, i64T, fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
- {{fp32T, i64T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{fp32T, i64T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T, boolT, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3, Extension::int64},
{{{fp8e4m3T, i64T, fp8e4m3T, fp8e4m3T},
SpecificationVersion::V_1_1_DRAFT}},
@@ -876,7 +888,9 @@ extensionComplianceMap = {
{{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
{{Extension::int64},
{{{i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
- {{i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{i64T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{boolT, i64T}, SpecificationVersion::V_1_1_DRAFT},
+ {{i64T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16, Extension::fp8e4m3},
{{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
{{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index fe38e2f61b2e8..fbd935d56fcc6 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -62,6 +62,70 @@ func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg
// -----
+func.func @test_gather_bool_i64(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xi1> {
+ // expected-error at +1 {{'tosa.gather' op illegal: requires [int64] but not enabled in target}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi64>) -> tensor<13x26x3xi1>
+ return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xi1> {
+ // expected-error at +1 {{'tosa.gather' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi32>) -> tensor<13x26x3xi1>
+ return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+func.func @test_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+ // expected-error at +1 {{'tosa.scatter' op illegal: requires [int64] but not enabled in target}}
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+ return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+func.func @test_scatter_bool_i32(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+ // expected-error at +1 {{'tosa.scatter' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi32>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+ return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+func.func @test_cast_bool_fp32(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+ // expected-error at +1 {{'tosa.cast' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_cast_bool_i64(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi64> {
+ // expected-error at +1 {{'tosa.cast' op illegal: requires [int64] but not enabled in target}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+func.func @test_cast_fp32_bool(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.cast' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+func.func @test_cast_i64_bool(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.cast' op illegal: requires [int64] but not enabled in target}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
func.func @test_dyanmic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand shape dimension cannot be dynamic when targeting TOSA specification version 1.0 or below}}
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 97fb14927f7e8..72269d21f3d98 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -183,6 +183,70 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te
// -----
+// CHECK-LABEL: test_gather_bool_i64
+func.func @test_gather_bool_i64(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xi1> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi64>) -> tensor<13x26x3xi1>
+ return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_gather_bool_i32
+func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xi1> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi32>) -> tensor<13x26x3xi1>
+ return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_scatter_bool_i64
+func.func @test_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+ return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_scatter_bool_i32
+func.func @test_scatter_bool_i32(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi32>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+ return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_bool_fp32
+func.func @test_cast_bool_fp32(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_bool_i64
+func.func @test_cast_bool_i64(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi64> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_fp32_bool
+func.func @test_cast_fp32_bool(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_i64_bool
+func.func @test_cast_i64_bool(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi1> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
// CHECK-LABEL: test_dynamic_dims
func.func @test_dynamic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
@@ -250,4 +314,4 @@ func.func @test_assert_equal_shape() {
%1 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<2>, !tosa.shape<2>) -> ()
return
-}
\ No newline at end of file
+}
More information about the Mlir-commits
mailing list