[Mlir-commits] [mlir] [mlir][tosa] Fix conv op build functions (PR #126321)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 7 14:50:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This patch fixes several issues:
  - buildConvOpWithQuantInfo: 
       call buildConvOpResultTypeInfo to get final output type
  - buildTransConvOpWithQuantInfo:
       add input_zp and weight_zp operands
       remove input_zp/weight_zp attributes
  - createZeroPointTensor:
       add getElementTypeOrSelf to get element type just in case
       remove bad auto-merge lines


Change-Id: Idbf88f500ce57a865da4b7be7b7b8bf2ba194b24

---
Full diff: https://github.com/llvm/llvm-project/pull/126321.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+17-20) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 955021abdd67b12..fd166cc1322cef9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -510,7 +510,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.addAttribute("stride", stride);
   result.addAttribute("dilation", dilation);
   result.addAttribute("acc_type", accType);
-  result.addTypes(outputType);
+  Type finalOutputType = outputType;
+  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
+  if (quantAttr) {
+    finalOutputType =
+        buildConvOpResultTypeInfo(builder, outputType, input, weight);
+  }
+  result.addTypes(finalOutputType);
 }
 
 /// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -519,25 +525,19 @@ static void buildTransConvOpWithQuantInfo(
     OpBuilder &builder, OperationState &result, Type outputType, Value input,
     Value weight, Value bias, DenseI64ArrayAttr outpad,
     DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
   result.addAttribute("out_shape", outputShape);
   result.addAttribute("acc_type", accType);
-  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
-
+  Type finalOutputType = outputType;
+  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
   if (quantAttr) {
-    result.addAttribute("input_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getInputZp())));
-    result.addAttribute("weight_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getWeightZp())));
-    result.addTypes(
-        buildConvOpResultTypeInfo(builder, outputType, input, weight));
-  } else {
-    result.addTypes(outputType);
+    finalOutputType =
+        buildConvOpResultTypeInfo(builder, outputType, input, weight);
   }
+  result.addTypes(finalOutputType);
 }
 
 /// The tosa.fully_connected op has its own builder as it does not have
@@ -2492,18 +2492,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
   return failure();
 }
 
-// Create a rank-0 const tensor for zero point of the source tensor.
+// Create a rank-1 const tensor for zero point of the source tensor.
 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
                                                        Location loc,
                                                        Type srcElemType,
                                                        int64_t zp) {
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
-    srcElemType = quantType.getStorageType();
-
-  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+  srcElemType = getElementTypeOrSelf(srcElemType);
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
     srcElemType = quantType.getStorageType();
+  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
   if (llvm::isa<FloatType>(srcElemType)) {
     auto zpAttr = DenseElementsAttr::get(
         zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));

``````````

</details>


https://github.com/llvm/llvm-project/pull/126321


More information about the Mlir-commits mailing list