[Mlir-commits] [mlir] [mlir][tosa] Allow unsigned types for rescale ops during validation (PR #138253)

Luke Hutton llvmlistbot at llvm.org
Fri May 2 04:26:46 PDT 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/138253

This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators.

>From e4d57166b9cce9c99518bab9837ef295c0c2f51b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 30 Apr 2025 22:09:04 +0000
Subject: [PATCH] [mlir][tosa] Allow unsigned types for rescale ops during
 validation

This commit allows unsigned types (ui8/ui16/ui32) when checking
for valid element types, only for rescale operators.

Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: I0525c5a5542e20e832d1bf150635be7423d3799a
---
 .../Tosa/Transforms/TosaValidation.cpp        | 23 +++++++++++++-----
 mlir/test/Dialect/Tosa/invalid.mlir           | 24 +++++++++++++++++++
 2 files changed, 41 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index e8b52d48347ab..feedc5057bea0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   bool CheckVariable(Operation *op);
   bool CheckVariableReadOrWrite(Operation *op);
-  bool isValidElementType(Type type);
+  bool isValidElementType(Type type, const bool allowUnsigned = false);
 
   SmallVector<
       std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   return success();
 }
 
-bool TosaValidation::isValidElementType(Type type) {
+bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
   if (isa<FloatType>(type)) {
     return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
                Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
       case 48:
         return true;
       }
+    } else if (allowUnsigned && intTy.isUnsigned()) {
+      switch (intTy.getWidth()) {
+      case 8:
+      case 16:
+      case 32:
+        return true;
+      }
     }
   } else if (mlir::isa<tosa::shapeType>(type)) {
     return true;
@@ -1209,11 +1216,15 @@ void TosaValidation::runOnOperation() {
     if (op->getDialect() != tosaDialect)
       return;
 
-    // perform valid element type check at the beginning to
-    // protect rest of code against quantized element types
+    // validate operator element types:
+    // - rescale operator is allowed to have ui8/ui16/ui32
+    //   operands/results
+    // - perform valid element type check at the beginning to
+    //   protect rest of code against quantized element types
+    const bool opIsRescale = isa<tosa::RescaleOp>(op);
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
-      if (!isValidElementType(elementTy)) {
+      if (!isValidElementType(elementTy, opIsRescale)) {
         op->emitOpError() << "is not profile-aligned: element type "
                           << elementTy << " is not legal";
         return signalPassFailure();
@@ -1221,7 +1232,7 @@ void TosaValidation::runOnOperation() {
     }
     for (Type resultTy : op->getResultTypes()) {
       auto elementTy = getElementTypeOrSelf(resultTy);
-      if (!isValidElementType(elementTy)) {
+      if (!isValidElementType(elementTy, opIsRescale)) {
         op->emitOpError() << "is not profile-aligned: element type "
                           << elementTy << " is not legal";
         return signalPassFailure();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c4f95b47628d1..c1f4f22887c79 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te
   %0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+  return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+  return %r : tensor<1x1xui8>
+}



More information about the Mlir-commits mailing list