[Mlir-commits] [mlir] [mlir][tosa] Remove Quantization Attribute (PR #125479)
Jack Frankland
llvmlistbot at llvm.org
Mon Feb 3 03:05:58 PST 2025
https://github.com/FranklandJack created https://github.com/llvm/llvm-project/pull/125479
Removed the TOSA quantization attribute used in various MLIR TOSA dialect operations in favour of using builtin attributes.
Update any lit tests, conversions and transformations appropriately.
Rename operands as follows to align with the TOSA-v1.0 specification:
* `cond` -> `condition`
* `then_branch` -> `then_graph`
* `else_branch` -> `else_graph`
* `inputs` -> `input_list`
* `output` -> `output_list`
* `cond` -> `cond_graph`
* `body` -> `body_graph`
>From 0549a1203096329a442c6aec02b730aa5e567017 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 13 Feb 2024 19:35:14 +0000
Subject: [PATCH] [mlir][tosa] Remove Quantization Attribute
Removed the TOSA quantization attribute used in various MLIR TOSA
dialect operations in favour of using builtin attributes.
Update any lit tests, conversions and transformations appropriately.
Rename operands as follows to align with the TOSA-v1.0 specification:
* `cond` -> `condition`
* `then_branch` -> `then_graph`
* `else_branch` -> `else_graph`
* `inputs` -> `input_list`
* `output` -> `output_list`
* `cond` -> `cond_graph`
* `body` -> `body_graph`
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Ic2244e15fa63a4508898151c9a45d95d3d2b3738
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 30 +++---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 100 +++++++++---------
.../TosaToLinalg/TosaToLinalgNamed.cpp | 31 ++----
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 12 +--
.../Conversion/TosaToTensor/TosaToTensor.cpp | 6 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 6 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 83 ++++++++++-----
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 6 +-
.../Transforms/TosaDecomposeTransposeConv.cpp | 5 +-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 4 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 28 +++--
.../TosaToTensor/tosa-to-tensor.mlir | 2 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 2 +-
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 6 +-
.../Tosa/tosa-decompose-transpose-conv.mlir | 10 +-
15 files changed, 184 insertions(+), 147 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015..fef0f2d98d95c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp,
+ OptionalAttr<I32Attr>:$output_zp
);
let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
Tosa_Tensor2D:$input,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp,
+ OptionalAttr<I32Attr>:$weight_zp
);
let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
let arguments = (ins
Tosa_Tensor3D:$a,
Tosa_Tensor3D:$b,
- OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$a_zp,
+ OptionalAttr<I32Attr>:$b_zp
);
let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
let arguments = (ins
Tosa_Tensor:$input1,
- OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input1_zp,
+ OptionalAttr<I32Attr>:$output_zp
);
let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Tosa_RankedTensor:$input1,
Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
- OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp
);
let results = (outs
@@ -2071,17 +2075,17 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];
let arguments = (ins
- Tosa_I1Tensor:$cond,
+ Tosa_I1Tensor:$condition,
Variadic<Tosa_Tensor>:$inputs
);
let results = (outs
- Variadic<Tosa_Tensor>:$output
+ Variadic<Tosa_Tensor>:$output_list
);
let regions = (region
- SizedRegion<1>:$then_branch,
- SizedRegion<1>:$else_branch
+ SizedRegion<1>:$then_graph,
+ SizedRegion<1>:$else_graph
);
let hasCustomAssemblyFormat = 1;
@@ -2108,16 +2112,16 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$inputs
+ Variadic<Tosa_Tensor>:$input_list
);
let results = (outs
- Variadic<Tosa_Tensor>:$output
+ Variadic<Tosa_Tensor>:$output_list
);
let regions = (region
- SizedRegion<1>:$cond,
- SizedRegion<1>:$body
+ SizedRegion<1>:$cond_graph,
+ SizedRegion<1>:$body_graph
);
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..449baad0edeafe 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,67 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
// tosa::NegateOp
- if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+ if (isa<tosa::NegateOp>(op)) {
+ if (isa<FloatType>(elementTy))
+ return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
- int64_t inZp = 0, outZp = 0;
+ auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
+ auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+ int32_t inputZpVal = inputZpAttr ? inputZpAttr.getInt() : 0;
+ int32_t outputZpVal = outputZpAttr ? outputZpAttr.getInt() : 0;
- if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
- inZp = quantizationInfo.value().getInputZp();
- outZp = quantizationInfo.value().getOutputZp();
- }
-
- int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- if (!inZp && !outZp) {
+ if (isa<IntegerType>(elementTy) && inputZpVal == 0 && outputZpVal == 0) {
auto constant = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 0));
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
args[0]);
}
- // Compute the maximum value that can occur in the intermediate buffer.
- int64_t zpAdd = inZp + outZp;
- int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
- std::abs(zpAdd) + 1;
-
- // Convert that maximum value into the maximum bitwidth needed to represent
- // it. We assume 48-bit numbers may be supported further in the pipeline.
- int intermediateBitWidth = 64;
- if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
- intermediateBitWidth = 16;
- } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
- intermediateBitWidth = 32;
- } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
- intermediateBitWidth = 48;
- }
+ if (isa<IntegerType>(elementTy) && (inputZpVal != 0 || outputZpVal != 0)) {
+ int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+ int64_t inZp = inputZpVal;
+ int64_t outZp = outputZpVal;
+
+ // Compute the maximum value that can occur in the intermediate buffer.
+ int64_t zpAdd = inZp + outZp;
+ int64_t maxValue =
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to
+ // represent it. We assume 48-bit numbers may be supported further in the
+ // pipeline.
+ int intermediateBitWidth = 64;
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
- Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
- // The negation can be applied by doing:
- // outputValue = inZp + outZp - inputValue
- auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
- auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
- // Clamp to the negation range.
- Value min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
- intermediateType);
- Value max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
- intermediateType);
- auto clamp =
- clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
- // Truncate to the final value.
- return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ Value zpAddValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+ // The negation can be applied by doing:
+ // outputValue = inZp + outZp - inputValue
+ auto ext =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+ auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+ // Clamp to the negation range.
+ Value min = rewriter.create<arith::ConstantIntOp>(
+ loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+ intermediateType);
+ Value max = rewriter.create<arith::ConstantIntOp>(
+ loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+ intermediateType);
+ auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+ // Truncate to the final value.
+ return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ }
}
// tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9..1e02301f7c23d5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- if (!op.getQuantizationInfo()) {
+ if (!op.getAZp() && !op.getBZp()) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
return success();
}
- auto quantizationInfo = *op.getQuantizationInfo();
- auto aZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
- auto bZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+ auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+ auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
- if (!op.getQuantizationInfo()) {
+ if (!op.getInputZp() && !op.getWeightZp()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
@@ -672,11 +669,8 @@ class FullyConnectedConverter
return success();
}
- auto quantizationInfo = *op.getQuantizationInfo();
- auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+ auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+ auto outputZp = rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
@@ -958,10 +952,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply an offset
// for the input zp value.
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = *op.getQuantizationInfo();
+ if (op.getInputZp()) {
auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+ loc, op.getInputZpAttr());
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
@@ -1013,11 +1006,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply output
// zeropoint.
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = *op.getQuantizationInfo();
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(scaled.getType(),
- quantizationInfo.getOutputZp()));
+ if (op.getOutputZp()) {
+ auto outputZp =
+ rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf11..80c58bdc0550cc 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
- rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+ rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
condition, true);
- inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+ inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputs(),
rewriter);
- inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+ inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputs(),
rewriter);
rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
auto newWhile = rewriter.create<scf::WhileOp>(
- op.getLoc(), op.getResultTypes(), op.getInputs());
+ op.getLoc(), op.getResultTypes(), op.getInputList());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());
- inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
- inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+ inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+ inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
rewriter.replaceOp(op, newWhile.getResults());
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..2a9b4d111bdfa2 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
TypedAttr constantAttr;
if (isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+ } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
- int64_t value = padOp.getQuantizationInfo()->getInputZp();
+ } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+ int64_t value = padOp.getInputZpAttr().getInt();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..8e22c879753a33 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
Attribute constantAttr;
if (llvm::isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+ } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
- auto value = op.getQuantizationInfo()->getInputZp();
+ } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+ int64_t value = op.getInputZpAttr().getInt();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..9bde6a85935255 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -124,7 +124,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
//===----------------------------------------------------------------------===//
/// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+ return {&getBodyGraph()};
+}
//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
@@ -271,11 +273,11 @@ static LogicalResult verifyConvOp(T op) {
}
}
- bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
- bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+ bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+ bool weightIsFloat = llvm::isa<FloatType>(weightEType);
- // Either both must be quantized or both unquantized.
- if (inputIsQuant != weightIsQuant) {
+ // Either both must be float or both non-float.
+ if (inputIsFloat != weightIsFloat) {
op.emitOpError(
"expect both input and weight to be float or not together, got ")
<< inputEType << " and " << weightEType;
@@ -527,7 +529,12 @@ static void buildTransConvOpWithQuantInfo(
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("weight_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getWeightZp())));
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
@@ -563,7 +570,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
+ result.addAttribute("a_zp", builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getAZp())));
+ result.addAttribute("b_zp", builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getBZp())));
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +613,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.addAttribute("pad", pad);
result.addAttribute("acc_type", accType);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("output_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getOutputZp())));
+ }
result.types.push_back(outputType);
}
@@ -616,8 +632,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
Value input) {
result.addOperands(input);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ // note: negateOp has attributes input1_zp and output_zp
+ result.addAttribute("input1_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("output_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getOutputZp())));
+ }
result.types.push_back(outputType);
}
@@ -629,8 +652,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Value paddings) {
result.addOperands({input, paddings});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ }
result.types.push_back(outputType);
}
@@ -643,8 +669,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
Value padConst) {
result.addOperands({input, paddings, padConst});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ }
result.types.push_back(outputType);
}
@@ -898,9 +927,8 @@ LogicalResult FullyConnectedOp::verify() {
// Quantized type must have constructed the quantizationattr, and unquantized
// types should not have a quantizationattr.
- if ((inputIsQuant && !getQuantizationInfo()) ||
- (!inputIsQuant && getQuantizationInfo())) {
- emitOpError("quantizationattr is required for quantized type, and not "
+ if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
+ emitOpError("input zero point is required for quantized type, and not "
"allowed for float type");
return failure();
}
@@ -2229,7 +2257,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
- for (auto &block : adaptor.getBody())
+ for (auto &block : adaptor.getBodyGraph())
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
@@ -2309,19 +2337,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
void IfOp::print(OpAsmPrinter &p) {
bool printBlockTerminators = false;
- p << " " << getCond();
+ p << " " << getCondition();
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ")";
// Print yield explicitly if the op defines values.
printBlockTerminators = true;
}
p << ' ';
- p.printRegion(getThenBranch(),
+ p.printRegion(getThenGraph(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
// Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseBranch();
+ auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
@@ -2419,14 +2447,15 @@ static void printInitializationList(OpAsmPrinter &parser,
}
void WhileOp::print(OpAsmPrinter &parser) {
- printInitializationList(parser, getCond().front().getArguments(), getInputs(),
- " ");
+ printInitializationList(parser, getCondGraph().front().getArguments(),
+ getInputList(), " ");
parser << " : ";
- parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
+ parser.printFunctionalType(getInputList().getTypes(),
+ getResults().getTypes());
parser << ' ';
- parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
+ parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
parser << " do ";
- parser.printRegion(getBody());
+ parser.printRegion(getBodyGraph());
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..4eba89b59bbd79 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -130,13 +130,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto maybeZps = failureOrMaybeZps.value();
Value fullyConnectedValue;
if (maybeZps) {
- auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>(
- maybeZps->inputZp, maybeZps->weightZp);
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
- reshapedWeight, op.getBias(), zeroPointAttr)
+ reshapedWeight, op.getBias(),
+ rewriter.getI32IntegerAttr(maybeZps->inputZp),
+ rewriter.getI32IntegerAttr(maybeZps->weightZp))
.getResult();
} else {
fullyConnectedValue = rewriter
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..b5b3e9d76c47e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -143,8 +143,7 @@ class TransposeConvStridedConverter
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp));
-
+ rewriter.getI32IntegerAttr(maybeZps->weightZp));
} else {
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
@@ -203,7 +202,7 @@ class TransposeConvStridedConverter
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp));
+ rewriter.getI32IntegerAttr(maybeZps->inputZp));
} else {
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 116cd045aa0d3a..87c388b6f5ee30 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -23,7 +23,7 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[TWO:%.+]] = arith.constant 2
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
- %0 = tosa.matmul %arg0, %arg1 {quantization_info = #tosa.matmul_quant<a_zp = 1, b_zp = 2>} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+ %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
return %0 : tensor<1x5x6xi32>
}
@@ -124,7 +124,7 @@ func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8
// CHECK: %[[C2:.+]] = arith.constant 2 : i32
// CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
+ %0 = tosa.fully_connected %arg0, %arg1, %arg2 {input_zp = 1 : i32, weight_zp = 2 : i32} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
return %0 : tensor<5x6xi32>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..6e8501aaaf2afe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -880,26 +880,36 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[CNST:%.+]] = arith.constant 7
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
+ // CHECK: linalg.yield [[SUB]]
+ %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[C32639:%.+]] = arith.constant 32639
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- %1 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32639, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[C32640:%.+]] = arith.constant 32640
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
- %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+ // CHECK: [[MIN:%.+]] = arith.constant -128
+ // CHECK: [[MAX:%.+]] = arith.constant 127
+ // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
+ // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
+ // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index f95de798474641..e83e898644bc09 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -492,7 +492,7 @@ func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..e0e1de6a94d10d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -317,7 +317,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
// CHECK-LABEL: @pad_determine_val_quant
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
- // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 685f799bd3d2bf..e4a2897908072a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -28,7 +28,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
// CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-SAME: -> tensor<3x2xi8>
// CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
- // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>
+ // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK-SAME: -> tensor<400x3xi32>
// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
@@ -48,7 +48,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
-// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
+// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
// CHECK: }
@@ -67,7 +67,7 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1:
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
// CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
- // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
+ // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
%input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb6de82ee10532..82838cc7e15451 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -91,7 +91,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
+ // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
@@ -101,7 +101,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32}
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
@@ -132,14 +132,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
+ // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
// CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
// CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
- // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>}
+ // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
// CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
- // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
+ // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]}
// CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
// CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>}
More information about the Mlir-commits
mailing list