[Mlir-commits] [mlir] [mlir][tosa] Disallow invalid datatype combinations in the validation pass (PR #131595)
Luke Hutton
llvmlistbot at llvm.org
Wed Mar 19 08:02:04 PDT 2025
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/131595
>From edc7f32a7a859b4bc3e2fc421c3871d10a3a4d24 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 6 Mar 2025 21:16:51 +0000
Subject: [PATCH] [mlir][tosa] Disallow invalid datatype combinations in the
validation pass
This commit checks if the operands/results of an operator can be found
in the profile compliance mapping, if it isn't the operator is considered
invalid. As a result, operator datatype combinations that are not listed
under the "Supported Data Types" of the TOSA specification are disallowed
and the validation pass results in failure.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: Iab36dd84cdbf188015c80b066c321edbb2efc0ff
---
.../Conversion/TosaToLinalg/TosaToLinalg.h | 2 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 5 +++
.../mlir/Dialect/Tosa/Transforms/Passes.td | 5 +++
.../TosaToLinalg/TosaToLinalgPass.cpp | 1 +
.../Tosa/Transforms/TosaProfileCompliance.cpp | 41 ++++++++++++++-----
.../Tosa/Transforms/TosaValidation.cpp | 8 ++++
.../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 6 +--
mlir/test/Dialect/Tosa/dynamic_extension.mlir | 8 ++--
mlir/test/Dialect/Tosa/invalid.mlir | 18 ++++++++
mlir/test/Dialect/Tosa/invalid_extension.mlir | 6 +--
mlir/test/Dialect/Tosa/level_check.mlir | 6 +--
11 files changed, 81 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 195a58432737b..f4823858e3893 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
+ {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 69b827fe14dee..da187d8316989 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
// environment.
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+ LogicalResult checkInvalid(Operation *op);
template <typename T>
LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
private:
+ template <typename T>
+ FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
+ CheckCondition &condition);
+
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
};
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index f6ead2b6ba3dd..2d5b0b39df078 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
+ Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
+ /*default=*/"false",
+ "Disable checks for operations that are determined to be invalid due to their "
+ "operand/result datatypes not aligning with the 'Supported Data Types' "
+ "sections of the specifciation">,
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
"Validate if operator parameters are within specfication for the given level",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index bfadebba12708..4cf232a7bc767 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
validationOptions.profile = {"none"};
validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
+ validationOptions.allowInvalidOpDatatypeCombinations = false;
validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ed2c40598458c..9523146581f10 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -300,6 +300,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//
+template <typename T>
+FailureOr<SmallVector<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op,
+ CheckCondition &condition) {
+ const std::string opName = op->getName().getStringRef().str();
+ const auto complianceMap = getProfileComplianceMap<T>();
+ const auto it = complianceMap.find(opName);
+ if (it == complianceMap.end())
+ return {};
+
+ return findMatchedProfile<T>(op, it->second, condition);
+}
+
template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
Operation *op, const tosa::TargetEnv &targetEnv,
@@ -309,11 +322,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- auto opName = op->getName().getStringRef().str();
- auto compMap = getProfileComplianceMap<T>();
- auto it = compMap.find(opName);
-
- if (it == compMap.end()) {
+ CheckCondition condition = CheckCondition::invalid;
+ const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
+ if (failed(maybeOpRequiredMode)) {
// Operators such as variable and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
@@ -334,12 +345,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
return failure();
}
- CheckCondition condition = CheckCondition::invalid;
- // Find the profiles or extensions requirement according to the signature of
- // type of the operand list.
- SmallVector<T> opRequiredMode =
- findMatchedProfile<T>(op, it->second, condition);
-
+ // Find the required profiles or extensions according to the operand type
+ // combination.
+ const auto opRequiredMode = maybeOpRequiredMode.value();
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -419,6 +427,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
return success();
}
+LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
+ CheckCondition condition = CheckCondition::invalid;
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
+ !maybeProfDef.value().size() && !maybeExtDef.value().size())
+ return failure();
+
+ return success();
+}
+
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 79c13793d7713..3ec7354562d23 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
this->profile = options.profile;
this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
+ this->allowInvalidOpDatatypeCombinations =
+ options.allowInvalidOpDatatypeCombinations;
this->level = options.level;
}
void runOnOperation() final;
@@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
}
}
+ if (!allowInvalidOpDatatypeCombinations &&
+ failed(profileComp.checkInvalid(op))) {
+ op->emitOpError("illegal: operand/result data types not supported");
+ return signalPassFailure();
+ }
+
// Some uses of TOSA rely on the constant operands of particular
// operations.
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index ecd5c792e08b6..22b07e69d3b87 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -14,10 +14,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
// -----
// check that -tosa-validate level checking kick in
-func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// expected-error at +1 {{'tosa.abs' op failed level check: unranked tensor}}
- %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
- return %0 : tensor<*xi8>
+ %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index fd9b3d5f23483..0ec46022157d7 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,13 +2,13 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %0 : tensor<13x21x3xi8>
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ca7c71cd3b137..80886e1cb6d46 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
return %1 : tensor<f32>
}
+
+// -----
+
+// CHECK-LABEL: test_add_i1
+func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_out_i16
+func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+ // expected-error at +1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 13952716a9611..10140cc0a1e9b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
// -----
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
// expected-error at +1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %0 : tensor<13x21x3xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bdf18ec823128..0f469761d89e3 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
// -----
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
+func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
// expected-error at +1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
- return %0 : tensor<1x1x1x1x1x1x64xi16>
+ %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
+ return %0 : tensor<1x1x1x1x1x1x64xi32>
}
// -----
More information about the Mlir-commits
mailing list