[Mlir-commits] [mlir] d4570ea - [mlir][tosa] Disallow invalid datatype combinations in the validation pass (#131595)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 25 03:05:42 PDT 2025


Author: Luke Hutton
Date: 2025-03-25T10:05:39Z
New Revision: d4570ea8138f3cc52030f9c3a5772ba41ebb9ced

URL: https://github.com/llvm/llvm-project/commit/d4570ea8138f3cc52030f9c3a5772ba41ebb9ced
DIFF: https://github.com/llvm/llvm-project/commit/d4570ea8138f3cc52030f9c3a5772ba41ebb9ced.diff

LOG: [mlir][tosa] Disallow invalid datatype combinations in the validation pass (#131595)

This commit checks if the operands/results of an operator can be found
in the profile compliance mapping, if it isn't the operator is
considered invalid. As a result, operator datatype combinations that are
not listed under the "Supported Data Types" of the TOSA specification
are disallowed and the validation pass results in failure.

Signed-off-by: Luke Hutton <luke.hutton at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
    mlir/test/Dialect/Tosa/dynamic_extension.mlir
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/invalid_extension.mlir
    mlir/test/Dialect/Tosa/level_check.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 195a58432737b..f4823858e3893 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
     // Note: Default to 'none' level unless otherwise specified.
     std::optional<tosa::TosaValidationOptions> validationOptions =
         tosa::TosaValidationOptions{
-            {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
+            {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
 /// Currently, this includes only the "tosa-to-linalg-pipeline".

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 1df1761d38455..d73b288d2c8bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
   // environment.
   LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
   LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+  LogicalResult checkInvalid(Operation *op);
 
   template <typename T>
   LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
   stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
 
 private:
+  template <typename T>
+  FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
+                                                  CheckCondition &condition);
+
   OperationProfileComplianceMap profileComplianceMap;
   OperationExtensionComplianceMap extensionComplianceMap;
 };

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index f6ead2b6ba3dd..2d5b0b39df078 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
       Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
              /*default=*/"false",
              "Verify if the properties of certain operations align the spec requirement">,
+      Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
+             /*default=*/"false",
+             "Disable checks for operations that are determined to be invalid due to their "
+             "operand/result datatypes not aligning with the 'Supported Data Types' "
+             "sections of the specifciation">,
       Option<"level", "level", "mlir::tosa::TosaLevelEnum",
              /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
              "Validate if operator parameters are within specfication for the given level",

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index bfadebba12708..4cf232a7bc767 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
         validationOptions.profile = {"none"};
         validationOptions.extension = {"none"};
         validationOptions.strictOpSpecAlignment = false;
+        validationOptions.allowInvalidOpDatatypeCombinations = false;
         validationOptions.level = tosa::TosaLevelEnum::EightK;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
                                     tosaToLinalgNamedOptions,

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 4aeb095ffff07..eb7981b313d1d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -140,6 +140,7 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
   addValue(op.getValues());
+  addValue(op.getIndices());
   addValue(op.getOutput());
   return success();
 }
@@ -147,6 +148,7 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
   addValue(op.getValuesIn());
+  addValue(op.getIndices());
   addValue(op.getInput());
   addValue(op.getValuesOut());
   return success();
@@ -347,6 +349,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
 // Tosa Profile And Extension Compliance Checker
 //===----------------------------------------------------------------------===//
 
+template <typename T>
+FailureOr<SmallVector<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op,
+                                             CheckCondition &condition) {
+  const std::string opName = op->getName().getStringRef().str();
+  const auto complianceMap = getProfileComplianceMap<T>();
+  const auto it = complianceMap.find(opName);
+  if (it == complianceMap.end())
+    return {};
+
+  return findMatchedProfile<T>(op, it->second, condition);
+}
+
 template <typename T>
 LogicalResult TosaProfileCompliance::checkProfileOrExtension(
     Operation *op, const tosa::TargetEnv &targetEnv,
@@ -356,11 +371,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
   if (specRequiredModeSet.size() == 0)
     return success();
 
-  auto opName = op->getName().getStringRef().str();
-  auto compMap = getProfileComplianceMap<T>();
-  auto it = compMap.find(opName);
-
-  if (it == compMap.end()) {
+  CheckCondition condition = CheckCondition::invalid;
+  const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
+  if (failed(maybeOpRequiredMode)) {
     // Operators such as control-flow and shape ops do not have an operand type
     // restriction. When the profile compliance information of operation is not
     // found, confirm if the target have enabled the profile required from the
@@ -381,12 +394,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
     return failure();
   }
 
-  CheckCondition condition = CheckCondition::invalid;
-  // Find the profiles or extensions requirement according to the signature of
-  // type of the operand list.
-  SmallVector<T> opRequiredMode =
-      findMatchedProfile<T>(op, it->second, condition);
-
+  // Find the required profiles or extensions according to the operand type
+  // combination.
+  const auto opRequiredMode = maybeOpRequiredMode.value();
   if (opRequiredMode.size() == 0) {
     // No matched restriction found.
     return success();
@@ -466,6 +476,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
   return success();
 }
 
+LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
+  CheckCondition condition = CheckCondition::invalid;
+  const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
+  const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+  if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
+      !maybeProfDef.value().size() && !maybeExtDef.value().size())
+    return failure();
+
+  return success();
+}
+
 // Find the profiles or extensions requirement according to the signature of
 // type of the operand list.
 template <typename T>
@@ -483,7 +504,6 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
 
   for (size_t i = 0; i < compInfo.size(); i++) {
     SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
-
     for (SmallVector<TypeInfo> expected : sets) {
       assert(present.size() == expected.size() &&
              "the entries for profile-based compliance do not match between "

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 79c13793d7713..3ec7354562d23 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     this->profile = options.profile;
     this->extension = options.extension;
     this->strictOpSpecAlignment = options.strictOpSpecAlignment;
+    this->allowInvalidOpDatatypeCombinations =
+        options.allowInvalidOpDatatypeCombinations;
     this->level = options.level;
   }
   void runOnOperation() final;
@@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
       }
     }
 
+    if (!allowInvalidOpDatatypeCombinations &&
+        failed(profileComp.checkInvalid(op))) {
+      op->emitOpError("illegal: operand/result data types not supported");
+      return signalPassFailure();
+    }
+
     // Some uses of TOSA rely on the constant operands of particular
     // operations.
     if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index ecd5c792e08b6..731e134ed1a07 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -4,20 +4,20 @@
 // -----
 
 // check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
-  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
   return
 }
 
 // -----
 
 // check that -tosa-validate level checking kick in
-func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
   // expected-error at +1 {{'tosa.abs' op failed level check: unranked tensor}}
-  %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
-  return %0 : tensor<*xi8>
+  %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index fd9b3d5f23483..0ec46022157d7 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,13 +2,13 @@
 // Check operations when the dynamic extension is enabled.
 //--------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
 
 // -----
 
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %0 : tensor<13x21x3xi8>
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3203c64b439da..ac8a247da24a7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -616,17 +616,17 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
 
 // -----
 
-func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable' op name has already been declared}}
-  tosa.variable @stored_var = dense<3> : tensor<1x4x8xi32>
+  tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
   return
 }
 
 // -----
 
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
   %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
   return
@@ -634,8 +634,8 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
 
 // -----
 
-func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
   %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
   return
@@ -644,7 +644,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
 // -----
 
 func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
   tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
   return
@@ -652,10 +652,10 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
 
 // -----
 
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
   // expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
-  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+  tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
   return
 }
 
@@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
   %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
   return %1 : tensor<f32>
 }
+
+// -----
+
+// CHECK-LABEL: test_add_i1
+func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+  // expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_out_i16
+func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+  // expected-error at +1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}

diff  --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index bde5b5ec7cffe..d1594232e4e1d 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
 
 // -----
 
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
   // expected-error at +1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
-  return %0 : tensor<13x21x3xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bdf18ec823128..0f469761d89e3 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
 
 // -----
 
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
+func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
   // expected-error at +1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
-    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
-    return %0 : tensor<1x1x1x1x1x1x64xi16>
+    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
+    return %0 : tensor<1x1x1x1x1x1x64xi32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list