[Mlir-commits] [mlir] [tosa] : Get quantized element type with sign info. (PR #169387)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 24 10:35:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Sayan Saha (sahas3)

<details>
<summary>Changes</summary>

As mentioned in https://github.com/llvm/llvm-project/blob/a27bb38ee6f5762e715803d8eb6ffc5a8dd09575/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h#L109 `QuantType::getStorageType` doesn't capture the sign information. This lead to the following IR to fail during verification:
```
func.func @<!-- -->clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
    %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}
```
with `'tosa.clamp' op min/max attributes types are incompatible with input/output element types` error
since `getStorageType` was returning signed integer but the clamp attributes were unsigned.

This PR updates the usage of `getStorageType` in tosa codebase to correctly use the signed info for the quantized type.

---
Full diff: https://github.com/llvm/llvm-project/pull/169387.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h (+4-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+10-10) 
- (modified) mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (+13) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+30) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+7) 
- (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+11-3) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 9d9a934cdfd5e..0e751911df94d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -19,7 +19,7 @@
 #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
 #include "mlir/Dialect/Quant/Utils/UniformSupport.h"
 
-namespace mlir {
+namespace mlir {    
 namespace tosa {
 
 //===----------------------------------------------------------------------===//
@@ -88,6 +88,9 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
                                   IntegerAttr quantBits, int filterQuantDim,
                                   bool isSigned, BoolAttr narrowRange);
 
+Type
+getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a118ac9c4b111..c420a4c9596ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
     if (auto quantType =
             llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
-      inputEType = quantType.getStorageType();
+      inputEType = getStorageElementTypeFromQuantized(quantType);
     }
 
     Attribute newMinValAttr, newMaxValAttr;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 65e0a59d39168..1c175f9ab0207 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
 static Type getStorageElementTypeOrSelf(Type type) {
   auto srcType = getElementTypeOrSelf(type);
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
-    srcType = quantType.getStorageType();
+    srcType = getStorageElementTypeFromQuantized(quantType);
   return srcType;
 }
 
@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
   bool resultIsFloat = llvm::isa<FloatType>(resultEType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
-    inputEType = quantType.getStorageType();
+    inputEType = getStorageElementTypeFromQuantized(quantType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
-    weightEType = quantType.getStorageType();
+    weightEType = getStorageElementTypeFromQuantized(quantType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
-    biasEType = quantType.getStorageType();
+    biasEType = getStorageElementTypeFromQuantized(quantType);
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
-    resultEType = quantType.getStorageType();
+    resultEType = getStorageElementTypeFromQuantized(quantType);
 
   if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
     // for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
 
   if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
           outputType.getElementType())) {
-    if (result.getStorageType() == attrType.getElementType())
+    if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
       return success();
   }
 
@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
-    inputEType = quantType.getStorageType();
+    inputEType = getStorageElementTypeFromQuantized(quantType);
 
   auto accType = op.getAccType();
   if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
       llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
 
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
-    resultEType = quantType.getStorageType();
+    resultEType = getStorageElementTypeFromQuantized(quantType);
 
   return success();
 }
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
       llvm::cast<ShapedType>(getInput().getType()).getElementType();
   if (auto quantType =
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
-    inputETy = quantType.getStorageType();
+    inputETy = getStorageElementTypeFromQuantized(quantType);
   }
   mlir::Type outputETy =
       llvm::cast<ShapedType>(getOutput().getType()).getElementType();
   if (auto quantType =
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
-    outputETy = quantType.getStorageType();
+    outputETy = getStorageElementTypeFromQuantized(quantType);
   }
   if (inputETy != outputETy)
     return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 02c86a090e6d4..c55b13dc98cc5 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
                                             maxAttr, quantBits, filterQuantDim,
                                             isSigned, narrowRange));
 }
+
+Type mlir::tosa::getStorageElementTypeFromQuantized(
+    quant::QuantizedType quantType) {
+  auto quantEty = quantType.getStorageType();
+  // StorageType doesn't capture the sign information
+  // Explicitly create unsigned type if needed
+  if (!quantType.isSigned()) {
+    quantEty = IntegerType::get(quantEty.getContext(),
+                                quantEty.getIntOrFloatBitWidth(),
+                                IntegerType::Unsigned);
+  }
+  return quantEty;
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fc5ea7710e2c4..84776c47b628d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
   return %1 : tensor<4xi8>
 }
 
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
+func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
+func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
+// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
+// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
+func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+    return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+
 // -----
 
 // CHECK-LABEL: @concat_fold
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4591f7ffd393..652447bd6056e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
 }
 
+// -----
+// CHECK-LABEL: clamp_quantized_unsigned
+func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+    %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
 // -----
 // CHECK-LABEL: sigmoid
 func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index f0ad4eb4fdb0b..88dffe7fdd2e8 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -1,13 +1,21 @@
 // RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
 
 // -----
-// CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+// CHECK-LABEL: test_build_qtype_unsigned
+func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
   //  CHECK: tosa.negate
-  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
   return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
 }
 
+// -----
+// CHECK-LABEL: test_build_qtype_signed
+func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
+  //  CHECK: tosa.negate
+  %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+  return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+}
+
 // -----
 // CHECK-LABEL: test_build_mult_and_shift
 func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cdc7ad8e..ea64d468f151e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+    // expected-error at +1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
+    %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+    return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/169387


More information about the Mlir-commits mailing list