[Mlir-commits] [mlir] [mlir][tosa] Fix invalid data type combinations check (PR #150066)

Luke Hutton llvmlistbot at llvm.org
Wed Jul 30 01:51:51 PDT 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/150066

>From 1cad3536b1ccba3e0ed880899355e4aec3694555 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 22 Jul 2025 12:53:38 +0000
Subject: [PATCH] [mlir][tosa] Fix invalid data type combinations check

Previously this check assumed that if an operator exists in profile
compliance (TosaProfileComplianceData.h.inc), an entry exists in both
the profiles and extensions section. However, this is not necessarily
the case.

This commit changes the check such that it doesn't assume the above.
In doing so, it allows more operators to be checked for invalid data
type combinations, which were otherwise skipped previously.

Change-Id: I2a7bc9be167463d29bf5d9ab1de946c26594845e
---
 .../Tosa/Transforms/TosaProfileCompliance.cpp    |  7 +++++--
 mlir/test/Dialect/Tosa/invalid.mlir              | 16 ++++++++++++++++
 mlir/test/Dialect/Tosa/level_check.mlir          |  6 +++---
 3 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f3650ca01..9543fa1fe39d8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
   CheckCondition condition = CheckCondition::invalid;
   const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
   const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (failed(maybeProfDef) && failed(maybeExtDef))
+    return success();
 
-  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
-      !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+  const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
+                        (succeeded(maybeExtDef) && !maybeExtDef->empty());
+  if (!hasEntry) {
     std::string message;
     llvm::raw_string_ostream os(message);
     os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 5a424c41775c9..95ebe0dfef0f7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2027,3 +2027,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
   %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
   return %0 : tensor<2x52x3xf32>
 }
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+  // expected-error at +1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+  return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+  %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+  return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0dddf26fb1f85..ef4e0ae13efa7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
 
 // -----
 
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
   // expected-error at +1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
+  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
 }
 
 // -----



More information about the Mlir-commits mailing list