[Mlir-commits] [mlir] [mlir][tosa] Remove Quantization Attribute (PR #125479)
Jack Frankland
llvmlistbot at llvm.org
Mon Feb 3 03:07:07 PST 2025
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/125479
>From babe874e1e78ee76da93bbc307e7485088bd95fb Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 13 Feb 2024 19:35:14 +0000
Subject: [PATCH] [mlir][tosa] Remove Quantization Attribute
Removed the TOSA quantization attribute used in various MLIR TOSA
dialect operations in favour of using builtin attributes.
Update any lit tests, conversions and transformations appropriately.
Rename operands as follows to align with the TOSA-v1.0 specification:
* `cond` -> `condition`
* `then_branch` -> `then_graph`
* `else_branch` -> `else_graph`
* `inputs` -> `input_list`
* `output` -> `output_list`
* `cond` -> `cond_graph`
* `body` -> `body_graph`
Signed-off-by: Tai Ly <tai.ly at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 30 +++---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 100 +++++++++---------
.../TosaToLinalg/TosaToLinalgNamed.cpp | 31 ++----
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 12 +--
.../Conversion/TosaToTensor/TosaToTensor.cpp | 6 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 6 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 83 ++++++++++-----
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 6 +-
.../Transforms/TosaDecomposeTransposeConv.cpp | 5 +-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 4 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 28 +++--
.../TosaToTensor/tosa-to-tensor.mlir | 2 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 2 +-
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 6 +-
.../Tosa/tosa-decompose-transpose-conv.mlir | 10 +-
15 files changed, 184 insertions(+), 147 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 819547855d1015b..fef0f2d98d95c63 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp,
+ OptionalAttr<I32Attr>:$output_zp
);
let results = (outs
@@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
Tosa_Tensor2D:$input,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp,
+ OptionalAttr<I32Attr>:$weight_zp
);
let results = (outs
@@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
let arguments = (ins
Tosa_Tensor3D:$a,
Tosa_Tensor3D:$b,
- OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$a_zp,
+ OptionalAttr<I32Attr>:$b_zp
);
let results = (outs
@@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
let arguments = (ins
Tosa_Tensor:$input1,
- OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input1_zp,
+ OptionalAttr<I32Attr>:$output_zp
);
let results = (outs
@@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Tosa_RankedTensor:$input1,
Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
- OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+ OptionalAttr<I32Attr>:$input_zp
);
let results = (outs
@@ -2071,17 +2075,17 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];
let arguments = (ins
- Tosa_I1Tensor:$cond,
+ Tosa_I1Tensor:$condition,
Variadic<Tosa_Tensor>:$inputs
);
let results = (outs
- Variadic<Tosa_Tensor>:$output
+ Variadic<Tosa_Tensor>:$output_list
);
let regions = (region
- SizedRegion<1>:$then_branch,
- SizedRegion<1>:$else_branch
+ SizedRegion<1>:$then_graph,
+ SizedRegion<1>:$else_graph
);
let hasCustomAssemblyFormat = 1;
@@ -2108,16 +2112,16 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$inputs
+ Variadic<Tosa_Tensor>:$input_list
);
let results = (outs
- Variadic<Tosa_Tensor>:$output
+ Variadic<Tosa_Tensor>:$output_list
);
let regions = (region
- SizedRegion<1>:$cond,
- SizedRegion<1>:$body
+ SizedRegion<1>:$cond_graph,
+ SizedRegion<1>:$body_graph
);
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b63..449baad0edeafe4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -141,63 +141,67 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
// tosa::NegateOp
- if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+ if (isa<tosa::NegateOp>(op)) {
+ if (isa<FloatType>(elementTy))
+ return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
- if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
- int64_t inZp = 0, outZp = 0;
+ auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
+ auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+ int32_t inputZpVal = inputZpAttr ? inputZpAttr.getInt() : 0;
+ int32_t outputZpVal = outputZpAttr ? outputZpAttr.getInt() : 0;
- if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
- auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
- inZp = quantizationInfo.value().getInputZp();
- outZp = quantizationInfo.value().getOutputZp();
- }
-
- int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- if (!inZp && !outZp) {
+ if (isa<IntegerType>(elementTy) && inputZpVal == 0 && outputZpVal == 0) {
auto constant = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 0));
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
args[0]);
}
- // Compute the maximum value that can occur in the intermediate buffer.
- int64_t zpAdd = inZp + outZp;
- int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
- std::abs(zpAdd) + 1;
-
- // Convert that maximum value into the maximum bitwidth needed to represent
- // it. We assume 48-bit numbers may be supported further in the pipeline.
- int intermediateBitWidth = 64;
- if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
- intermediateBitWidth = 16;
- } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
- intermediateBitWidth = 32;
- } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
- intermediateBitWidth = 48;
- }
+ if (isa<IntegerType>(elementTy) && (inputZpVal != 0 || outputZpVal != 0)) {
+ int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
+ int64_t inZp = inputZpVal;
+ int64_t outZp = outputZpVal;
+
+ // Compute the maximum value that can occur in the intermediate buffer.
+ int64_t zpAdd = inZp + outZp;
+ int64_t maxValue =
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to
+ // represent it. We assume 48-bit numbers may be supported further in the
+ // pipeline.
+ int intermediateBitWidth = 64;
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
- Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
-
- // The negation can be applied by doing:
- // outputValue = inZp + outZp - inputValue
- auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
- auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
-
- // Clamp to the negation range.
- Value min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
- intermediateType);
- Value max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
- intermediateType);
- auto clamp =
- clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
-
- // Truncate to the final value.
- return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ Value zpAddValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+
+ // The negation can be applied by doing:
+ // outputValue = inZp + outZp - inputValue
+ auto ext =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
+ auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
+
+ // Clamp to the negation range.
+ Value min = rewriter.create<arith::ConstantIntOp>(
+ loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
+ intermediateType);
+ Value max = rewriter.create<arith::ConstantIntOp>(
+ loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
+ intermediateType);
+ auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
+
+ // Truncate to the final value.
+ return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
+ }
}
// tosa::BitwiseAndOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9f..1e02301f7c23d5a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -590,18 +590,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- if (!op.getQuantizationInfo()) {
+ if (!op.getAZp() && !op.getBZp()) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
return success();
}
- auto quantizationInfo = *op.getQuantizationInfo();
- auto aZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
- auto bZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
+ auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
+ auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -661,7 +658,7 @@ class FullyConnectedConverter
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
- if (!op.getQuantizationInfo()) {
+ if (!op.getInputZp() && !op.getWeightZp()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
@@ -672,11 +669,8 @@ class FullyConnectedConverter
return success();
}
- auto quantizationInfo = *op.getQuantizationInfo();
- auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
+ auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+ auto outputZp = rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
@@ -958,10 +952,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply an offset
// for the input zp value.
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = *op.getQuantizationInfo();
+ if (op.getInputZp()) {
auto inputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
+ loc, op.getInputZpAttr());
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
@@ -1013,11 +1006,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// If we have quantization information we need to apply output
// zeropoint.
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = *op.getQuantizationInfo();
- auto outputZp = rewriter.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(scaled.getType(),
- quantizationInfo.getOutputZp()));
+ if (op.getOutputZp()) {
+ auto outputZp =
+ rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf113..80c58bdc0550ccd 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
- rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+ rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
condition, true);
- inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+ inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputs(),
rewriter);
- inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+ inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputs(),
rewriter);
rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
auto newWhile = rewriter.create<scf::WhileOp>(
- op.getLoc(), op.getResultTypes(), op.getInputs());
+ op.getLoc(), op.getResultTypes(), op.getInputList());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());
- inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
- inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+ inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+ inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
rewriter.replaceOp(op, newWhile.getResults());
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b01..2a9b4d111bdfa2d 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -358,10 +358,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
TypedAttr constantAttr;
if (isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
+ } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
- int64_t value = padOp.getQuantizationInfo()->getInputZp();
+ } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) {
+ int64_t value = padOp.getInputZpAttr().getInt();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb7..8e22c879753a339 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
Attribute constantAttr;
if (llvm::isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
+ } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
- auto value = op.getQuantizationInfo()->getInputZp();
+ } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
+ int64_t value = op.getInputZpAttr().getInt();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf6..9bde6a859352559 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -124,7 +124,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
//===----------------------------------------------------------------------===//
/// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+ return {&getBodyGraph()};
+}
//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
@@ -271,11 +273,11 @@ static LogicalResult verifyConvOp(T op) {
}
}
- bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
- bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+ bool inputIsFloat = llvm::isa<FloatType>(inputEType);
+ bool weightIsFloat = llvm::isa<FloatType>(weightEType);
- // Either both must be quantized or both unquantized.
- if (inputIsQuant != weightIsQuant) {
+ // Either both must be float or both non-float.
+ if (inputIsFloat != weightIsFloat) {
op.emitOpError(
"expect both input and weight to be float or not together, got ")
<< inputEType << " and " << weightEType;
@@ -527,7 +529,12 @@ static void buildTransConvOpWithQuantInfo(
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("weight_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getWeightZp())));
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
@@ -563,7 +570,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
+ result.addAttribute("a_zp", builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getAZp())));
+ result.addAttribute("b_zp", builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getBZp())));
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
assert(inputType && "Input must be a shaped tensor type!");
@@ -603,8 +613,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.addAttribute("pad", pad);
result.addAttribute("acc_type", accType);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("output_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getOutputZp())));
+ }
result.types.push_back(outputType);
}
@@ -616,8 +632,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
Value input) {
result.addOperands(input);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ // note: negateOp has attributes input1_zp and output_zp
+ result.addAttribute("input1_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ result.addAttribute("output_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getOutputZp())));
+ }
result.types.push_back(outputType);
}
@@ -629,8 +652,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Value paddings) {
result.addOperands({input, paddings});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ }
result.types.push_back(outputType);
}
@@ -643,8 +669,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
Value padConst) {
result.addOperands({input, paddings, padConst});
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
- if (quantAttr)
- result.addAttribute("quantization_info", quantAttr);
+ if (quantAttr) {
+ result.addAttribute("input_zp",
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(quantAttr.getInputZp())));
+ }
result.types.push_back(outputType);
}
@@ -898,9 +927,8 @@ LogicalResult FullyConnectedOp::verify() {
// Quantized type must have constructed the quantizationattr, and unquantized
// types should not have a quantizationattr.
- if ((inputIsQuant && !getQuantizationInfo()) ||
- (!inputIsQuant && getQuantizationInfo())) {
- emitOpError("quantizationattr is required for quantized type, and not "
+ if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
+ emitOpError("input zero point is required for quantized type, and not "
"allowed for float type");
return failure();
}
@@ -2229,7 +2257,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
- for (auto &block : adaptor.getBody())
+ for (auto &block : adaptor.getBodyGraph())
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
@@ -2309,19 +2337,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
void IfOp::print(OpAsmPrinter &p) {
bool printBlockTerminators = false;
- p << " " << getCond();
+ p << " " << getCondition();
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ")";
// Print yield explicitly if the op defines values.
printBlockTerminators = true;
}
p << ' ';
- p.printRegion(getThenBranch(),
+ p.printRegion(getThenGraph(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
// Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseBranch();
+ auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
@@ -2419,14 +2447,15 @@ static void printInitializationList(OpAsmPrinter &parser,
}
void WhileOp::print(OpAsmPrinter &parser) {
- printInitializationList(parser, getCond().front().getArguments(), getInputs(),
- " ");
+ printInitializationList(parser, getCondGraph().front().getArguments(),
+ getInputList(), " ");
parser << " : ";
- parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
+ parser.printFunctionalType(getInputList().getTypes(),
+ getResults().getTypes());
parser << ' ';
- parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
+ parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
parser << " do ";
- parser.printRegion(getBody());
+ parser.printRegion(getBodyGraph());
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe1..4eba89b59bbd79f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -130,13 +130,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto maybeZps = failureOrMaybeZps.value();
Value fullyConnectedValue;
if (maybeZps) {
- auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>(
- maybeZps->inputZp, maybeZps->weightZp);
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
- reshapedWeight, op.getBias(), zeroPointAttr)
+ reshapedWeight, op.getBias(),
+ rewriter.getI32IntegerAttr(maybeZps->inputZp),
+ rewriter.getI32IntegerAttr(maybeZps->weightZp))
.getResult();
} else {
fullyConnectedValue = rewriter
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f28..b5b3e9d76c47e23 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -143,8 +143,7 @@ class TransposeConvStridedConverter
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp));
-
+ rewriter.getI32IntegerAttr(maybeZps->weightZp));
} else {
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
@@ -203,7 +202,7 @@ class TransposeConvStridedConverter
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp));
+ rewriter.getI32IntegerAttr(maybeZps->inputZp));
} else {
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 116cd045aa0d3af..87c388b6f5ee30b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -23,7 +23,7 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[TWO:%.+]] = arith.constant 2
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
- %0 = tosa.matmul %arg0, %arg1 {quantization_info = #tosa.matmul_quant<a_zp = 1, b_zp = 2>} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+ %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
return %0 : tensor<1x5x6xi32>
}
@@ -124,7 +124,7 @@ func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8
// CHECK: %[[C2:.+]] = arith.constant 2 : i32
// CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
+ %0 = tosa.fully_connected %arg0, %arg1, %arg2 {input_zp = 1 : i32, weight_zp = 2 : i32} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
return %0 : tensor<5x6xi32>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317aa..6e8501aaaf2afe2 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -880,26 +880,36 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[CNST:%.+]] = arith.constant 7
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
+ // CHECK: linalg.yield [[SUB]]
+ %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[C32639:%.+]] = arith.constant 32639
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- %1 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32639, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: [[C32640:%.+]] = arith.constant 32640
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
- %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
+ // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+ // CHECK: [[MIN:%.+]] = arith.constant -128
+ // CHECK: [[MAX:%.+]] = arith.constant 127
+ // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
+ // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
+ // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index f95de7984746412..e83e898644bc091 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -492,7 +492,7 @@ func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
// CHECK: [[CST:%.+]] = arith.constant 42 : i32
// CHECK: tensor.pad
// CHECK: tensor.yield [[CST]]
- %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
+ %1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63ccd..e0e1de6a94d10d0 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -317,7 +317,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
// CHECK-LABEL: @pad_determine_val_quant
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
- // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+ // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
%0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 685f799bd3d2bfa..e4a2897908072a6 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -28,7 +28,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
// CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-SAME: -> tensor<3x2xi8>
// CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
- // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>
+ // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK-SAME: -> tensor<400x3xi32>
// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
@@ -48,7 +48,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
-// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
+// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
// CHECK: }
@@ -67,7 +67,7 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1:
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
// CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
- // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
+ // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
%input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
%weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb6de82ee10532e..82838cc7e154514 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -91,7 +91,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>}
+ // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>}
// CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>}
@@ -101,7 +101,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32}
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
@@ -132,14 +132,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>}
+ // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
// CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>}
// CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
- // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>}
+ // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
// CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]
- // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
+ // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]}
// CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
// CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>}
More information about the Mlir-commits
mailing list