[Mlir-commits] [mlir] 356bd2c - [mlir][tosa] Allow unsigned types for rescale ops during validation (#138253)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 03:33:40 PDT 2025
Author: Luke Hutton
Date: 2025-05-08T11:33:37+01:00
New Revision: 356bd2c9605761121b49f318a187560ec306718e
URL: https://github.com/llvm/llvm-project/commit/356bd2c9605761121b49f318a187560ec306718e
DIFF: https://github.com/llvm/llvm-project/commit/356bd2c9605761121b49f318a187560ec306718e.diff
LOG: [mlir][tosa] Allow unsigned types for rescale ops during validation (#138253)
This commit allows unsigned types (ui8/ui16/ui32) when checking for
valid element types, only for rescale operators.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Added:
Modified:
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index e8b52d48347ab..feedc5057bea0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
- bool isValidElementType(Type type);
+ bool isValidElementType(Type type, const bool allowUnsigned = false);
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
return success();
}
-bool TosaValidation::isValidElementType(Type type) {
+bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
case 48:
return true;
}
+ } else if (allowUnsigned && intTy.isUnsigned()) {
+ switch (intTy.getWidth()) {
+ case 8:
+ case 16:
+ case 32:
+ return true;
+ }
}
} else if (mlir::isa<tosa::shapeType>(type)) {
return true;
@@ -1209,11 +1216,15 @@ void TosaValidation::runOnOperation() {
if (op->getDialect() != tosaDialect)
return;
- // perform valid element type check at the beginning to
- // protect rest of code against quantized element types
+ // validate operator element types:
+ // - rescale operator is allowed to have ui8/ui16/ui32
+ // operands/results
+ // - perform valid element type check at the beginning to
+ // protect rest of code against quantized element types
+ const bool opIsRescale = isa<tosa::RescaleOp>(op);
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
- if (!isValidElementType(elementTy)) {
+ if (!isValidElementType(elementTy, opIsRescale)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
@@ -1221,7 +1232,7 @@ void TosaValidation::runOnOperation() {
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
- if (!isValidElementType(elementTy)) {
+ if (!isValidElementType(elementTy, opIsRescale)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 56d76585be71b..732c980f3ab92 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te
%0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ return %r : tensor<1x1xui8>
+}
More information about the Mlir-commits
mailing list