[Mlir-commits] [mlir] [mlir][tosa] Allow unsigned types for rescale ops during validation (PR #138253)
Luke Hutton
llvmlistbot at llvm.org
Fri May 2 04:26:46 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/138253
This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators.
>From e4d57166b9cce9c99518bab9837ef295c0c2f51b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 30 Apr 2025 22:09:04 +0000
Subject: [PATCH] [mlir][tosa] Allow unsigned types for rescale ops during
validation
This commit allows unsigned types (ui8/ui16/ui32) when checking
for valid element types, only for rescale operators.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: I0525c5a5542e20e832d1bf150635be7423d3799a
---
.../Tosa/Transforms/TosaValidation.cpp | 23 +++++++++++++-----
mlir/test/Dialect/Tosa/invalid.mlir | 24 +++++++++++++++++++
2 files changed, 41 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index e8b52d48347ab..feedc5057bea0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
- bool isValidElementType(Type type);
+ bool isValidElementType(Type type, const bool allowUnsigned = false);
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
return success();
}
-bool TosaValidation::isValidElementType(Type type) {
+bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
case 48:
return true;
}
+ } else if (allowUnsigned && intTy.isUnsigned()) {
+ switch (intTy.getWidth()) {
+ case 8:
+ case 16:
+ case 32:
+ return true;
+ }
}
} else if (mlir::isa<tosa::shapeType>(type)) {
return true;
@@ -1209,11 +1216,15 @@ void TosaValidation::runOnOperation() {
if (op->getDialect() != tosaDialect)
return;
- // perform valid element type check at the beginning to
- // protect rest of code against quantized element types
+ // validate operator element types:
+ // - rescale operator is allowed to have ui8/ui16/ui32
+ // operands/results
+ // - perform valid element type check at the beginning to
+ // protect rest of code against quantized element types
+ const bool opIsRescale = isa<tosa::RescaleOp>(op);
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
- if (!isValidElementType(elementTy)) {
+ if (!isValidElementType(elementTy, opIsRescale)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
@@ -1221,7 +1232,7 @@ void TosaValidation::runOnOperation() {
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
- if (!isValidElementType(elementTy)) {
+ if (!isValidElementType(elementTy, opIsRescale)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c4f95b47628d1..c1f4f22887c79 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te
%0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ return %r : tensor<1x1xui8>
+}
More information about the Mlir-commits
mailing list