[Mlir-commits] [mlir] [tosa] : Get quantized element type with sign info. (PR #169387)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 24 10:35:50 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir
Author: Sayan Saha (sahas3)
<details>
<summary>Changes</summary>
As mentioned in https://github.com/llvm/llvm-project/blob/a27bb38ee6f5762e715803d8eb6ffc5a8dd09575/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h#L109 `QuantType::getStorageType` doesn't capture the sign information. This lead to the following IR to fail during verification:
```
func.func @<!-- -->clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
%0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
}
```
with `'tosa.clamp' op min/max attributes types are incompatible with input/output element types` error
since `getStorageType` was returning signed integer but the clamp attributes were unsigned.
This PR updates the usage of `getStorageType` in tosa codebase to correctly use the signed info for the quantized type.
---
Full diff: https://github.com/llvm/llvm-project/pull/169387.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h (+4-1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+2-1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+10-10)
- (modified) mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (+13)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+30)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+7)
- (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+11-3)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 9d9a934cdfd5e..0e751911df94d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -19,7 +19,7 @@
#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
-namespace mlir {
+namespace mlir {
namespace tosa {
//===----------------------------------------------------------------------===//
@@ -88,6 +88,9 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
IntegerAttr quantBits, int filterQuantDim,
bool isSigned, BoolAttr narrowRange);
+Type
+getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a118ac9c4b111..c420a4c9596ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
}
Attribute newMinValAttr, newMaxValAttr;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 65e0a59d39168..1c175f9ab0207 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
- srcType = quantType.getStorageType();
+ srcType = getStorageElementTypeFromQuantized(quantType);
return srcType;
}
@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
- weightEType = quantType.getStorageType();
+ weightEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
- biasEType = quantType.getStorageType();
+ biasEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
// for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
outputType.getElementType())) {
- if (result.getStorageType() == attrType.getElementType())
+ if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
return success();
}
@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
return success();
}
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
llvm::cast<ShapedType>(getInput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
- inputETy = quantType.getStorageType();
+ inputETy = getStorageElementTypeFromQuantized(quantType);
}
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
- outputETy = quantType.getStorageType();
+ outputETy = getStorageElementTypeFromQuantized(quantType);
}
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 02c86a090e6d4..c55b13dc98cc5 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
maxAttr, quantBits, filterQuantDim,
isSigned, narrowRange));
}
+
+Type mlir::tosa::getStorageElementTypeFromQuantized(
+ quant::QuantizedType quantType) {
+ auto quantEty = quantType.getStorageType();
+ // StorageType doesn't capture the sign information
+ // Explicitly create unsigned type if needed
+ if (!quantType.isSigned()) {
+ quantEty = IntegerType::get(quantEty.getContext(),
+ quantEty.getIntOrFloatBitWidth(),
+ IntegerType::Unsigned);
+ }
+ return quantEty;
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fc5ea7710e2c4..84776c47b628d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens
return %1 : tensor<4xi8>
}
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8}
+func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp
+// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8}
+func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp
+// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8}
+// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8}
+func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+ return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>
+}
+
+
// -----
// CHECK-LABEL: @concat_fold
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4591f7ffd393..652447bd6056e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 1.000000e-01:-127>>
}
+// -----
+// CHECK-LABEL: clamp_quantized_unsigned
+func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
+
// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index f0ad4eb4fdb0b..88dffe7fdd2e8 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -1,13 +1,21 @@
// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
// -----
-// CHECK-LABEL: test_build_qtype
-func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+// CHECK-LABEL: test_build_qtype_unsigned
+func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
// CHECK: tosa.negate
- %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
}
+// -----
+// CHECK-LABEL: test_build_qtype_signed
+func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> {
+ // CHECK: tosa.negate
+ %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+ return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>
+}
+
// -----
// CHECK-LABEL: test_build_mult_and_shift
func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cdc7ad8e..ea64d468f151e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
}
+
+// -----
+
+func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) {
+ // expected-error at +1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}}
+ %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+ return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/169387
More information about the Mlir-commits
mailing list