[Mlir-commits] [mlir] [mlir][tosa] Fix invalid data type combinations check (PR #150066)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 22 10:03:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

Previously this check assumed that if an operator exists in profile complimance (TosaProfileComplianceData.h.inc), an entry exists in both the profiles and extensions section. However, this is not necessarily the case.

This commit changes the check such that it doesn't assume the above. In doing so, it allows more operators to be checked for invalid data type combinations, which were otherwise skipped previously.

---
Full diff: https://github.com/llvm/llvm-project/pull/150066.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+9-2) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+16) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f3650ca01..09f09487129ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,16 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
   CheckCondition condition = CheckCondition::invalid;
   const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
   const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (failed(maybeProfDef) && failed(maybeExtDef))
+    return success();
+
+  bool hasEntry = false;
+  if (succeeded(maybeProfDef))
+    hasEntry |= maybeProfDef.value().size();
+  if (succeeded(maybeExtDef))
+    hasEntry |= maybeExtDef.value().size();
 
-  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
-      !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+  if (!hasEntry) {
     std::string message;
     llvm::raw_string_ostream os(message);
     os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5a424c41775c9..95ebe0dfef0f7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2027,3 +2027,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
   %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
   return %0 : tensor<2x52x3xf32>
 }
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+  // expected-error at +1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+  return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+  %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+  return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..ef4e0ae13efa7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
 
 // -----
 
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
   // expected-error at +1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
+  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/150066


More information about the Mlir-commits mailing list