[Mlir-commits] [mlir] [mlir][tosa] Replace UniformQuantizedType by the more generic Quantiz… (PR #126275)

Tai Ly llvmlistbot at llvm.org
Fri Feb 7 09:38:53 PST 2025


https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/126275

…edType in Conv verifiers

also fixed buildTransConvOpWithQuantInfo to insert input/weight zp operands

Change-Id: Ie1961af931864f801914a62976bc988881ee075e

>From b67944bc0c3bb7e307dc0ccee41d0087bccaac7c Mon Sep 17 00:00:00 2001
From: Thibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon at arm.com>
Date: Wed, 8 May 2024 18:09:42 +0100
Subject: [PATCH] [mlir][tosa] Replace UniformQuantizedType by the more generic
 QuantizedType in Conv verifiers

also fixed buildTransConvOpWithQuantInfo to insert input/weight zp
operands

Change-Id: Ie1961af931864f801914a62976bc988881ee075e
Signed-off-by: Tai Ly <tai.ly at arm.com>
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 42 +++++++++++++++++-----------
 1 file changed, 26 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..6143a9d23a00394 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
   bool biasIsFloat = llvm::isa<FloatType>(biasEType);
   bool resultIsFloat = llvm::isa<FloatType>(resultEType);
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
     resultEType = quantType.getStorageType();
 
   if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
   auto inputEType =
       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
   auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
   if (inputEType.isF32() && !accType.isF32())
     return op.emitOpError("accumulator type for f32 tensor is not f32");
 
-  return success();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  // check allowed input/result element types combinations
+  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+      (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+      (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
+      (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
+      (inputEType.isF16() && resultEType.isF16()) ||
+      (inputEType.isBF16() && resultEType.isBF16()) ||
+      (inputEType.isF32() && resultEType.isF32()))
+    return success();
+
+  return op.emitOpError("input/output element types are incompatible.");
 }
 
 // verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
     OpBuilder &builder, OperationState &result, Type outputType, Value input,
     Value weight, Value bias, DenseI64ArrayAttr outpad,
     DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
   result.addAttribute("out_shape", outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
   return failure();
 }
 
-// Create a rank-0 const tensor for zero point of the source tensor.
+// Create a rank-1 const tensor for zero point of the source tensor.
 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
                                                        Location loc,
                                                        Type srcElemType,
                                                        int64_t zp) {
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
-    srcElemType = quantType.getStorageType();
-
-  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+  srcElemType = getElementTypeOrSelf(srcElemType);
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
     srcElemType = quantType.getStorageType();
+  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
   if (llvm::isa<FloatType>(srcElemType)) {
     auto zpAttr = DenseElementsAttr::get(
         zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));



More information about the Mlir-commits mailing list