[Mlir-commits] [mlir] [mlir][tosa] Remove Quantization Attribute (PR #125479)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 3 03:06:34 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Rename operands as follows to align with the TOSA-v1.0 specification:
* `cond` -> `condition`
* `then_branch` -> `then_graph`
* `else_branch` -> `else_graph`
* `inputs` -> `input_list`
* `output` -> `output_list`
* `cond` -> `cond_graph`
* `body` -> `body_graph`

---

Patch is 39.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125479.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+17-13) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+52-48) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+11-20) 
- (modified) mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp (+6-6) 
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+56-27) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-3) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+2-2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-9) 
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+5-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015..fef0f2d98d95c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
     Tosa_Tensor2D:$input,
     TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$weight_zp
   );
 
   let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   let arguments = (ins
     Tosa_Tensor3D:$a,
     Tosa_Tensor3D:$b,
-    OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$a_zp,
+    OptionalAttr<I32Attr>:$b_zp
   );
 
   let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
 
   let arguments = (ins
       Tosa_Tensor:$input1,
-      OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+      OptionalAttr<I32Attr>:$input1_zp,
+      OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Tosa_RankedTensor:$input1,
     Tosa_Shape:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
-    OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
@@ -2071,17 +2075,17 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$cond,
+    Tosa_I1Tensor:$condition,
     Variadic<Tosa_Tensor>:$inputs
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   let regions = (region
-    SizedRegion<1>:$then_branch,
-    SizedRegion<1>:$else_branch
+    SizedRegion<1>:$then_graph,
+    SizedRegion<1>:$else_graph
   );
 
   let hasCustomAssemblyFormat = 1;
@@ -2108,16 +2112,16 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$inputs
+    Variadic<Tosa_Tensor>:$input_list
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   let regions = (region
-    SizedRegion<1>:$cond,
-    SizedRegion<1>:$body
+    SizedRegion<1>:$cond_graph,
+    SizedRegion<1>:$body_graph
   );
 
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..449baad0edeafe 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,67 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::NegateOp
-  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+  if (isa<tosa::NegateOp>(op)) {
+    if (isa<FloatType>(elementTy))
+      return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
-    int64_t inZp = 0, outZp = 0;
+    auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
+    auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+    int32_t inputZpVal = inputZpAttr ? inputZpAttr.getInt() : 0;
+    int32_t outputZpVal = outputZpAttr ? outputZpAttr.getInt() : 0;
 
-    if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-      auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
-      inZp = quantizationInfo.value().getInputZp();
-      outZp = quantizationInfo.value().getOutputZp();
-    }
-
-    int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    if (!inZp && !outZp) {
+    if (isa<IntegerType>(elementTy) && inputZpVal == 0 && outputZpVal == 0) {
       auto constant = rewriter.create<arith::ConstantOp>(
           loc, IntegerAttr::get(elementTy, 0));
       return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
                                             args[0]);
     }
 
-    // Compute the maximum value that can occur in the intermediate buffer.
-    int64_t zpAdd = inZp + outZp;
-    int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-                       std::abs(zpAdd) + 1;
-
-    // Convert that maximum value into the maximum bitwidth needed to represent
-    // it. We assume 48-bit numbers may be supported further in the pipeline.
-    int intermediateBitWidth = 64;
-    if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-      intermediateBitWidth = 16;
-    } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-      intermediateBitWidth = 32;
-    } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-      intermediateBitWidth = 48;
-    }
+    if (isa<IntegerType>(elementTy) && (inputZpVal != 0 || outputZpVal != 0)) {
+      int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+      int64_t inZp = inputZpVal;
+      int64_t outZp = outputZpVal;
+
+      // Compute the maximum value that can occur in the intermediate buffer.
+      int64_t zpAdd = inZp + outZp;
+      int64_t maxValue =
+          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+          std::abs(zpAdd) + 1;
+
+      // Convert that maximum value into the maximum bitwidth needed to
+      // represent it. We assume 48-bit numbers may be supported further in the
+      // pipeline.
+      int intermediateBitWidth = 64;
+      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+        intermediateBitWidth = 16;
+      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+        intermediateBitWidth = 32;
+      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+        intermediateBitWidth = 48;
+      }
 
-    Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-    Value zpAddValue = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
-    // The negation can be applied by doing:
-    //  outputValue = inZp + outZp - inputValue
-    auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
-    auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
-    // Clamp to the negation range.
-    Value min = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    Value max = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    auto clamp =
-        clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
-    // Truncate to the final value.
-    return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+      Value zpAddValue = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+      // The negation can be applied by doing:
+      //  outputValue = inZp + outZp - inputValue
+      auto ext =
+          rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+      auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+      // Clamp to the negation range.
+      Value min = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      Value max = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+      // Truncate to the final value.
+      return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+    }
   }
 
   // tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9..1e02301f7c23d5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
                            .create<linalg::FillOp>(loc, ValueRange{zero},
                                                    ValueRange{emptyTensor})
                            .result();
-    if (!op.getQuantizationInfo()) {
+    if (!op.getAZp() && !op.getBZp()) {
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
           op, TypeRange{op.getType()},
           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto aZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
-    auto bZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+    auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+    auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (!op.getQuantizationInfo()) {
+    if (!op.getInputZp() && !op.getWeightZp()) {
       Value matmul = rewriter
                          .create<linalg::MatmulOp>(
                              loc, TypeRange{op.getType()},
@@ -672,11 +669,8 @@ class FullyConnectedConverter
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto inputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
-    auto outputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+    auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+    auto outputZp = rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
     Value matmul =
         rewriter
             .create<linalg::QuantizedMatmulOp>(
@@ -958,10 +952,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply an offset
             // for the input zp value.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
+            if (op.getInputZp()) {
               auto inputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+                  loc, op.getInputZpAttr());
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
@@ -1013,11 +1006,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply output
             // zeropoint.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
-              auto outputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(scaled.getType(),
-                                        quantizationInfo.getOutputZp()));
+            if (op.getOutputZp()) {
+              auto outputZp =
+                  rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf11..80c58bdc0550cc 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
   LogicalResult matchAndRewrite(tosa::IfOp op,
                                 PatternRewriter &rewriter) const final {
     auto condition =
-        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
     auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
                                             condition, true);
 
-    inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+    inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputs(),
                  rewriter);
-    inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+    inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputs(),
                  rewriter);
 
     rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
   LogicalResult matchAndRewrite(tosa::WhileOp op,
                                 PatternRewriter &rewriter) const final {
     auto newWhile = rewriter.create<scf::WhileOp>(
-        op.getLoc(), op.getResultTypes(), op.getInputs());
+        op.getLoc(), op.getResultTypes(), op.getInputList());
     rewriter.createBlock(&newWhile.getBefore());
     rewriter.createBlock(&newWhile.getAfter());
 
-    inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
-    inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+    inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+    inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
 
     rewriter.replaceOp(op, newWhile.getResults());
 
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..2a9b4d111bdfa2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
       TypedAttr constantAttr;
       if (isa<FloatType>(elementTy)) {
         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-      } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+      } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-      } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
-        int64_t value = padOp.getQuantizationInfo()->getInputZp();
+      } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+        int64_t value = padOp.getInputZpAttr().getInt();
         constantAttr = rewriter.getIntegerAttr(elementTy, value);
       }
       if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..8e22c879753a33 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     Attribute constantAttr;
     if (llvm::isa<FloatType>(elementTy)) {
       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+    } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
-      auto value = op.getQuantizationInfo()->getInputZp();
+    } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+      int64_t value = op.getInputZpAttr().getInt();
       constantAttr = rewriter.getIntegerAttr(elementTy, value);
     }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..9bde6a85935255 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -124,7 +124,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
 //===----------------------------------------------------------------------===//
 
 /// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+  return {&getBodyGraph()};
+}
 
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
@@ -271,11 +273,11 @@ static LogicalResult verifyConvOp(T op) {
     }
   }
 
-  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
-  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
 
-  // Either both must be quantized or both unquantized.
-  if (inputIsQuant != weightIsQuant) {
+  // Either both must be float or both non-float.
+  if (inputIsFloat != weightIsFloat) {
     op.emitOpError(
         "expect both input and weight to be float or not together, got ")
         << inputEType << " and " << weightEType;
@@ -527,7 +529,12 @@ static void buildTransConvOpWithQuantInfo(
   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("weight_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getWeightZp())));
     result.addTypes(
         buildConvOpResultTypeInfo(builder, outputType, input, weight));
   } else {
@@ -563,7 +570,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("a_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getAZp())));
+    result.addAttribute("b_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getBZp())));
 
     auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
     assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +613,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.addAttribute("pad", pad);
   result.addAttribute("acc_type", accType);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("output_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getOutputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -616,8 +632,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
                                       Value input) {
   result.addOperands(input);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    // note: negateOp has attributes input1_zp and output_zp
+    result.addAttribute("input1_zp",
+                     ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125479


More information about the Mlir-commits mailing list