[Mlir-commits] [mlir] 15e8079 - [mlir][tosa] Add support for boolean cast/gather/scatter operations (#177693)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 4 06:30:31 PST 2026


Author: Luke Hutton
Date: 2026-02-04T14:30:25Z
New Revision: 15e807960a4f464a94ca093bb531cce8d86410ce

URL: https://github.com/llvm/llvm-project/commit/15e807960a4f464a94ca093bb531cce8d86410ce
DIFF: https://github.com/llvm/llvm-project/commit/15e807960a4f464a94ca093bb531cce8d86410ce.diff

LOG: [mlir][tosa] Add support for boolean cast/gather/scatter operations (#177693)

Aligns with the spec change:
https://github.com/arm/tosa-specification/pull/32

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
    mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
    mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 009775293a987..a4eb3dd12bd54 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -371,7 +371,10 @@ profileComplianceMap = {
         {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
       {{Profile::pro_fp},
        {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
-        {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+        {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}},
+      {{Profile::pro_fp, Profile::pro_int},
+       {{{boolT, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+       anyOf}}},
     {"tosa.scatter",
      {{{Profile::pro_int},
        {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -379,7 +382,12 @@ profileComplianceMap = {
         {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
       {{Profile::pro_fp},
        {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
-        {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+        {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}},
+      {{Profile::pro_fp, Profile::pro_int},
+       {
+           {{boolT, i32T, boolT, boolT}, SpecificationVersion::V_1_1_DRAFT},
+       },
+       anyOf}}},
     {"tosa.resize",
      {{{Profile::pro_int},
        {{{i8T, i32T}, SpecificationVersion::V_1_0},
@@ -402,7 +410,8 @@ profileComplianceMap = {
         {{i32T, i8T}, SpecificationVersion::V_1_0},
         {{i32T, i16T}, SpecificationVersion::V_1_0}}},
       {{Profile::pro_fp},
-       {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+       {{{boolT, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i8T, fp16T}, SpecificationVersion::V_1_0},
         {{i8T, fp32T}, SpecificationVersion::V_1_0},
         {{i16T, fp16T}, SpecificationVersion::V_1_0},
         {{i16T, fp32T}, SpecificationVersion::V_1_0},
@@ -415,7 +424,8 @@ profileComplianceMap = {
         {{fp32T, i8T}, SpecificationVersion::V_1_0},
         {{fp32T, i16T}, SpecificationVersion::V_1_0},
         {{fp32T, i32T}, SpecificationVersion::V_1_0},
-        {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
+        {{fp32T, fp16T}, SpecificationVersion::V_1_0},
+        {{fp32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.rescale",
      {{{Profile::pro_int},
        {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -824,7 +834,8 @@ extensionComplianceMap = {
         {{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT},
         {{i64T, i64T, i64T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp16T, i64T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp32T, i64T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+        {{fp32T, i64T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e4m3, Extension::int64},
        {{{fp8e4m3T, i64T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf},
@@ -847,7 +858,8 @@ extensionComplianceMap = {
         {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
         {{i64T, i64T, i64T, i64T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp16T, i64T, fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp32T, i64T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+        {{fp32T, i64T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, boolT, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e4m3, Extension::int64},
        {{{fp8e4m3T, i64T, fp8e4m3T, fp8e4m3T},
          SpecificationVersion::V_1_1_DRAFT}},
@@ -876,7 +888,9 @@ extensionComplianceMap = {
         {{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
       {{Extension::int64},
        {{{i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
-        {{i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+        {{i64T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16, Extension::fp8e4m3},
        {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
         {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},

diff  --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index fe38e2f61b2e8..fbd935d56fcc6 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -62,6 +62,70 @@ func.func @test_transpose_conv2d_fp8_acc32(%arg0: tensor<1x32x32x8xf8E5M2>, %arg
 
 // -----
 
+func.func @test_gather_bool_i64(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xi1> {
+  // expected-error at +1 {{'tosa.gather' op illegal: requires [int64] but not enabled in target}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi64>) -> tensor<13x26x3xi1>
+  return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xi1> {
+  // expected-error at +1 {{'tosa.gather' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi32>) -> tensor<13x26x3xi1>
+  return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+func.func @test_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+  // expected-error at +1 {{'tosa.scatter' op illegal: requires [int64] but not enabled in target}}
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+  return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+func.func @test_scatter_bool_i32(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+  // expected-error at +1 {{'tosa.scatter' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi32>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+  return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+func.func @test_cast_bool_fp32(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.cast' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_cast_bool_i64(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi64> {
+  // expected-error at +1 {{'tosa.cast' op illegal: requires [int64] but not enabled in target}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi64>
+  return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+func.func @test_cast_fp32_bool(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+  // expected-error at +1 {{'tosa.cast' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+func.func @test_cast_i64_bool(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi1> {
+  // expected-error at +1 {{'tosa.cast' op illegal: requires [int64] but not enabled in target}}
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
 func.func @test_dyanmic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand shape dimension cannot be dynamic when targeting TOSA specification version 1.0 or below}}
   %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>

diff  --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 97fb14927f7e8..72269d21f3d98 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -183,6 +183,70 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te
 
 // -----
 
+// CHECK-LABEL: test_gather_bool_i64
+func.func @test_gather_bool_i64(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xi1> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi64>) -> tensor<13x26x3xi1>
+  return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_gather_bool_i32
+func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xi1> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi32>) -> tensor<13x26x3xi1>
+  return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_scatter_bool_i64
+func.func @test_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+  return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_scatter_bool_i32
+func.func @test_scatter_bool_i32(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi32>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+  return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_bool_fp32
+func.func @test_cast_bool_fp32(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_bool_i64
+func.func @test_cast_bool_i64(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi64> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi64>
+  return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_fp32_bool
+func.func @test_cast_fp32_bool(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_i64_bool
+func.func @test_cast_i64_bool(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi1> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
 // CHECK-LABEL: test_dynamic_dims
 func.func @test_dynamic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
   %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
@@ -250,4 +314,4 @@ func.func @test_assert_equal_shape() {
   %1 = tosa.const_shape {values = dense<[5, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
   tosa.assert_equal_shape %0, %1 {allow_broadcast = true} : (!tosa.shape<2>, !tosa.shape<2>) -> ()
   return
-}
\ No newline at end of file
+}


        


More information about the Mlir-commits mailing list