[Mlir-commits] [mlir] d4570ea - [mlir][tosa] Disallow invalid datatype combinations in the validation pass (#131595)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 25 03:05:42 PDT 2025
Author: Luke Hutton
Date: 2025-03-25T10:05:39Z
New Revision: d4570ea8138f3cc52030f9c3a5772ba41ebb9ced
URL: https://github.com/llvm/llvm-project/commit/d4570ea8138f3cc52030f9c3a5772ba41ebb9ced
DIFF: https://github.com/llvm/llvm-project/commit/d4570ea8138f3cc52030f9c3a5772ba41ebb9ced.diff
LOG: [mlir][tosa] Disallow invalid datatype combinations in the validation pass (#131595)
This commit checks if the operands/results of an operator can be found
in the profile compliance mapping, if it isn't the operator is
considered invalid. As a result, operator datatype combinations that are
not listed under the "Supported Data Types" of the TOSA specification
are disallowed and the validation pass results in failure.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Added:
Modified:
mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
mlir/test/Dialect/Tosa/dynamic_extension.mlir
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/invalid_extension.mlir
mlir/test/Dialect/Tosa/level_check.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 195a58432737b..f4823858e3893 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -40,7 +40,7 @@ void addTosaToLinalgPasses(
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, tosa::TosaLevelEnum::None});
+ {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 1df1761d38455..d73b288d2c8bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -115,6 +115,7 @@ class TosaProfileCompliance {
// environment.
LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv);
LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv);
+ LogicalResult checkInvalid(Operation *op);
template <typename T>
LogicalResult checkProfileOrExtension(
@@ -163,6 +164,10 @@ class TosaProfileCompliance {
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
private:
+ template <typename T>
+ FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
+ CheckCondition &condition);
+
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
};
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index f6ead2b6ba3dd..2d5b0b39df078 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
+ Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
+ /*default=*/"false",
+ "Disable checks for operations that are determined to be invalid due to their "
+ "operand/result datatypes not aligning with the 'Supported Data Types' "
+ "sections of the specifciation">,
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
"Validate if operator parameters are within specfication for the given level",
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index bfadebba12708..4cf232a7bc767 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
validationOptions.profile = {"none"};
validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
+ validationOptions.allowInvalidOpDatatypeCombinations = false;
validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 4aeb095ffff07..eb7981b313d1d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -140,6 +140,7 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
addValue(op.getValues());
+ addValue(op.getIndices());
addValue(op.getOutput());
return success();
}
@@ -147,6 +148,7 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
+ addValue(op.getIndices());
addValue(op.getInput());
addValue(op.getValuesOut());
return success();
@@ -347,6 +349,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//
+template <typename T>
+FailureOr<SmallVector<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op,
+ CheckCondition &condition) {
+ const std::string opName = op->getName().getStringRef().str();
+ const auto complianceMap = getProfileComplianceMap<T>();
+ const auto it = complianceMap.find(opName);
+ if (it == complianceMap.end())
+ return {};
+
+ return findMatchedProfile<T>(op, it->second, condition);
+}
+
template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
Operation *op, const tosa::TargetEnv &targetEnv,
@@ -356,11 +371,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- auto opName = op->getName().getStringRef().str();
- auto compMap = getProfileComplianceMap<T>();
- auto it = compMap.find(opName);
-
- if (it == compMap.end()) {
+ CheckCondition condition = CheckCondition::invalid;
+ const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
+ if (failed(maybeOpRequiredMode)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
@@ -381,12 +394,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
return failure();
}
- CheckCondition condition = CheckCondition::invalid;
- // Find the profiles or extensions requirement according to the signature of
- // type of the operand list.
- SmallVector<T> opRequiredMode =
- findMatchedProfile<T>(op, it->second, condition);
-
+ // Find the required profiles or extensions according to the operand type
+ // combination.
+ const auto opRequiredMode = maybeOpRequiredMode.value();
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -466,6 +476,17 @@ TosaProfileCompliance::checkExtension(Operation *op,
return success();
}
+LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
+ CheckCondition condition = CheckCondition::invalid;
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
+ !maybeProfDef.value().size() && !maybeExtDef.value().size())
+ return failure();
+
+ return success();
+}
+
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
@@ -483,7 +504,6 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
for (size_t i = 0; i < compInfo.size(); i++) {
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
-
for (SmallVector<TypeInfo> expected : sets) {
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 79c13793d7713..3ec7354562d23 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
this->profile = options.profile;
this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
+ this->allowInvalidOpDatatypeCombinations =
+ options.allowInvalidOpDatatypeCombinations;
this->level = options.level;
}
void runOnOperation() final;
@@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() {
}
}
+ if (!allowInvalidOpDatatypeCombinations &&
+ failed(profileComp.checkInvalid(op))) {
+ op->emitOpError("illegal: operand/result data types not supported");
+ return signalPassFailure();
+ }
+
// Some uses of TOSA rely on the constant operands of particular
// operations.
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index ecd5c792e08b6..731e134ed1a07 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -4,20 +4,20 @@
// -----
// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
- tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
// -----
// check that -tosa-validate level checking kick in
-func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// expected-error at +1 {{'tosa.abs' op failed level check: unranked tensor}}
- %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
- return %0 : tensor<*xi8>
+ %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index fd9b3d5f23483..0ec46022157d7 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,13 +2,13 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %0 : tensor<13x21x3xi8>
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3203c64b439da..ac8a247da24a7 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -616,17 +616,17 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
// -----
-func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable' op name has already been declared}}
- tosa.variable @stored_var = dense<3> : tensor<1x4x8xi32>
+ tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
return
}
// -----
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
return
@@ -634,8 +634,8 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
// -----
-func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
return
@@ -644,7 +644,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
// -----
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
return
@@ -652,10 +652,10 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
// -----
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error at +1 {{'tosa.variable.write' op operand type does not equal variable type}}
- tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
+ tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
@@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
return %1 : tensor<f32>
}
+
+// -----
+
+// CHECK-LABEL: test_add_i1
+func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ // expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_out_i16
+func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
+ // expected-error at +1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index bde5b5ec7cffe..d1594232e4e1d 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor<
// -----
-func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> {
+func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
// expected-error at +1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}}
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
- return %0 : tensor<13x21x3xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bdf18ec823128..0f469761d89e3 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso
// -----
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
+func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> {
// expected-error at +1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
- return %0 : tensor<1x1x1x1x1x1x64xi16>
+ %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32>
+ return %0 : tensor<1x1x1x1x1x1x64xi32>
}
// -----
More information about the Mlir-commits
mailing list