[Mlir-commits] [mlir] [mlir][tosa] Fix invalid data type combinations check (PR #150066)

Luke Hutton llvmlistbot at llvm.org
Tue Jul 22 10:03:15 PDT 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/150066

Previously this check assumed that if an operator exists in profile complimance (TosaProfileComplianceData.h.inc), an entry exists in both the profiles and extensions section. However, this is not necessarily the case.

This commit changes the check such that it doesn't assume the above. In doing so, it allows more operators to be checked for invalid data type combinations, which were otherwise skipped previously.

>From 0ece9c5c9817ac39a958998efca999aa2ccdadc6 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 22 Jul 2025 12:53:38 +0000
Subject: [PATCH] [mlir][tosa] Fix invalid data type combinations check

Previously this check assumed that if an operator exists in profile
complimance (TosaProfileComplianceData.h.inc), an entry exists in both
the profiles and extensions section. However, this is not necessarily
the case.

This commit changes the check such that it doesn't assume the above.
In doing so, it allows more operators to be checked for invalid data
type combinations, which were otherwise skipped previously.

Change-Id: I2a7bc9be167463d29bf5d9ab1de946c26594845e
---
 .../Tosa/Transforms/TosaProfileCompliance.cpp    | 11 +++++++++--
 mlir/test/Dialect/Tosa/invalid.mlir              | 16 ++++++++++++++++
 mlir/test/Dialect/Tosa/level_check.mlir          |  6 +++---
 3 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f3650ca01..09f09487129ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,16 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
   CheckCondition condition = CheckCondition::invalid;
   const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
   const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (failed(maybeProfDef) && failed(maybeExtDef))
+    return success();
+
+  bool hasEntry = false;
+  if (succeeded(maybeProfDef))
+    hasEntry |= maybeProfDef.value().size();
+  if (succeeded(maybeExtDef))
+    hasEntry |= maybeExtDef.value().size();
 
-  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
-      !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+  if (!hasEntry) {
     std::string message;
     llvm::raw_string_ostream os(message);
     os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5a424c41775c9..95ebe0dfef0f7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2027,3 +2027,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
   %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
   return %0 : tensor<2x52x3xf32>
 }
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+  // expected-error at +1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+  return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+  %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+  return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..ef4e0ae13efa7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
 
 // -----
 
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
   // expected-error at +1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
+  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
 }
 
 // -----



More information about the Mlir-commits mailing list