[Mlir-commits] [mlir] [mlir][tosa] Remove Quantization Attribute (PR #125479)

Jack Frankland llvmlistbot at llvm.org
Mon Feb 3 04:00:44 PST 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/125479

>From 16b43c8c0caa16f2b76ee2d261e299346ec18431 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 13 Feb 2024 19:35:14 +0000
Subject: [PATCH] [mlir][tosa] Remove Quantization Attribute

Removed the TOSA quantization attribute used in various MLIR TOSA
dialect operations in favour of using builtin attributes.

Update any lit tests, conversions and transformations appropriately.

Signed-off-by: Tai Ly <tai.ly at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  14 ++-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 108 +++++++++---------
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  34 +++---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  |   6 +-
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |   6 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |  60 +++++++---
 .../Tosa/Transforms/TosaDecomposeConv2D.cpp   |   6 +-
 .../Transforms/TosaDecomposeTransposeConv.cpp |   5 +-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |   4 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          |  28 +++--
 .../TosaToTensor/tosa-to-tensor.mlir          |   2 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |   2 +-
 .../Dialect/Tosa/tosa-decompose-conv2d.mlir   |   6 +-
 .../Tosa/tosa-decompose-transpose-conv.mlir   |  10 +-
 14 files changed, 162 insertions(+), 129 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015..48d8f1bd4836c9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
     Tosa_Tensor2D:$input,
     TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp,
+    OptionalAttr<I32Attr>:$weight_zp
   );
 
   let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   let arguments = (ins
     Tosa_Tensor3D:$a,
     Tosa_Tensor3D:$b,
-    OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$a_zp,
+    OptionalAttr<I32Attr>:$b_zp
   );
 
   let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
 
   let arguments = (ins
       Tosa_Tensor:$input1,
-      OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+      OptionalAttr<I32Attr>:$input1_zp,
+      OptionalAttr<I32Attr>:$output_zp
   );
 
   let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
     Tosa_RankedTensor:$input1,
     Tosa_Shape:$padding,
     Optional<Tosa_ScalarTensor>:$pad_const,
-    OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+    OptionalAttr<I32Attr>:$input_zp
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..3d7cf3e1959e44 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,65 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::NegateOp
-  if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+  if (isa<tosa::NegateOp>(op)) {
+    if (isa<FloatType>(elementTy))
+      return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
 
-  if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
-    int64_t inZp = 0, outZp = 0;
-
-    if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
-      auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
-      inZp = quantizationInfo.value().getInputZp();
-      outZp = quantizationInfo.value().getOutputZp();
-    }
+    if (isa<IntegerType>(elementTy)) {
+      auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
+      auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+
+      if (inputZpAttr && outputZpAttr) {
+        auto constant = rewriter.create<arith::ConstantOp>(
+            loc, IntegerAttr::get(elementTy, 0));
+        return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
+                                              args[0]);
+      }
 
-    int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-    if (!inZp && !outZp) {
-      auto constant = rewriter.create<arith::ConstantOp>(
-          loc, IntegerAttr::get(elementTy, 0));
-      return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
-                                            args[0]);
-    }
+      const int64_t inZp = inputZpAttr ? inputZpAttr.getInt() : 0;
+      const int64_t outZp = outputZpAttr ? outputZpAttr.getInt() : 0;
+
+      // Compute the maximum value that can occur in the intermediate buffer.
+      const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+      const int64_t zpAdd = inZp + outZp;
+      const int64_t maxValue =
+          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+          std::abs(zpAdd) + 1;
+
+      // Convert that maximum value into the maximum bitwidth needed to
+      // represent it. We assume 48-bit numbers may be supported further in
+      // the pipeline.
+      int intermediateBitWidth = 64;
+      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+        intermediateBitWidth = 16;
+      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+        intermediateBitWidth = 32;
+      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+        intermediateBitWidth = 48;
+      }
 
-    // Compute the maximum value that can occur in the intermediate buffer.
-    int64_t zpAdd = inZp + outZp;
-    int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-                       std::abs(zpAdd) + 1;
-
-    // Convert that maximum value into the maximum bitwidth needed to represent
-    // it. We assume 48-bit numbers may be supported further in the pipeline.
-    int intermediateBitWidth = 64;
-    if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-      intermediateBitWidth = 16;
-    } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-      intermediateBitWidth = 32;
-    } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-      intermediateBitWidth = 48;
+      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+      Value zpAddValue = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+      // The negation can be applied by doing:
+      //  outputValue = inZp + outZp - inputValue
+      auto ext =
+          rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+      auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+      // Clamp to the negation range.
+      Value min = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      Value max = rewriter.create<arith::ConstantIntOp>(
+          loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+          intermediateType);
+      auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+      // Truncate to the final value.
+      return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
     }
-
-    Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-    Value zpAddValue = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
-    // The negation can be applied by doing:
-    //  outputValue = inZp + outZp - inputValue
-    auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
-    auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
-    // Clamp to the negation range.
-    Value min = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    Value max = rewriter.create<arith::ConstantIntOp>(
-        loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
-        intermediateType);
-    auto clamp =
-        clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
-    // Truncate to the final value.
-    return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
   }
 
   // tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9..6321cb6087394a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
                            .create<linalg::FillOp>(loc, ValueRange{zero},
                                                    ValueRange{emptyTensor})
                            .result();
-    if (!op.getQuantizationInfo()) {
+    if (!op.getAZp() && !op.getBZp()) {
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
           op, TypeRange{op.getType()},
           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto aZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
-    auto bZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+    auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+    auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
         op, TypeRange{op.getType()},
         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (!op.getQuantizationInfo()) {
+    if (!op.getInputZp() && !op.getWeightZp()) {
       Value matmul = rewriter
                          .create<linalg::MatmulOp>(
                              loc, TypeRange{op.getType()},
@@ -672,11 +669,9 @@ class FullyConnectedConverter
       return success();
     }
 
-    auto quantizationInfo = *op.getQuantizationInfo();
-    auto inputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
-    auto outputZp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+    auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+    auto outputZp =
+        rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
     Value matmul =
         rewriter
             .create<linalg::QuantizedMatmulOp>(
@@ -958,10 +953,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply an offset
             // for the input zp value.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
-              auto inputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+            if (op.getInputZp()) {
+              auto inputZp =
+                  rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
@@ -1013,11 +1007,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply output
             // zeropoint.
-            if (op.getQuantizationInfo()) {
-              auto quantizationInfo = *op.getQuantizationInfo();
-              auto outputZp = rewriter.create<arith::ConstantOp>(
-                  loc, b.getIntegerAttr(scaled.getType(),
-                                        quantizationInfo.getOutputZp()));
+            if (op.getOutputZp()) {
+              auto outputZp =
+                  rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..2a9b4d111bdfa2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
       TypedAttr constantAttr;
       if (isa<FloatType>(elementTy)) {
         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-      } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+      } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-      } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
-        int64_t value = padOp.getQuantizationInfo()->getInputZp();
+      } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+        int64_t value = padOp.getInputZpAttr().getInt();
         constantAttr = rewriter.getIntegerAttr(elementTy, value);
       }
       if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..8e22c879753a33 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     Attribute constantAttr;
     if (llvm::isa<FloatType>(elementTy)) {
       constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+    } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
       constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
-      auto value = op.getQuantizationInfo()->getInputZp();
+    } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+      int64_t value = op.getInputZpAttr().getInt();
       constantAttr = rewriter.getIntegerAttr(elementTy, value);
     }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..031c279ff09e27 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -271,11 +271,11 @@ static LogicalResult verifyConvOp(T op) {
     }
   }
 
-  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
-  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
 
-  // Either both must be quantized or both unquantized.
-  if (inputIsQuant != weightIsQuant) {
+  // Either both must be float or both non-float.
+  if (inputIsFloat != weightIsFloat) {
     op.emitOpError(
         "expect both input and weight to be float or not together, got ")
         << inputEType << " and " << weightEType;
@@ -527,7 +527,12 @@ static void buildTransConvOpWithQuantInfo(
   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("weight_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getWeightZp())));
     result.addTypes(
         buildConvOpResultTypeInfo(builder, outputType, input, weight));
   } else {
@@ -563,7 +568,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
 
   if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
+    result.addAttribute("a_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getAZp())));
+    result.addAttribute("b_zp", builder.getI32IntegerAttr(
+                                    static_cast<int32_t>(quantAttr.getBZp())));
 
     auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
     assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +611,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.addAttribute("pad", pad);
   result.addAttribute("acc_type", accType);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("output_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getOutputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -616,8 +630,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
                                       Value input) {
   result.addOperands(input);
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    // note: negateOp has attributes input1_zp and output_zp
+    result.addAttribute("input1_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+    result.addAttribute("output_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getOutputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -629,8 +650,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                     Value paddings) {
   result.addOperands({input, paddings});
   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -643,8 +667,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
                                                  Value padConst) {
   result.addOperands({input, paddings, padConst});
   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
-  if (quantAttr)
-    result.addAttribute("quantization_info", quantAttr);
+  if (quantAttr) {
+    result.addAttribute("input_zp",
+                        builder.getI32IntegerAttr(
+                            static_cast<int32_t>(quantAttr.getInputZp())));
+  }
   result.types.push_back(outputType);
 }
 
@@ -898,9 +925,8 @@ LogicalResult FullyConnectedOp::verify() {
 
   // Quantized type must have constructed the quantizationattr, and unquantized
   // types should not have a quantizationattr.
-  if ((inputIsQuant && !getQuantizationInfo()) ||
-      (!inputIsQuant && getQuantizationInfo())) {
-    emitOpError("quantizationattr is required for quantized type, and not "
+  if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
+    emitOpError("input zero point is required for quantized type, and not "
                 "allowed for float type");
     return failure();
   }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..4eba89b59bbd79 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -130,13 +130,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     auto maybeZps = failureOrMaybeZps.value();
     Value fullyConnectedValue;
     if (maybeZps) {
-      auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>(
-          maybeZps->inputZp, maybeZps->weightZp);
       fullyConnectedValue =
           rewriter
               .create<tosa::FullyConnectedOp>(
                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.getBias(), zeroPointAttr)
+                  reshapedWeight, op.getBias(),
+                  rewriter.getI32IntegerAttr(maybeZps->inputZp),
+                  rewriter.getI32IntegerAttr(maybeZps->weightZp))
               .getResult();
     } else {
       fullyConnectedValue = rewriter
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..b5b3e9d76c47e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -143,8 +143,7 @@ class TransposeConvStridedConverter
       weight = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
           weightPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp));
-
+          rewriter.getI32IntegerAttr(maybeZps->weightZp));
     } else {
       weight = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
@@ -203,7 +202,7 @@ class TransposeConvStridedConverter
       input = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
           inputPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp));
+          rewriter.getI32IntegerAttr(maybeZps->inputZp));
     } else {
       input = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 116cd045aa0d3a..87c388b6f5ee30 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -23,7 +23,7 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
   // CHECK: [[ONE:%.+]] = arith.constant 1
   // CHECK: [[TWO:%.+]] = arith.constant 2
   // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
-  %0 = tosa.matmul %arg0, %arg1 {quantization_info = #tosa.matmul_quant<a_zp = 1, b_zp = 2>} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+  %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
   return %0 : tensor<1x5x6xi32>
 }
 
@@ -124,7 +124,7 @@ func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8
   // CHECK: %[[C2:.+]] = arith.constant 2 : i32
   // CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
 
-  %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
+  %0 = tosa.fully_connected %arg0, %arg1, %arg2 {input_zp = 1 : i32, weight_zp = 2 : i32} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
   return %0 : tensor<5x6xi32>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..6e8501aaaf2afe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -880,26 +880,36 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[CNST:%.+]] = arith.constant 7
+  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
+  // CHECK: linalg.yield [[SUB]]
+  %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+  // CHECK: [[C32639:%.+]] = arith.constant 32639
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
-  // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
+  // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
   // CHECK: [[MIN:%.+]] = arith.constant -128
   // CHECK: [[MAX:%.+]] = arith.constant 127
   // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
   // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
   // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
-  %1 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32639, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+  %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+  // CHECK: [[C32640:%.+]] = arith.constant 32640
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
-  %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+  // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+  // CHECK: [[MIN:%.+]] = arith.constant -128
+  // CHECK: [[MAX:%.+]] = arith.constant 127
+  // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
+  // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
   // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index f95de798474641..e83e898644bc09 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -492,7 +492,7 @@ func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
   // CHECK: [[CST:%.+]] = arith.constant 42 : i32
   // CHECK: tensor.pad
   // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>)  -> (tensor<4x9xi32>)
+  %1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..e0e1de6a94d10d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -317,7 +317,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
 
 // CHECK-LABEL: @pad_determine_val_quant
 func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 685f799bd3d2bf..e4a2897908072a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -28,7 +28,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
   // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
   // CHECK-SAME: -> tensor<3x2xi8>
   // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
-  // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>
+  // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32}
   // CHECK-SAME: -> tensor<400x3xi32>
   // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
   // CHECK-SAME: -> tensor<4x10x10x3xi32>
@@ -48,7 +48,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
 func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
 // CHECK:           %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
 // CHECK:           %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
-// CHECK:           %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
+// CHECK:           %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
 // CHECK:           %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
 // CHECK:           return %[[VAL_6]] : tensor<?x14x14x384xi32>
 // CHECK:         }
@@ -67,7 +67,7 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1:
   // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
   // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
   // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
-  // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
+  // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32}
   // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
   %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
   %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb6de82ee10532..82838cc7e15451 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -91,7 +91,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // Manipulate the weight matrix to handle striding.
   // CHECK-DAG: %[[PADV:.+]]  = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
-  // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
+  // CHECK-DAG: %[[PADW:.+]]  = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
   // CHECK-DAG: %[[RESW1:.+]]  = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
   // CHECK-DAG: %[[TRANS:.+]]  = tosa.transpose %[[RESW1]], %[[TRANSV]]
   // CHECK-DAG: %[[RESW2:.+]]  = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
@@ -101,7 +101,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // Pad out the input matrix to handle the transpose conv.
   // CHECK-DAG: %[[PAD:.+]]  = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
   // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
-  // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
+  // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32}
 
   // Manipulate the final shape.
   // CHECK-DAG: %[[BIAS:.+]]  = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
@@ -132,14 +132,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
   // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
   // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
   // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
-  // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
+  // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
   // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
   // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
   // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
   // CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
-  // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>}
+  // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
   // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
-  // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
+  // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]}
   // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
   // CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
   // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>}



More information about the Mlir-commits mailing list