[Mlir-commits] [mlir] [mlir][tosa] Disallow invalid datatype combinations in the validation pass (PR #131595)

Luke Hutton llvmlistbot at llvm.org
Wed Mar 19 07:37:06 PDT 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/131595

>From 3bb61abd02fa9fcbf95bf31bd0d3f8ee5d77ace6 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 6 Mar 2025 21:16:51 +0000
Subject: [PATCH] [mlir][tosa] Disallow invalid datatype combinations in the
 validation pass

This commit checks if the operands/results of an operator can be found
in the profile compliance mapping, if it isn't the operator is considered
invalid. As a result, operator datatype combinations that are not listed
under the "Supported Data Types" of the TOSA specification are disallowed
and the validation pass results in failure.

Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: Iab36dd84cdbf188015c80b066c321edbb2efc0ff
---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  2 +-
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   |  5 +++
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  5 +++
 .../TosaToLinalg/TosaToLinalgPass.cpp         |  1 +
 .../Tosa/Transforms/TosaProfileCompliance.cpp | 41 ++++++++++++++-----
 .../Tosa/Transforms/TosaValidation.cpp        |  8 ++++
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir |  6 +--
 mlir/test/Dialect/Tosa/dynamic_extension.mlir |  8 ++--
 mlir/test/Dialect/Tosa/invalid.mlir           | 18 ++++++++
 mlir/test/Dialect/Tosa/invalid_extension.mlir |  6 +--
 mlir/test/Dialect/Tosa/level_check.mlir       |  6 +--
 11 files changed, 81 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 195a58432737b..f4823858e3893 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
     // Note: Default to 'none' level unless otherwise specified.
     std::optional<tosa::TosaValidationOptions> validationOptions =
         tosa::TosaValidationOptions{
-            {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
+            {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
 /// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 69b827fe14dee..da187d8316989 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
   // environment.
   LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
   LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+  LogicalResult checkInvalid(Operation *op);
 
   template <typename T>
   LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
   stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
 
 private:
+  template <typename T>
+  FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
+                                                  CheckCondition &condition);
+
   OperationProfileComplianceMap profileComplianceMap;
   OperationExtensionComplianceMap extensionComplianceMap;
 };
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index f6ead2b6ba3dd..2d5b0b39df078 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
       Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
              /*default=*/"false",
              "Verify if the properties of certain operations align the spec requirement">,
+      Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
+             /*default=*/"false",
+             "Disable checks for operations that are determined to be invalid due to their "
+             "operand/result datatypes not aligning with the 'Supported Data Types' "
+             "sections of the specifciation">,
       Option<"level", "level", "mlir::tosa::TosaLevelEnum",
              /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
              "Validate if operator parameters are within specfication for the given level",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index bfadebba12708..4cf232a7bc767 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
         validationOptions.profile = {"none"};
         validationOptions.extension = {"none"};
         validationOptions.strictOpSpecAlignment = false;
+        validationOptions.allowInvalidOpDatatypeCombinations = false;
         validationOptions.level = tosa::TosaLevelEnum::EightK;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
                                     tosaToLinalgNamedOptions,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ed2c40598458c..9523146581f10 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -300,6 +300,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 // Tosa Profile And Extension Compliance Checker
 //===----------------------------------------------------------------------===//
 
+template <typename T>
+FailureOr<SmallVector<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op,
+                                             CheckCondition &condition) {
+  const std::string opName = op->getName().getStringRef().str();
+  const auto complianceMap = getProfileComplianceMap<T>();
+  const auto it = complianceMap.find(opName);
+  if (it == complianceMap.end())
+    return {};
+
+  return findMatchedProfile<T>(op, it->second, condition);
+}
+
 template <typename T>
 LogicalResult TosaProfileCompliance::checkProfileOrExtension(
     Operation *op, const tosa::TargetEnv &targetEnv,
@@ -309,11 +322,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
   if (specRequiredModeSet.size() == 0)
     return success();
 
-  auto opName = op->getName().getStringRef().str();
-  auto compMap = getProfileComplianceMap<T>();
-  auto it = compMap.find(opName);
-
-  if (it == compMap.end()) {
+  CheckCondition condition = CheckCondition::invalid;
+  const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
+  if (failed(maybeOpRequiredMode)) {
     // Operators such as variable and shape ops do not have an operand type
     // restriction. When the profile compliance information of operation is not
     // found, confirm if the target have enabled the profile required from the
@@ -334,12 +345,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
     return failure();
   }
 
-  CheckCondition condition = CheckCondition::invalid;
-  // Find the profiles or extensions requirement according to the signature of
-  // type of the operand list.
-  SmallVector<T> opRequiredMode =
-      findMatchedProfile<T>(op, it->second, condition);
-
+  // Find the required profiles or extensions according to the operand type
+  // combination.
+  const auto opRequiredMode = maybeOpRequiredMode.value();
   if (opRequiredMode.size() == 0) {
     // No matched restriction found.
     return success();
@@ -419,6 +427,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
   return success();
 }
 
+LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
+  CheckCondition condition = CheckCondition::invalid;
+  const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
+  const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
+      !maybeProfDef.value().size() && !maybeExtDef.value().size())
+    return failure();
+
+  return success();
+}
+
 // Find the profiles or extensions requirement according to the signature of
 // type of the operand list.
 template <typename T>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 79c13793d7713..3ec7354562d23 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     this->profile = options.profile;
     this->extension = options.extension;
     this->strictOpSpecAlignment = options.strictOpSpecAlignment;
+    this->allowInvalidOpDatatypeCombinations =
+        options.allowInvalidOpDatatypeCombinations;
     this->level = options.level;
   }
   void runOnOperation() final;
@@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
       }
     }
 
+    if (!allowInvalidOpDatatypeCombinations &&
+        failed(profileComp.checkInvalid(op))) {
+      op->emitOpError("illegal: operand/result data types not supported");
+      return signalPassFailure();
+    }
+
     // Some uses of TOSA rely on the constant operands of particular
     // operations.
     if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index ecd5c792e08b6..22b07e69d3b87 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -14,10 +14,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
 // -----
 
 // check that -tosa-validate level checking kick in
-func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
   // expected-error at +1 {{'tosa.abs' op failed level check: unranked tensor}}
-  %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
-  return %0 : tensor<*xi8>
+  %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index fd9b3d5f23483..0ec46022157d7 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,13 +2,13 @@
 // Check operations when the dynamic extension is enabled.
 //--------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
 
 // -----
 
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %0 : tensor<13x21x3xi8>
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ca7c71cd3b137..d3e4452dfc1e3 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
   %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
   return %1 : tensor<f32>
 }
+
+// -----
+
+// CHECK-LABEL: test_add_i1
+func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+  // expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_out_i16
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+  // expected-error at +1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 13952716a9611..10140cc0a1e9b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
 
 // -----
 
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
   // expected-error at +1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %0 : tensor<13x21x3xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bdf18ec823128..0f469761d89e3 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
 
 // -----
 
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
+func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
   // expected-error at +1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
-    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
-    return %0 : tensor<1x1x1x1x1x1x64xi16>
+    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
+    return %0 : tensor<1x1x1x1x1x1x64xi32>
 }
 
 // -----



More information about the Mlir-commits mailing list