[Mlir-commits] [mlir] 4e9c3ce - [mlir][tosa] Improve invalid operator data types error message (#140756)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 22 09:06:35 PDT 2025
Author: Luke Hutton
Date: 2025-05-22T17:06:32+01:00
New Revision: 4e9c3ce39189fc68f83be03f85a6a504de537049
URL: https://github.com/llvm/llvm-project/commit/4e9c3ce39189fc68f83be03f85a6a504de537049
DIFF: https://github.com/llvm/llvm-project/commit/4e9c3ce39189fc68f83be03f85a6a504de537049.diff
LOG: [mlir][tosa] Improve invalid operator data types error message (#140756)
The error message on invalid operator data types in the validation pass
was not very clear. This commit improves the error message as follows:
Current:
```
'tosa.add' op illegal: operand/result data types not supported
```
Improved:
```
'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification.
```
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 716e55706c625..8f5c72bc5f7a9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -164,6 +164,8 @@ class TosaProfileCompliance {
SmallVector<StringRef>
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);
+ static llvm::SmallString<7> stringifyTypeInfo(const TypeInfo &typeInfo);
+
private:
template <typename T>
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 66ea00b23b9d4..1a896c1464e1c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -485,9 +485,52 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
- !maybeProfDef.value().size() && !maybeExtDef.value().size())
+ !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "illegal: operation operand/result data types did not align with any "
+ "profile or extension, got (";
+
+ ProfileInfoDepot depot(op);
+ SmallVector<TypeInfo> current = depot.getInfo();
+ for (const auto &typeInfo : llvm::drop_end(current))
+ os << stringifyTypeInfo(typeInfo) << ",";
+ os << stringifyTypeInfo(current.back()) << ")";
+
+ // avoid polluting the error message output by outputting only
+ // the best match
+ const std::string opName = op->getName().getStringRef().str();
+ int maxMatches = -1;
+ SmallVector<TypeInfo> bestTypeInfo;
+ const auto searchBestMatch = [&](auto map) {
+ for (const auto &complianceInfos : map[opName]) {
+ for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
+ const int matches = llvm::count_if(
+ llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
+ return isSameTypeInfo(std::get<0>(zipType),
+ std::get<1>(zipType));
+ });
+ if (matches > maxMatches) {
+ maxMatches = matches;
+ bestTypeInfo = typeInfos;
+ }
+ }
+ }
+ };
+ searchBestMatch(getProfileComplianceMap<Profile>());
+ searchBestMatch(getProfileComplianceMap<Extension>());
+
+ os << ", did you mean (";
+ for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
+ os << stringifyTypeInfo(typeInfo) << ",";
+ os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
+ os << "Otherwise, please refer to the 'supported data types' for '"
+ << opName << "' in the specification.";
+ op->emitOpError(message);
return failure();
+ }
return success();
}
@@ -562,3 +605,21 @@ SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
return debugStrings;
}
+
+llvm::SmallString<7>
+TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
+ if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
+ return {"i" + llvm::utostr(typeInfo.bitWidth)};
+ } else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
+ return {"f16"};
+ } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
+ return {"f32"};
+ } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
+ return {"bf16"};
+ } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
+ return {"fp8e4m3"};
+ } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
+ return {"fp8e5m2"};
+ }
+ llvm_unreachable("unknown type");
+}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index feedc5057bea0..e6d8e7834bf2c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,10 +1248,8 @@ void TosaValidation::runOnOperation() {
return signalPassFailure();
if (!allowInvalidOpDatatypeCombinations &&
- failed(profileComp.checkInvalid(op))) {
- op->emitOpError("illegal: operand/result data types not supported");
+ failed(profileComp.checkInvalid(op)))
return signalPassFailure();
- }
// Some uses of TOSA rely on the constant operands of particular
// operations.
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 732c980f3ab92..7b589fa839b44 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -35,7 +35,7 @@ func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2:
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.conv2d' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i8,i8,i8,i32,i8), did you mean (i8,i8,i32,i8,i8,i32,i32)?}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
@@ -1888,7 +1888,7 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
// CHECK-LABEL: test_add_i1
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
- // expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification.}}
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
}
@@ -1897,7 +1897,7 @@ func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) ->
// CHECK-LABEL: test_mul_out_i16
func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
- // expected-error at +1 {{'tosa.mul' op illegal: operand/result data types not supported}}
+ // expected-error at +1 {{'tosa.mul' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i16), did you mean (i8,i8,i32)?}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
More information about the Mlir-commits
mailing list