[Mlir-commits] [mlir] [mlir][tosa] Remove Quantization Attribute (PR #125479)
Jack Frankland
llvmlistbot at llvm.org
Tue Feb 4 09:06:04 PST 2025
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/125479
>From 9688657c8791b32e0de0af3fc77885431c795d23 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 13 Feb 2024 19:35:14 +0000
Subject: [PATCH] [mlir][tosa] Remove Quantization Attribute
Removed the TOSA quantization attribute used in various MLIR TOSA
dialect operations in favour of using builtin attributes.
Update any lit tests, conversions and transformations appropriately.
Signed-off-by: Tai Ly <tai.ly at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 14 ++-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 106 +++++++++---------
.../TosaToLinalg/TosaToLinalgNamed.cpp | 34 +++---
.../Conversion/TosaToTensor/TosaToTensor.cpp | 6 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 6 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 60 +++++++---
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 6 +-
.../Transforms/TosaDecomposeTransposeConv.cpp | 5 +-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 4 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 28 +++--
.../TosaToTensor/tosa-to-tensor.mlir | 2 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 2 +-
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 6 +-
.../Tosa/tosa-decompose-transpose-conv.mlir | 10 +-
14 files changed, 161 insertions(+), 128 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015..48d8f1bd4836c9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp,
+ OptionalAttr<I32Attr>:$output_zp
);
let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
Tosa_Tensor2D:$input,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp,
+ OptionalAttr<I32Attr>:$weight_zp
);
let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
let arguments = (ins
Tosa_Tensor3D:$a,
Tosa_Tensor3D:$b,
- OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$a_zp,
+ OptionalAttr<I32Attr>:$b_zp
);
let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
let arguments = (ins
Tosa_Tensor:$input1,
- OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input1_zp,
+ OptionalAttr<I32Attr>:$output_zp
);
let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Tosa_RankedTensor:$input1,
Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
- OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp
);
let results = (outs
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..67218cee518d59 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,65 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
// tosa::NegateOp
- if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+ if (isa<tosa::NegateOp>(op)) {
+ if (isa<FloatType>(elementTy))
+ return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
- int64_t inZp = 0, outZp = 0;
+ if (isa<IntegerType>(elementTy)) {
+ auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1Zp();
+ auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZp();
- if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
- inZp = quantizationInfo.value().getInputZp();
- outZp = quantizationInfo.value().getOutputZp();
- }
+ const int64_t inZp = inputZpAttr ? *inputZpAttr : 0;
+ const int64_t outZp = outputZpAttr ? *outputZpAttr : 0;
- int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- if (!inZp && !outZp) {
- auto constant = rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(elementTy, 0));
- return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
- args[0]);
- }
+ if (!inZp && !outZp) {
+ auto constant = rewriter.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 0));
+ return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
+ args[0]);
+ }
- // Compute the maximum value that can occur in the intermediate buffer.
- int64_t zpAdd = inZp + outZp;
- int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
- std::abs(zpAdd) + 1;
-
- // Convert that maximum value into the maximum bitwidth needed to represent
- // it. We assume 48-bit numbers may be supported further in the pipeline.
- int intermediateBitWidth = 64;
- if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
- intermediateBitWidth = 16;
- } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
- intermediateBitWidth = 32;
- } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
- intermediateBitWidth = 48;
- }
+ // Compute the maximum value that can occur in the intermediate buffer.
+ const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+ const int64_t zpAdd = inZp + outZp;
+ const int64_t maxValue =
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to
+ // represent it. We assume 48-bit numbers may be supported further in
+ // the pipeline.
+ int intermediateBitWidth = 64;
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
- Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
- // The negation can be applied by doing:
- // outputValue = inZp + outZp - inputValue
- auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
- auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
- // Clamp to the negation range.
- Value min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
- intermediateType);
- Value max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
- intermediateType);
- auto clamp =
- clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
- // Truncate to the final value.
- return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ Value zpAddValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+ // The negation can be applied by doing:
+ // outputValue = inZp + outZp - inputValue
+ auto ext =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+ auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+ // Clamp to the negation range.
+ Value min = rewriter.create<arith::ConstantIntOp>(
+ loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+ intermediateType);
+ Value max = rewriter.create<arith::ConstantIntOp>(
+ loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+ intermediateType);
+ auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+ // Truncate to the final value.
+ return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ }
}
// tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9..6321cb6087394a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- if (!op.getQuantizationInfo()) {
+ if (!op.getAZp() && !op.getBZp()) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
return success();
}
- auto quantizationInfo = *op.getQuantizationInfo();
- auto aZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
- auto bZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+ auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+ auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
- if (!op.getQuantizationInfo()) {
+ if (!op.getInputZp() && !op.getWeightZp()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
@@ -672,11 +669,9 @@ class FullyConnectedConverter
return success();
}
- auto quantizationInfo = *op.getQuantizationInfo();
- auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+ auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+ auto outputZp =
+ rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
@@ -958,10 +953,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply an offset
// for the input zp value.
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = *op.getQuantizationInfo();
- auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+ if (op.getInputZp()) {
+ auto inputZp =
+ rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
@@ -1013,11 +1007,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply output
// zeropoint.
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = *op.getQuantizationInfo();
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(scaled.getType(),
- quantizationInfo.getOutputZp()));
+ if (op.getOutputZp()) {
+ auto outputZp =
+ rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..2a9b4d111bdfa2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
TypedAttr constantAttr;
if (isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+ } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
- int64_t value = padOp.getQuantizationInfo()->getInputZp();
+ } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+ int64_t value = padOp.getInputZpAttr().getInt();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..8e22c879753a33 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
Attribute constantAttr;
if (llvm::isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+ } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
- auto value = op.getQuantizationInfo()->getInputZp();
+ } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+ int64_t value = op.getInputZpAttr().getInt();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..031c279ff09e27 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -271,11 +271,11 @@ static LogicalResult verifyConvOp(T op) {
}
}
- bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
- bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+ bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+ bool weightIsFloat = llvm::isa<FloatType>(weightEType);
- // Either both must be quantized or both unquantized.
- if (inputIsQuant != weightIsQuant) {
+ // Either both must be float or both non-float.
+ if (inputIsFloat != weightIsFloat) {
op.emitOpError(
"expect both input and weight to be float or not together, got ")
<< inputEType << " and " << weightEType;
@@ -527,7 +527,12 @@ static void buildTransConvOpWithQuantInfo(
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("weight_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getWeightZp())));
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
@@ -563,7 +568,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
+ result.addAttribute("a_zp", builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getAZp())));
+ result.addAttribute("b_zp", builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getBZp())));
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +611,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.addAttribute("pad", pad);
result.addAttribute("acc_type", accType);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("output_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getOutputZp())));
+ }
result.types.push_back(outputType);
}
@@ -616,8 +630,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
Value input) {
result.addOperands(input);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ // note: negateOp has attributes input1_zp and output_zp
+ result.addAttribute("input1_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("output_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getOutputZp())));
+ }
result.types.push_back(outputType);
}
@@ -629,8 +650,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Value paddings) {
result.addOperands({input, paddings});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ }
result.types.push_back(outputType);
}
@@ -643,8 +667,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
Value padConst) {
result.addOperands({input, paddings, padConst});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ }
result.types.push_back(outputType);
}
@@ -898,9 +925,8 @@ LogicalResult FullyConnectedOp::verify() {
// Quantized type must have constructed the quantizationattr, and unquantized
// types should not have a quantizationattr.
- if ((inputIsQuant && !getQuantizationInfo()) ||
- (!inputIsQuant && getQuantizationInfo())) {
- emitOpError("quantizationattr is required for quantized type, and not "
+ if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
+ emitOpError("input zero point is required for quantized type, and not "
"allowed for float type");
return failure();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..4eba89b59bbd79 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -130,13 +130,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto maybeZps = failureOrMaybeZps.value();
Value fullyConnectedValue;
if (maybeZps) {
- auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>(
- maybeZps->inputZp, maybeZps->weightZp);
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
- reshapedWeight, op.getBias(), zeroPointAttr)
+ reshapedWeight, op.getBias(),
+ rewriter.getI32IntegerAttr(maybeZps->inputZp),
+ rewriter.getI32IntegerAttr(maybeZps->weightZp))
.getResult();
} else {
fullyConnectedValue = rewriter
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..b5b3e9d76c47e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -143,8 +143,7 @@ class TransposeConvStridedConverter
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp));
-
+ rewriter.getI32IntegerAttr(maybeZps->weightZp));
} else {
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
@@ -203,7 +202,7 @@ class TransposeConvStridedConverter
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp));
+ rewriter.getI32IntegerAttr(maybeZps->inputZp));
} else {
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 116cd045aa0d3a..87c388b6f5ee30 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -23,7 +23,7 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[TWO:%.+]] = arith.constant 2
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
- %0 = tosa.matmul %arg0, %arg1 {quantization_info = #tosa.matmul_quant<a_zp = 1, b_zp = 2>} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+ %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
return %0 : tensor<1x5x6xi32>
}
@@ -124,7 +124,7 @@ func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8
// CHECK: %[[C2:.+]] = arith.constant 2 : i32
// CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
+ %0 = tosa.fully_connected %arg0, %arg1, %arg2 {input_zp = 1 : i32, weight_zp = 2 : i32} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
return %0 : tensor<5x6xi32>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..6e8501aaaf2afe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -880,26 +880,36 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[CNST:%.+]] = arith.constant 7
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
+ // CHECK: linalg.yield [[SUB]]
+ %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[C32639:%.+]] = arith.constant 32639
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- %1 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32639, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[C32640:%.+]] = arith.constant 32640
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
- %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+ // CHECK: [[MIN:%.+]] = arith.constant -128
+ // CHECK: [[MAX:%.+]] = arith.constant 127
+ // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
+ // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
+ // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index f95de798474641..e83e898644bc09 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -492,7 +492,7 @@ func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..e0e1de6a94d10d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -317,7 +317,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
// CHECK-LABEL: @pad_determine_val_quant
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
- // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 685f799bd3d2bf..e4a2897908072a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -28,7 +28,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
// CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-SAME: -> tensor<3x2xi8>
// CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
- // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>
+ // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK-SAME: -> tensor<400x3xi32>
// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
@@ -48,7 +48,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
-// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
+// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
// CHECK: }
@@ -67,7 +67,7 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1:
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
// CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
- // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
+ // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
%input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb6de82ee10532..82838cc7e15451 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -91,7 +91,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
+ // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
@@ -101,7 +101,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32}
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
@@ -132,14 +132,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
+ // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
// CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
// CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
- // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>}
+ // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
// CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
- // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
+ // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]}
// CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
// CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>}
More information about the Mlir-commits
mailing list