[Mlir-commits] [mlir] 852b648 - [mlir] Improvements to the 'quant' dialect (#100667)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 26 11:09:35 PDT 2024
Author: Rafael Ubal
Date: 2024-09-26T14:09:28-04:00
New Revision: 852b6486246141e44cc9f126f542a2ae0d73b3d6
URL: https://github.com/llvm/llvm-project/commit/852b6486246141e44cc9f126f542a2ae0d73b3d6
DIFF: https://github.com/llvm/llvm-project/commit/852b6486246141e44cc9f126f542a2ae0d73b3d6.diff
LOG: [mlir] Improvements to the 'quant' dialect (#100667)
Full revamp of the 'quant' dialect. This is an implementation for the
RFC at
https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
Added:
mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Quant/IR/Quant.h
mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
mlir/test/Dialect/Quant/invalid.mlir
mlir/test/Dialect/Quant/lower-quant-ops.mlir
mlir/test/Dialect/Quant/ops.mlir
mlir/test/Dialect/Quant/strip-func-quant-types.mlir
Modified:
mlir/include/mlir/Dialect/Quant/CMakeLists.txt
mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/CAPI/Dialect/Quant.cpp
mlir/lib/Dialect/Quant/CMakeLists.txt
mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
mlir/lib/Dialect/Quant/IR/QuantOps.cpp
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
mlir/lib/Dialect/Quant/IR/TypeParser.cpp
mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Quant/canonicalize.mlir
mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
Removed:
mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
mlir/include/mlir/Dialect/Quant/QuantOps.h
mlir/include/mlir/Dialect/Quant/QuantOps.td
mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
mlir/include/mlir/Dialect/Quant/QuantTypes.h
mlir/include/mlir/Dialect/Quant/UniformSupport.h
################################################################################
diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
index c08f399ee182d8..9f57627c321fb0 100644
--- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
@@ -1,6 +1,2 @@
-add_mlir_dialect(QuantOps quant)
-add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
-
-set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
-mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
-add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..c08f399ee182d8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(QuantOps quant)
+add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
+mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
+add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
similarity index 59%
rename from mlir/include/mlir/Dialect/Quant/QuantOps.h
rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h
index 14fb3035ab0d38..11a969a3ee5191 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h
@@ -1,4 +1,4 @@
-//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
+//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_
-#define MLIR_DIALECT_QUANT_QUANTOPS_H_
+#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_
+#define MLIR_DIALECT_QUANT_IR_QUANT_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -19,9 +19,19 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
-#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc"
+
+namespace mlir {
+namespace quant {
+
+class QuantizedType;
+class UniformQuantizedType;
+class UniformQuantizedPerAxisType;
+
+} // namespace quant
+} // namespace mlir
#define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.h.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.h.inc"
-#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_
+#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
new file mode 100644
index 00000000000000..791cb9de48d058
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -0,0 +1,297 @@
+//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Quantization dialect, types, and traits.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef QUANT_BASE
+#define QUANT_BASE
+
+include "mlir/IR/OpBase.td"
+
+def Quant_Dialect : Dialect {
+ let name = "quant";
+ let description = [{
+ The `quant` dialect offers a framework for defining and manipulating
+ quantized values. Central to this framework is the `!quant.uniform` data
+ type, used to represent quantized values. This dialect also provides a
+ suite of operations to handle and convert quantized values between their
+ original floating-point representations and the optimized, lower bit-width
+ integer representations. The `quant` dialect is instrumented with
+ transformation passes to lower these operations into other core MLIR
+ dialects, while also flattening all occurrences of quantized types into
+ their integer counterparts.
+
+
+ ## The `!quant.uniform` type
+
+ The quantization process establishes a relationship between two types of
+ values: an *expressed value* and a *stored value*. The former refers to the
+ floating-point representation used in an original machine learning model,
+ capturing the precise numerical characteristics needed for accurate
+ calculations. The latter is the simplified integer representation that
+ resides in memory after quantization. The `!quant.uniform` data type
+ encodes the necessary information for (lossy) round-trip conversion between
+ an expressed and a stored value.
+
+ The `quant.uniform` type has two variants: per-layer quantization and
+ per-channel (or per-axis) quantization. In per-layer quantization, the
+ quantization information affects an entire tensor uniformly. Conversely, in
+ per-channel quantization, the data type encodes the specific tensor axis
+ that serves as the channel and includes quantization information for each
+ individual channel within the tensor. Below are the specific syntactic and
+ semantic considerations for each modality.
+
+
+ ### Per-layer quantization
+
+ This is the general syntax of the `!quant.uniform` type representing
+ per-layer quantization:
+
+ ```
+ `!quant.uniform` `<`
+ storedType (`<` storageMin `:` storageMax `>`)? `:`
+ expressedType `,`
+ scale (`:` zeroPoint)?
+ `>`
+ ```
+
+ The type contains the following parameters:
+
+ - `storedType`: Integer type of the value stored in memory. This type
+ conveys the bit width and signedness of the quantized stored value.
+ Signed integer types are represented as `'i' bitWidth` (e.g., `i8`),
+ while unsigned integer types are represented as `'u' bitWidth` (e.g.,
+ `u8`).
+
+ - `storageMin`, `storageMax`: Optional bounds for the stored value. If
+ given, they must be within the range of `storedType`. If omitted, the
+ entire range of `storedType` is allowed (e.g., `-128...127` for `i8` or
+ `0...255` for `u8`).
+
+ - `expressedType`: Floating-point type of the value expressed by this
+ quantized type (e.g., `f32`, `f80`, `bf16`, or `tf32`).
+
+ - `scale`: Floating-point value of type `expressedType` used in the
+ conversion between stored and expressed values.
+
+ - `zeroPoint`: Optional integer value of type `storageType` used in the
+ conversion between stored and expressed values. If omitted, the default
+ is 0.
+
+ Type conversions, rounding methods, and clamping actions aside, the
+ relationship between the expressed and stored values as encoded in a
+ quantized type is denoted by the following formula:
+
+ $$
+ expressedValue = (storedValue ~-~ zeroPoint) ~\times~ scale
+ $$
+
+ Operations `quant.qcast` (quantize cast) and `quant.dcast` (dequantize
+ cast) can be used to quantize a floating-point value and dequantize a
+ stored value, respectively. See the documentation for these operations for
+ details on how the quantization and dequantization processes are influenced
+ by the `!quant.uniform` type parameters.
+
+ Here are some examples of the use of `!quant.uniform` with per-layer
+ quantization:
+
+ ```
+ // An 8-bit signed integer type is used to represent a 32-bit float. No
+ // clamping information is provided, so the full [-128, 127] range is
+ // available. The scale is set to 3.0, and the zero point takes its default
+ // 0 value.
+ !quant.uniform<i8:f32, 3.0>
+
+ // A 16-bit unsigned integer type is used to represent a 32-bit float. Out
+ // of the 16 bits, only 10 are used, acoording to the 0..1023 clamping
+ // range. The type sets the scale to 1.23 and the zero point to 512.
+ !quant.uniform<u16<0:1023>:f32, 1.23:512>
+ ```
+
+ ### Per-channel quantization
+
+ The general syntax of the `!quant.uniform` type representing per-channel
+ quantization is as follows:
+
+ ```
+ `!quant.uniform` `<`
+ storedType (`<` storageMin `:` storageMax `>`)? `:`
+ expressedType `:`
+ channelAxis `,`
+ `{`
+ scale0 (`:` zeroPoint0)? `,`
+ scale1 (`:` zeroPoint1)? ...
+ '}'
+ `>`
+ ```
+
+ In this data type, there are multiple pairs of `scale` and `zeroPoint`
+ values. The `channelAxis` field represents the dimension of the containing
+ tensor acting as the channel. The size of the tensor along this dimension
+ is expected to match the number of provided `scale`-`zeroPoint` pairs, and
+ a given pair *i* applies to all elements in the tensor whose index along
+ dimension `channelAxis` is *i*. A quantized data type using per-channel
+ quantization is always expected to be contained within a tensor type.
+
+ Here are some examples:
+
+ ```
+ // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
+ // floats. Dimension 1 of the tensor acts as the channel dimension. Its
+ // size 3 matches the number of provided scale values. Tensor elemenets at
+ // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
+ // 5.0, respectively.
+ tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
+
+ // A 2D dynamically sized tensor contains 16-bit unsigned integers
+ // representing 32-bit floats. Dimension 0 of the tensor acts as the
+ // channel dimension. Since 2 scale and zero-point values are provided, the
+ // size of dimension 0 is expected to be 2 at runtime. Tensor elements
+ // [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale
+ // 3.0 and zero point 20.
+ tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
+ ```
+
+
+ ## Per-axis quantization integrity
+
+ When type `!quant.uniform` contains per-axis quantization information, the
+ rules below are enforced. These rules guarantee that the quantization
+ information encoded in the data type is applicable to the context in which
+ the quantized type is used. For efficiency, these rules are actively
+ enforced by the verifiers of `quant` dialect ops, but they must be
+ respected in any context in which the `!quant.uniform` data type is used,
+ such as the header of a `func.func` op, or the input of an arithmetic
+ operation.
+
+ - A quantized type with per-channel quantization information must be the
+ element type of a tensor container type, and may not occur directly as
+ the data type of a scalar value.
+
+ ```
+ // Incorrect. Type !quant.uniform specifies per-channel quantization for a
+ // scalar type.
+ %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:0, {1.0, 2.0}>
+
+ // Correct. Type `!quant.uniform` with per-channel quantization is wrapped
+ // in a `tensor` type.
+ %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform<i8:f32:0, {1.0, 2.0}>>
+ ```
+
+ - If the tensor containing the `!quant.uniform` type is ranked, its rank
+ must be greater than the channel axis specified in the quantized type.
+
+ ```
+ // Incorrect. The tensor rank (2) is not greater than the channel axis in
+ // the quantized type (3).
+ %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:3, {1.0, 2.0}>>
+
+ // Correct. The tensor rank (2) is now greater than the channel axis (1):
+ %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform<i8:f32:1, {1.0, 2.0}>>
+ ```
+
+ - If the axis dimension in the containing tensor is static, its size must
+ be equal to the number of scales present in the quantized type.
+
+ ```
+ // Incorrect. The channel axis is 1, and the size of dimension 1 in the
+ // containing tensor is 3. However, there are 4 scale values present in the
+ // quantized type.
+ %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {1.0, 2.0, 3.0, 4.0}>>
+
+ // Correct. The quantized type now includes 3 scale values, matching the
+ // size of dimension 1 of the result tensor.
+ %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+ ```
+ }];
+ let cppNamespace = "::mlir::quant";
+ let useDefaultTypePrinterParser = 1;
+}
+
+
+//===----------------------------------------------------------------------===//
+// Type predicates
+//===----------------------------------------------------------------------===//
+
+class quant_ScalarOrTensorOf<Type etype> :
+ Type<Or<[etype.predicate, TensorOf<[etype]>.predicate]>,
+ "scalar or tensor of " # etype.summary>;
+
+def quant_QuantizedType :
+ Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "quantized type">;
+
+def quant_ScalarType :
+ Type<Or<[
+ AnySignlessInteger.predicate,
+ AnyFloat.predicate,
+ quant_QuantizedType.predicate
+ ]>,
+ "signless integer, float, or quantized scalar">;
+
+def quant_IntegerOrQuantizedType :
+ Type<Or<[
+ AnySignlessInteger.predicate,
+ quant_QuantizedType.predicate
+ ]>,
+ "signless integer or quantized type">;
+
+def quant_FloatScalarOrTensor :
+ quant_ScalarOrTensorOf<AnyFloat>;
+
+def quant_IntegerScalarOrTensor :
+ quant_ScalarOrTensorOf<AnySignlessInteger>;
+
+def quant_QuantizedScalarOrTensor :
+ quant_ScalarOrTensorOf<quant_QuantizedType>;
+
+def quant_IntegerOrQuantizedScalarOrTensor :
+ quant_ScalarOrTensorOf<quant_IntegerOrQuantizedType>;
+
+
+//===----------------------------------------------------------------------===//
+// Traits
+//===----------------------------------------------------------------------===//
+
+def quant_SameScalarOrTensorShape :
+ PredOpTrait<
+ "input and result are both scalars or both tensors with matching shape",
+ Or<[
+ And<[
+ TypeIsPred<"input", quant_ScalarType>,
+ TypeIsPred<"result", quant_ScalarType>
+ ]>,
+ And<[
+ TypeIsPred<"input", AnyUnrankedTensor>,
+ TypeIsPred<"result", AnyUnrankedTensor>
+ ]>,
+ And<[
+ TypeIsPred<"input", AnyRankedTensor>,
+ TypeIsPred<"result", AnyRankedTensor>,
+ AllShapesMatch<["input", "result"]>.predicate
+ ]>
+ ]>
+ >;
+
+def quant_IntegerAndQuantizedCombination :
+ PredOpTrait<
+ "input must be integer and result must be quantized, or vice versa",
+ Or<[
+ And<[
+ TypeIsPred<"input", quant_QuantizedScalarOrTensor>,
+ TypeIsPred<"result", quant_IntegerScalarOrTensor>
+ ]>,
+ And<[
+ TypeIsPred<"input", quant_IntegerScalarOrTensor>,
+ TypeIsPred<"result", quant_QuantizedScalarOrTensor>
+ ]>
+ ]>
+ >;
+
+#endif // QUANT_BASE
diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
similarity index 100%
rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td
rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
new file mode 100644
index 00000000000000..6ef925146dce66
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td
@@ -0,0 +1,243 @@
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef QUANT_OPS
+#define QUANT_OPS
+
+include "mlir/Dialect/Quant/IR/QuantBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<Trait> traits> :
+ Op<Quant_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization casts
+//===----------------------------------------------------------------------===//
+
+def quant_DequantizeCastOp : quant_Op<"dcast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
+ let summary = "Dequantize cast operation";
+ let description = [{
+ Convert an input quantized value into its expressed floating-point value.
+ The dequantization process consists of the following steps:
+
+ ```
+ def dequantize(quantizedValue: quantizedType) -> expressedType:
+ storedValue = reinterpretCast(quantizedValue, storageType)
+ storedValueFloat = convertIntToFloat(storedValue, expressedType)
+ zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+ expressedValue = (storedValueFloat - zeroPointFloat) * scale
+ return expressedValue
+ ```
+
+ Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained
+ from the corresponding parameters encoded in `quantizedType`. For
+ per-channel quantization, the appropriate `scale` and `zeroPoint` values
+ are used for each tensor element computation according to the channel the
+ element belongs to.
+
+ The numerical results produced by the algorithm above may vary depending on
+ the rounding methods used by `convertIntToFloat()`, subtraction (`-`), and
+ multiplication (`*`). This operation does not define specific rounding
+ methods; instead, it is the responsibility of a transform pipeline to
+ determine which rounding method to apply when this operation is broken down
+ into lower-level dialects.
+
+ The operation must satisfy the following syntactic constraints:
+
+ - Operand `input` must be a scalar or tensor of type `!quant.uniform`.
+
+ - The result type must be a floating-point scalar or tensor.
+
+ - The `expressedType` parameter of the `!quant.uniform` type of the input
+ must match the floating-point type of the result.
+
+ - The operand and result types must be both scalars or both tensors. If
+ tensors, they must be both ranked or both unranked. If ranked, both must
+ have the same shape, including matching static and dynamic dimensions.
+
+ - If the operand uses per-channel quantization, its `!quant.uniform` type
+ must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
+
+ ```
+ // Dequantize a scalar quantized value
+ %result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
+
+ // Dequantize a dynamically shaped tensor of quantized values
+ %result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
+
+ // Dequantize an unranked tensor using per-axis quantization information
+ %result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
+ ```
+ }];
+ let arguments = (ins quant_QuantizedScalarOrTensor:$input);
+ let results = (outs quant_FloatScalarOrTensor:$result);
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
+ let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// Return the float type of the scalar or tensor result.
+ FloatType getFloatType();
+
+ /// Return the quantized type of the scalar or tensor input.
+ quant::QuantizedType getQuantizedType();
+ }];
+}
+
+def quant_QuantizeCastOp : quant_Op<"qcast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
+ let summary = "Quantize cast operation";
+ let description = [{
+ Convert a floating-point value to a quantized type. The quantization
+ process consists of the following steps:
+
+ ```
+ def quantize(expressedValue: expressedType) -> quantizedType:
+ zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+ scaledValue = expressedValue / scale
+ storedValueFloat = scaledValue + zeroPointFloat
+ storedValue = convertFloatToInt(storedValueFloat, storageType)
+ storedValueClamped = clamp(storedValue, storageMin, storageMax)
+ quantizedValue = reinterpretCast(storedValueClamped, quantizedType)
+ return quantizedValue
+ ```
+
+ Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`,
+ and `zeroPoint` are obtained from the corresponding parameters encoded in
+ `quantizedType`. For per-channel quantization, the appropriate `scale` and
+ `zeroPoint` values are used for each tensor element computation according
+ to the channel the element belongs to.
+
+ The numerical results produced by the algorithm above may vary depending on
+ the rounding methods used by `convertIntToFloat()`, `convertFloatToInt()`,
+ `clamp()`, division (`/`), and addition (`+`). This operation does not
+ define specific rounding methods; instead, it is the responsibility of a
+ transform pipeline to determine which rounding method to apply when this
+ operation is broken down into lower-level dialects.
+
+ The operation must satisfy the following syntactic constraints:
+
+ - Operand `input` must be a floating-point scalar or tensor.
+
+ - The result type must be a scalar or tensor of type `!quant.uniform`.
+
+ - The `expressedType` parameter in the `!quant.uniform` type of the result
+ must match the floating-point type of the input.
+
+ - The operand and result types must be both scalars or both tensors. If
+ tensors, they must be both ranked or both unranked. If ranked, both must
+ have the same shape, including matching static and dynamic dimensions.
+
+ - If the result uses per-channel quantization, its `!quant.uniform` type
+ must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
+
+ ```
+ // Quantize a scalar floating-point value
+ %result = quant.qcast %input : f32 to !quant.uniform<i8:f32, 2.0>
+
+ // Quantize a dynamically shaped tensor of quantized values
+ %result = quant.qcast %input : tensor<?xf32> to tensor<?x!quant.uniform<i8:f32, 2.0>>
+
+ // Quantize an unranked tensor using per-axis quantization information
+ %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
+ ```
+ }];
+ let arguments = (ins quant_FloatScalarOrTensor:$input);
+ let results = (outs quant_QuantizedScalarOrTensor:$result);
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
+ let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// Return the float type of the scalar or tensor input.
+ FloatType getFloatType();
+
+ /// Return the quantized type of the scalar or tensor result.
+ quant::QuantizedType getQuantizedType();
+ }];
+}
+
+def quant_StorageCastOp : quant_Op<"scast", [
+ Pure,
+ quant_SameScalarOrTensorShape,
+ quant_IntegerAndQuantizedCombination]> {
+ let summary = "Storage cast operation";
+ let description = [{
+ Convert a value from a quantized type to the corresponding signless integer
+ storage type, or vice versa. This conversion simply involves a
+ reinterpretation of the input bits and does not involve any data
+ manipulation.
+
+ The following syntactic restrictions must be met:
+
+ - Operand `input` must be a scalar or tensor of a signless integer or
+ `!quant.uniform` type.
+
+ - The result must be a scalar or tensor of a signless integer or
+ `!quant.uniform` type.
+
+ - If the operand is a scalar or tensor of type integer, the result must be
+ a scalar or tensor of type `!quant.uniform`, and vice versa.
+
+ - The operand and result must be both scalars or both tensors. If tensors,
+ they must be both ranked or both unranked. If ranked, both must have the
+ same shape, including matching static and dynamic dimensions.
+
+ - The width of the `storageType` parameter of the quantized type of the
+ operand or result must match the width of the signless integer type of
+ the operand or result.
+
+ - If the operand or result uses per-channel quantization, its
+ `!quant.uniform` type must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
+
+ ```
+ // Cast a scalar quantized value into its storage type
+ %result = quant.scast %input : !quant.uniform<i8:f32, 2.0> to i8
+
+ // Cast a dynamically shaped tensor of quantized values into their storage type
+ %result = quant.scast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xi8>
+
+ // Cast an unranked tensor of signless integers into a quantized type using
+ // per-channel quantization
+ %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>>
+ ```
+ }];
+ let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input);
+ let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result);
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
+ let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// Return the integer type used either in the input or the result.
+ IntegerType getIntegerType();
+
+ /// Return the quantized type used either in the input or the result.
+ quant::QuantizedType getQuantizedType();
+ }];
+}
+
+#endif // QUANT_OPS
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
similarity index 98%
rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h
rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 57a2aa29833657..43440ba623b9c1 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
-#define MLIR_DIALECT_QUANT_QUANTTYPES_H
+#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
+#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -114,6 +114,10 @@ class QuantizedType : public Type {
/// The maximum value that storageType can take.
int64_t getStorageTypeMax() const;
+ /// Return whether the storage type has explicit min or max boundaries
+ ///
diff erent from the minimum and maximum representable values.
+ bool hasStorageTypeBounds() const;
+
/// Gets the integral bit width that the underlying storage type can exactly
/// represent. For integral storage types, this will just be their width.
unsigned getStorageTypeIntegralWidth() const;
@@ -413,4 +417,4 @@ class CalibratedQuantizedType
} // namespace quant
} // namespace mlir
-#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
+#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td
deleted file mode 100644
index 7937265ce2f209..00000000000000
--- a/mlir/include/mlir/Dialect/Quant/QuantOps.td
+++ /dev/null
@@ -1,103 +0,0 @@
-//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This is the operation definition file for Quantization.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef DIALECT_QUANT_QUANT_OPS_
-#define DIALECT_QUANT_QUANT_OPS_
-
-include "mlir/Dialect/Quant/QuantOpsBase.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-
-//===----------------------------------------------------------------------===//
-// Base classes
-//===----------------------------------------------------------------------===//
-
-class quant_Op<string mnemonic, list<Trait> traits> :
- Op<Quantization_Dialect, mnemonic, traits>;
-
-//===----------------------------------------------------------------------===//
-// Quantization casts
-//===----------------------------------------------------------------------===//
-
-def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> {
- let summary = "convert a quantizable type to a quantized type";
- let description = [{
- A QuantizeCast `qcast` represents a potential type shift from a quantizable
- type to a quantized type.
-
- At runtime, a `qcast` will apply the transformation expressed by its
- operand and result type. For flexibility during transformation, it is also
- possible to have a `qcast` that performs no transformation (both its
- operand and result type are quantizable).
-
- A `qcast` will typically originate from either:
- a) An expressed or implied constraint in the source dialect which signals
- that a certain level of quantization is possible or required.
- b) An inference made by a quantization algorithm indicating that a
- quantized representation may be acceptable.
-
- Especially early in transformation, it is common to have pairs of
- `qcast` and `dcast` at points where a transition to a quantized type is
- required. In addition, it is also common to have an identity `qcast`
- (where the operand and result type are not quantized) at all points where
- it is legal to use a quantized representation (but is not known to be
- acceptable).
- }];
- let arguments = (ins quant_RealValueType:$arg);
- let results = (outs quant_RealValueType:$res);
-}
-
-def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> {
- let summary = "convert back from a quantized to quantizable (expressed) type operation";
- let description = [{
- A DequantizeCast op `dcast` represents the inverse of a `qcast`,
- converting back from a quantized to quantizable (expressed) type.
-
- Like `qcast`s, a `dcast` is allowed to have both its operand and result
- as non quantized types. This facilitates transformations and marks edges
- where the computation must be carried out in the expressed type.
-
- Especially early in transformation, it is common to have `dcast`s on
- all operands to ops that must operate with the expressed type (typically
- math ops prior to lowering to target-specific, quantized kernels).
- }];
- let arguments = (ins quant_RealValueType:$arg);
- let results = (outs quant_RealValueType:$res);
-}
-
-def quant_StorageCastOp : quant_Op<"scast", [Pure]> {
- let summary = "cast from or to a type based on the storage type and the corresponding quantized type";
- let description = [{
- A StorageCast `scast` represents a cast from or to a type based on the
- storage type and a type based on a corresponding quantized type.
-
- This op exists to ensure type coherency for between parts of the computation
- which are operating directly on an underlying storage type and those which
- operate on quantized values.
-
- Examples from storage to quantized type:
- ```
- i8 -> !quant<"uniform[i8:f32]{1.0}">
- ```
- ```
- tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- ```
- ```
- vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
- ```
- }];
- let arguments = (ins quant_RealOrStorageValueType:$arg);
- let results = (outs quant_RealOrStorageValueType:$res);
- let hasFolder = 1;
-}
-
-#endif // DIALECT_QUANT_QUANT_OPS_
diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
deleted file mode 100644
index da822d0a61deb2..00000000000000
--- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td
+++ /dev/null
@@ -1,74 +0,0 @@
-//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Predicates for types in the Quantization dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef DIALECT_QUANT_QUANT_OPS_BASE_
-#define DIALECT_QUANT_QUANT_OPS_BASE_
-
-include "mlir/IR/OpBase.td"
-
-def Quantization_Dialect : Dialect {
- let name = "quant";
- let cppNamespace = "::mlir::quant";
-
- let useDefaultTypePrinterParser = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// Quantization type definitions
-//===----------------------------------------------------------------------===//
-
-class quant_TypedPrimitiveOrContainer<Type etype> :
- Type<Or<[etype.predicate,
- TensorOf<[etype]>.predicate,
- VectorOf<[etype]>.predicate]>,
- "primitive/tensor/vector of " # etype.summary>;
-
-// An implementation of QuantizedType.
-def quant_QuantizedType :
- Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "QuantizedType">;
-
-// A primitive type that can represent a real value. This is either a
-// floating point value or a quantized type.
-def quant_RealPrimitiveType :
- Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
- "real valued primitive (float or quantized type)">;
-
-// A primitive type that can represent a storage value. This is either an
-// integer or quantized type.
-def quant_StoragePrimitiveType :
- Type<Or<[AnySignlessInteger.predicate, quant_QuantizedType.predicate]>,
- "quantized storage primitive (integer or quantized type)">;
-
-// A primitive or container of RealPrimitiveType.
-def quant_RealValueType :
- quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
-
-// A primitive or container of StoragePrimitiveType.
-def quant_StorageValueType :
- quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
-
-// Either a real valued or storage primitive or container type.
-def quant_RealOrStorageValueType :
- Type<Or<[quant_RealValueType.predicate, quant_StorageValueType.predicate]>,
- "real valued or storage primitive or container type">;
-
-// An implementation of UniformQuantizedType.
-def quant_UniformQuantizedType :
- DialectType<Quantization_Dialect,
- CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
- "UniformQuantizedType">;
-
-// Predicate for detecting a container or primitive of UniformQuantizedType.
-def quant_UniformQuantizedValueType :
- quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
-
-#endif // DIALECT_QUANT_QUANT_OPS_BASE_
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..30f7c1696bdb9b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
+add_public_tablegen_target(MLIRQuantTransformsIncGen)
+
+add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
new file mode 100644
index 00000000000000..84be2a21b34ed2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h
@@ -0,0 +1,29 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns);
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
new file mode 100644
index 00000000000000..b25296d4db5a99
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -0,0 +1,49 @@
+//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
+ let summary = "Lower quant.dcast and quant.qcast ops";
+ let description = [{
+ Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops
+ into other core dialects.
+
+ The lowering process generates storage type casts in the form of
+ `quant.scast` ops to act as an interface between the original quantized
+ types of operands and results and their corresponding storage types used in
+ the generated arithmetic computations.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "linalg::LinalgDialect",
+ "quant::QuantDialect",
+ "shape::ShapeDialect",
+ "tensor::TensorDialect"
+ ];
+}
+
+def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
+ let summary = "Strip quantized types from function headers";
+ let description = [{
+ Identify occurrences of function arguments using a quantized type and
+ replace them with a new value of the corresponding storage (signless
+ integer) type. For each converted argument, a `quant.scast` op is introduced
+ at the head of the function's entry block converting the new integer
+ argument into the original quantized value.
+ }];
+ let dependentDialects = [
+ "func::FuncDialect",
+ "quant::QuantDialect"
+ ];
+}
+
+#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
similarity index 93%
rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
index 367d468b2acf1a..6551efc6242a60 100644
--- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h
@@ -34,10 +34,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
-#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
namespace mlir {
namespace quant {
@@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
} // namespace quant
} // namespace mlir
-#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_
diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
similarity index 97%
rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h
rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
index 4119aced4c0752..6773f45069c874 100644
--- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
-#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
+#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
#include <utility>
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
@@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter {
} // namespace quant
} // namespace mlir
-#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_
+#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 64bacd0e432fe5..67b41187e5bfb7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect {
there will be tools to lower from the ML frameworks into TOSA.
}];
- let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"];
+ let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"];
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 298c97015fe2eb..5e80745777b3b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -16,8 +16,8 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
namespace mlir {
namespace tosa {
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 73dccdb017ee14..7fd0432ddce1bb 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -65,7 +65,7 @@
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
@@ -137,7 +137,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
pdl_interp::PDLInterpDialect,
polynomial::PolynomialDialect,
ptr::PtrDialect,
- quant::QuantizationDialect,
+ quant::QuantDialect,
ROCDL::ROCDLDialect,
scf::SCFDialect,
shape::ShapeDialect,
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 1b9c1b193ace6e..dd8b292a87344e 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -35,6 +35,7 @@
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
memref::registerMemRefPasses();
mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
+ quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 0a7181d8bc17c3..c94dbb5692fdb0 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -8,12 +8,12 @@
#include "mlir-c/Dialect/Quant.h"
#include "mlir/CAPI/Registration.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
using namespace mlir;
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
//===---------------------------------------------------------------------===//
// QuantizedType
diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt
index 037bba8dcb5c9b..31167e6af908b9 100644
--- a/mlir/lib/Dialect/Quant/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index c0c00fb4893cb3..6a4ac310eb0524 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -9,8 +9,8 @@
#include "QuantDialectBytecode.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVector.h"
@@ -31,7 +31,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader,
return success();
}
-#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc"
/// This class implements the bytecode interface for the Quant dialect.
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
@@ -64,6 +64,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
};
} // namespace
-void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
+void quant::detail::addBytecodeInterface(QuantDialect *dialect) {
dialect->addInterfaces<QuantDialectBytecodeInterface>();
}
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
index 9e9cbf66d84d92..eef2b5bbefecc0 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h
@@ -15,12 +15,12 @@
#define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H
namespace mlir::quant {
-class QuantizationDialect;
+class QuantDialect;
namespace detail {
/// Add the interfaces necessary for encoding the quantization dialect
/// components in bytecode.
-void addBytecodeInterface(QuantizationDialect *dialect);
+void addBytecodeInterface(QuantDialect *dialect);
} // namespace detail
} // namespace mlir::quant
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c9a6bbc9ceeea5..c584903f3a15de 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -6,44 +6,209 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
#include "QuantDialectBytecode.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/MathExtras.h"
-#include <numeric>
+#include "mlir/IR/TypeUtilities.h"
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
-void QuantizationDialect::initialize() {
+namespace mlir {
+namespace quant {
+
+namespace {
+
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+// Any quantized type. Any quantized type with no per-axis quantization is
+// ignored.
+//
+// - containerType
+// Original input or result type of the operation using the provided quantized
+// type. Used to ensure that the quantized type appears within a tensor and
+// that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+ QuantizedType quantizedType,
+ Type containerType) {
+ auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+ if (!quantizedPerAxisType)
+ return success();
+
+ auto tensorType = dyn_cast<TensorType>(containerType);
+ if (!tensorType)
+ return op->emitError("scalar types may not use per-axis quantization");
+
+ if (!tensorType.hasRank())
+ return success();
+
+ int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+ if (quantizedDimension >= tensorType.getRank())
+ return op->emitError("quantized dimension must be less than tensor rank");
+
+ int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
+ if (quantizedDimensionSize != ShapedType::kDynamic &&
+ quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+ return op->emitError(
+ "quantized dimension size does not match number of scales");
+
+ return success();
+}
+
+// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
+//
+// - quantizedType
+// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
+// whether as a primitive type or in a tensor.
+//
+// - floatType
+// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
+// whether as a primitive type or in a tensor.
+//
+// - containerType
+// Type of original input or result.
+//
+LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
+ FloatType floatType, Type containerType) {
+ if (quantizedType.getExpressedType() != floatType)
+ return op->emitError(
+ "expressed type in quantized type expected to match float type");
+
+ // Veriy integrity of per-axis quantization information, if present.
+ return verifyPerAxisQuantization(op, quantizedType, containerType);
+}
+
+} // namespace
+
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
+void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
>();
- addBytecodeInterface(this);
+ detail::addBytecodeInterface(this);
+}
+
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DequantizeCastOp::verify() {
+ return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+ getInput().getType());
+}
+
+OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
+ // with the value of x. Values x and y are guaranteed to be of the same type
+ // in this pattern.
+ auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
+ if (!srcQcastOp)
+ return {};
+ assert(srcQcastOp.getInput().getType() == getType());
+ return srcQcastOp.getInput();
+}
+
+FloatType DequantizeCastOp::getFloatType() {
+ return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+QuantizedType DequantizeCastOp::getQuantizedType() {
+ return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+
+//===----------------------------------------------------------------------===//
+// QuantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult QuantizeCastOp::verify() {
+ return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+ getInput().getType());
+}
+
+OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
+ // with the value of x if the casts invert each other. Contrary to the folding
+ // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
+ // x and y are not guaranteed to be of the same type here, as they may use
+ //
diff erent quantization parameters.
+ auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
+ if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
+ return {};
+ return srcDcastOp.getInput();
+}
+
+FloatType QuantizeCastOp::getFloatType() {
+ return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
+}
+
+QuantizedType QuantizeCastOp::getQuantizedType() {
+ return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
+}
+
+
+//===----------------------------------------------------------------------===//
+// StorageCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult StorageCastOp::verify() {
+ auto quantizedType = getQuantizedType();
+ auto integerType = getIntegerType();
+ if (quantizedType.getStorageType() != integerType)
+ return emitError(
+ "storage type in quantized type expected to match integer type");
+
+ // Verify integrity of per-axis quantization information, if available. While
+ // the quantization type may appear in the input or the result, their tensor
+ // shapes are guaranteed to be identical at this point.
+ return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
}
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
- // Matches x -> [scast -> scast] -> y, replacing the second scast with the
- // value of x if the casts invert each other.
- auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
- if (!srcScastOp || srcScastOp.getArg().getType() != getType())
- return OpFoldResult();
- return srcScastOp.getArg();
+ // Matches x -> quant.scast -> quant.scast -> y, replacing the second
+ // quant.scast with the value of x if the casts invert each other.
+ auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
+ if (!srcScastOp || srcScastOp.getInput().getType() != getType())
+ return {};
+ return srcScastOp.getInput();
+}
+
+IntegerType StorageCastOp::getIntegerType() {
+ auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+ if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
+ return integerType;
+
+ auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+ return cast<IntegerType>(resultScalarType);
+}
+
+QuantizedType StorageCastOp::getQuantizedType() {
+ auto inputScalarType = getElementTypeOrSelf(getInput().getType());
+ if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
+ return quantizedType;
+
+ auto resultScalarType = getElementTypeOrSelf(getResult().getType());
+ return cast<QuantizedType>(resultScalarType);
}
+
+} // namespace quant
+} // namespace mlir
+
#define GET_OP_CLASSES
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
+
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index c2ba9c04e8771d..ac01b37a553077 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantTypes.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -20,12 +20,28 @@ using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
+namespace {
+
+// Return the minimum scale representable in a given float type
+double getMinScale(Type expressedType) {
+ auto floatType = cast<FloatType>(expressedType);
+ return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
+}
+
+// Return the maximum scale representable in a given float type
+double getMaxScale(Type expressedType) {
+ auto floatType = cast<FloatType>(expressedType);
+ return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
+}
+
+} // namespace
+
unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(impl)->flags;
}
bool QuantizedType::classof(Type type) {
- return llvm::isa<QuantizationDialect>(type.getDialect());
+ return llvm::isa<QuantDialect>(type.getDialect());
}
LogicalResult
@@ -73,6 +89,17 @@ int64_t QuantizedType::getStorageTypeMax() const {
return static_cast<ImplType *>(impl)->storageTypeMax;
}
+bool QuantizedType::hasStorageTypeBounds() const {
+ unsigned int integralWidth = getStorageTypeIntegralWidth();
+ bool isSignedInteger = isSigned();
+ int64_t defaultIntegerMin =
+ getDefaultMinimumForInteger(isSignedInteger, integralWidth);
+ int64_t defaultIntegerMax =
+ getDefaultMaximumForInteger(isSignedInteger, integralWidth);
+ return defaultIntegerMin != getStorageTypeMin() ||
+ defaultIntegerMax != getStorageTypeMax();
+}
+
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
@@ -293,8 +320,13 @@ LogicalResult UniformQuantizedType::verifyInvariants(
return emitError() << "expressed type must be floating point";
// Verify scale.
+ double minScale = getMinScale(expressedType);
+ double maxScale = getMaxScale(expressedType);
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitError() << "illegal scale: " << scale;
+ if (scale < minScale || scale > maxScale)
+ return emitError() << "scale out of expressed type range [" << minScale
+ << ", " << maxScale << "]";
return success();
}
@@ -353,11 +385,20 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
<< scales.size() << ", " << zeroPoints.size();
// Verify scale.
+ double minScale = getMinScale(expressedType);
+ double maxScale = getMaxScale(expressedType);
for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitError() << "illegal scale: " << scale;
+ if (scale < minScale || scale > maxScale)
+ return emitError() << "scale out of expressed type range [" << minScale
+ << ", " << maxScale << "]";
}
+ // Verify quantized dimension.
+ if (quantizedDimension < 0)
+ return emitError() << "illegal quantized dimension: " << quantizedDimension;
+
return success();
}
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 926a8a0aa13d5c..851763d8942e83 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
@@ -317,7 +317,7 @@ static Type parseCalibratedType(DialectAsmParser &parser) {
}
/// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
+Type QuantDialect::parseType(DialectAsmParser &parser) const {
// All types start with an identifier that we switch on.
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling)))
@@ -346,12 +346,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
}
// storageTypeMin and storageTypeMax if not default.
- int64_t defaultIntegerMin =
- QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
- int64_t defaultIntegerMax =
- QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
- if (defaultIntegerMin != type.getStorageTypeMin() ||
- defaultIntegerMax != type.getStorageTypeMax()) {
+ if (type.hasStorageTypeBounds()) {
out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
<< ">";
}
@@ -419,7 +414,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type,
}
/// Print a type registered to this dialect.
-void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
+void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
printAnyQuantizedType(anyType, os);
else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..2fd4a41999d456
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -0,0 +1,26 @@
+add_mlir_dialect_library(MLIRQuantTransforms
+ LowerQuantOps.cpp
+ StripFuncQuantTypes.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+
+ DEPENDS
+ MLIRQuantTransformsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRFuncTransforms
+ MLIRIndexDialect
+ MLIRIR
+ MLIRLinalgDialect
+ MLIRLinalgUtils
+ MLIRPass
+ MLIRQuantDialect
+ MLIRShapeDialect
+ MLIRTensorDialect
+ MLIRTransforms
+ MLIRTransformUtils
+
+ )
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
new file mode 100644
index 00000000000000..4adeb9218ff8ec
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -0,0 +1,676 @@
+//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_LOWERQUANTOPS
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+// If 'inputType' is a tensor, return its element type. If it is a scalar,
+// return it as is.
+Type getScalarType(Type inputType) {
+ if (auto tensorType = dyn_cast<TensorType>(inputType))
+ return tensorType.getElementType();
+ return inputType;
+}
+
+// Return the shape of an input value as a list of attributes (static dimensions)
+// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
+// returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult>
+getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+ if (isa<TensorType>(input.getType()))
+ return tensor::getMixedSizes(builder, loc, input);
+ return {};
+}
+
+// If 'referenceType' is a scalar, return 'elementType' as is. If
+// 'referenceType' is a tensor, return another tensor with the same shape and
+// elements of type 'elementType'.
+Type getScalarOrTensorType(Type elementType, Type referenceType) {
+ if (auto tensorType = dyn_cast<TensorType>(referenceType))
+ return tensorType.clone(elementType);
+ return elementType;
+}
+
+// Return a constant with the given value. If 'referenceType' is a tensor, a
+// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
+// scalar, 'referenceShape' is ignored and a scalar constant is returned.
+Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
+ Type referenceType,
+ ArrayRef<OpFoldResult> referenceShape) {
+ // If the result type is a scalar, return the unmodified scalar constant.
+ auto tensorType = dyn_cast<TensorType>(referenceType);
+ if (!tensorType) {
+ assert(referenceShape.empty());
+ return scalar;
+ }
+
+ // Create tensor splat
+ auto tensorConstant =
+ builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
+ return tensorConstant;
+}
+
+// Reshape an unranked tensor into a 1D ranked tensor.
+//
+// - input
+// Unranked tensor.
+//
+// Return values:
+//
+// - flatInput
+// 1D ranked, dynamically shaped tensor.
+//
+// - inputShape
+// 1D extent tensor containing the shape of the original unranked input.
+//
+std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
+ Value input) {
+ // Get unranked input shape and total size
+ auto *context = builder.getContext();
+ auto shapeType = shape::getExtentTensorType(context);
+ auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+ Value inputSize = builder.create<shape::NumElementsOp>(
+ loc, builder.getIndexType(), inputShape);
+
+ // Turn input size into 1D tensor
+ auto flatShapeType = shape::getExtentTensorType(context, 1);
+ auto flatInputShape = builder.create<tensor::FromElementsOp>(
+ loc, flatShapeType, inputSize);
+
+ // Reshape input tensor into 1D
+ auto inputType = cast<UnrankedTensorType>(input.getType());
+ auto elementType = inputType.getElementType();
+ auto flatInputType =
+ RankedTensorType::get({ShapedType::kDynamic}, elementType);
+ auto flatInput = builder.create<tensor::ReshapeOp>(
+ loc, flatInputType, input, flatInputShape);
+ return std::make_pair(flatInput, inputShape);
+}
+
+// Reshape an unranked tensor into a 3D ranked tensor where the central
+// dimension of the result tensor corresponds to dimension 'axis' of the input
+// tensor.
+//
+// - input
+// Unranked tensor.
+//
+// - axis
+// Index of the input dimension around which other input dimiensions will be
+// collapsed.
+//
+// - axisSize
+// Size of input dimension 'axis'.
+//
+// Return values:
+//
+// - flatInput
+// 3D ranked tensor of shape [?, axisSize, ?].
+//
+// - inputShape
+// 1D extent tensor containing the shape of the original unranked input.
+//
+std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
+ Location loc,
+ Value input,
+ int64_t axis,
+ int64_t axisSize) {
+ // Get full tensor shape
+ auto *context = builder.getContext();
+ auto indexType = builder.getIndexType();
+ auto shapeType = shape::getExtentTensorType(context);
+ auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
+
+ // Get shape and sizes on left and right of axis
+ auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
+ auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
+ auto shapeLeft = builder.create<shape::SplitAtOp>(
+ loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
+ .getResult(0);
+ auto sizeLeft = builder.create<shape::NumElementsOp>(
+ loc, indexType, shapeLeft);
+ auto shapeRight = builder.create<shape::SplitAtOp>(
+ loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
+ .getResult(1);
+ auto sizeRight = builder.create<shape::NumElementsOp>(
+ loc, indexType, shapeRight);
+
+ // Compute flat input shape as a 3-element 1D tensor
+ auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
+ auto flatShapeType = shape::getExtentTensorType(context, 3);
+ auto flatInputShape = builder.create<tensor::FromElementsOp>(
+ loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
+
+ // Reshape input to 3D tensor
+ auto inputType = cast<UnrankedTensorType>(input.getType());
+ auto elementType = inputType.getElementType();
+ auto flatInputType = RankedTensorType::get(
+ {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
+ auto flatInput = builder.create<tensor::ReshapeOp>(
+ loc, flatInputType, input, flatInputShape);
+
+ return std::make_pair(flatInput, inputShape);
+}
+
+// Reshape an input tensor into its original unranked shape.
+//
+// - input
+// Ranked tensor.
+//
+// - inputShape
+// 1D extent tensor.
+//
+Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
+ Value inputShape) {
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elementType = inputType.getElementType();
+ auto unrankedType = UnrankedTensorType::get(elementType);
+ return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
+}
+
+// Create a tensor constant containing all scales in a per-channel quantized
+// type. Example:
+//
+// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
+//
+Value materializePerChannelScales(OpBuilder &builder, Location loc,
+ UniformQuantizedPerAxisType quantizedType) {
+ auto scales = quantizedType.getScales();
+ auto expressedType = quantizedType.getExpressedType();
+ auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
+ return builder.getFloatAttr(expressedType, scale);
+ });
+ auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+ auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
+ return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+}
+
+// Create a tensor constant containing all zero points in a per-channel
+// quantized type. Example:
+//
+// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+//
+// produces
+//
+// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
+//
+Value materializePerChannelZeroPoints(
+ OpBuilder &builder, Location loc,
+ UniformQuantizedPerAxisType quantizedType) {
+ auto zeroPoints = quantizedType.getZeroPoints();
+ auto storageType = quantizedType.getStorageType();
+ auto zeroPointAttrs = llvm::map_to_vector(
+ zeroPoints,
+ [&](int64_t zeroPoint) -> Attribute {
+ return builder.getIntegerAttr(storageType, zeroPoint);
+ });
+ auto tensorType =
+ RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
+ auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
+ return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+}
+
+// Clamp the given scalar or tensor input using the storage bounds encoded in
+// the given quantized type, if present.
+//
+// - input
+// Scalar or ranked tensor input. The element type must match the storage type
+// of 'quantizedType'.
+//
+// - inputShape
+// If 'input' is a tensor, combination of attributes/values representing its
+// static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - quantizedType
+// Per-axis or per-channel quantized type.
+Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape,
+ QuantizedType quantizedType) {
+ // If quantized type does not narrow down the storage type range, there is
+ // nothing to do.
+ if (!quantizedType.hasStorageTypeBounds())
+ return input;
+
+ // Materialize bounds
+ auto inputType = input.getType();
+ auto storageType = quantizedType.getStorageType();
+ auto storageMinScalar = builder.create<arith::ConstantIntOp>(
+ loc, quantizedType.getStorageTypeMin(), storageType);
+ auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
+ loc, quantizedType.getStorageTypeMax(), storageType);
+ auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
+ inputType, inputShape);
+ auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
+ inputType, inputShape);
+
+ // Clamp
+ if (quantizedType.isSigned()) {
+ input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
+ input = builder.create<arith::MinSIOp>(loc, input, storageMax);
+ } else {
+ input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
+ input = builder.create<arith::MinUIOp>(loc, input, storageMax);
+ }
+ return input;
+}
+
+// Emit op 'arith.fptosi' or 'arith.fptoui'.
+Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
+ Type resultType, bool isSigned) {
+ if (isSigned)
+ return builder.create<arith::FPToSIOp>(loc, resultType, input);
+ return builder.create<arith::FPToUIOp>(loc, resultType, input);
+}
+
+// Emit op 'arith.sitofp' or 'arith.uitofp'.
+Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
+ Type resultType, bool isSigned) {
+ if (isSigned)
+ return builder.create<arith::SIToFPOp>(loc, resultType, input);
+ return builder.create<arith::UIToFPOp>(loc, resultType, input);
+}
+
+// Quantize a scalar or ranked tensor value. The stored value is clamped using
+// the storage bounds encoded in the given quantized type.
+//
+// See function 'convertRanked()' below for a description of the arguments.
+Value quantizeValue(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ // Convert scale to tensor if necessary
+ auto inputType = input.getType();
+ scale = getScalarOrTensorConstant(
+ builder, loc, scale, inputType, inputShape);
+
+ // Scale input
+ auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
+
+ // Skip unnecessary computations if no zero point is given
+ Value storedValueFloat = scaledValue;
+ if (!matchPattern(zeroPoint, m_Zero())) {
+ // Convert zero point to tensor if necessary
+ zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+ inputShape);
+
+ // Convert zero point from storage to expressed type
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+ scale.getType(),
+ quantizedType.isSigned());
+
+ // Add zero point to stored value
+ storedValueFloat =
+ builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
+ }
+
+ // Convert stored value to storage type
+ auto storageScalarOrTensorType =
+ getScalarOrTensorType(quantizedType.getStorageType(), inputType);
+ auto storedValueInt = convertFloatToInteger(
+ builder, loc, storedValueFloat, storageScalarOrTensorType,
+ quantizedType.isSigned());
+
+ // Clamp stored value it if the storage type is bound
+ auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
+ inputShape, quantizedType);
+ return storedValueClamped;
+}
+
+// Dequantize a scalar or ranked tensor input.
+//
+// See function 'convertRanked()' below for a description of the arguments.
+Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
+ ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ // Convert scale to tensor if necessary
+ auto inputType = input.getType();
+ scale = getScalarOrTensorConstant(
+ builder, loc, scale, inputType, inputShape);
+
+ // Convert stored value to float
+ auto result = convertIntegerToFloat(
+ builder, loc, input, scale.getType(), quantizedType.isSigned());
+
+ // Skip unnecessary computations if no zero point is given
+ if (!matchPattern(zeroPoint, m_Zero())) {
+ // Convert zero point to tensor if necessary
+ zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
+ inputShape);
+
+ // Convert zero point from storage to expressed type
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
+ scale.getType(),
+ quantizedType.isSigned());
+
+ // Subtract zero point to stored value
+ result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
+ }
+
+ // Multiply by scale
+ result = builder.create<arith::MulFOp>(loc, result, scale);
+ return result;
+}
+
+// Convert a scalar or ranked tensor input with the given scale and zero point
+// values.
+//
+// - input
+// Scalar or ranked tensor value.
+//
+// - inputShape
+// If 'input' is a tensor, combination or attributes/values representing its
+// static/dynamic dimensions. If 'input' is a scalar, empty list.
+//
+// - scale
+// Scale as a floating-point scalar value.
+//
+// - zeroPoint
+// Zero point as an integer scalar value.
+//
+// - quantizedType
+// Scalar quantized type of the result ('quant.qcast') or of the input
+// ('quant.dcast').
+//
+Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
+ Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
+ Value zeroPoint, QuantizedType quantizedType) {
+ if (isa<QuantizeCastOp>(op))
+ return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+ quantizedType);
+ if (isa<DequantizeCastOp>(op))
+ return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
+ quantizedType);
+ llvm_unreachable("unexpected quant op");
+}
+
+// Convert an operation using per-layer quantization with a scalar or ranked
+// tensor input.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar or ranked tensor.
+//
+// - quantizedType
+// Per-layer quantized type.
+//
+Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
+ Value input, UniformQuantizedType quantizedType) {
+ // Create scale and zero point constants
+ auto expressedType = quantizedType.getExpressedType();
+ auto storageType = quantizedType.getStorageType();
+ auto scaleAttr =
+ builder.getFloatAttr(expressedType, quantizedType.getScale());
+ auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
+ auto zeroPointAttr =
+ builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
+ auto zeroPoint =
+ builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
+
+ auto inputShape = getScalarOrTensorShape(builder, loc, input);
+ return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
+ quantizedType);
+}
+
+// Convert an operation using per-layer quantization.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+// Per-layer quantized type.
+//
+Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
+ Value input, UniformQuantizedType quantizedType) {
+ // Flatten input if unranked
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ Value inputShape;
+ if (isUnranked)
+ std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
+
+ // Process ranked tensor
+ auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
+
+ // Restore original shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+ return result;
+}
+
+// Convert an operation using per-channel quantization and a scalar or ranked
+// tensor as an input.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar or ranked tensor.
+//
+// - quantizedType
+// Per-channel quantized type.
+//
+Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
+ Value input,
+ UniformQuantizedPerAxisType quantizedType,
+ int64_t channelAxis) {
+ auto *context = builder.getContext();
+
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto inputRank = inputType.getRank();
+
+ auto scales = materializePerChannelScales(builder, loc, quantizedType);
+ auto zeroPoints =
+ materializePerChannelZeroPoints(builder, loc, quantizedType);
+
+ auto elementType = isa<FloatType>(inputType.getElementType())
+ ? quantizedType.getStorageType()
+ : quantizedType.getExpressedType();
+ auto initShape = tensor::getMixedSizes(builder, loc, input);
+ Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
+ inputRank, utils::IteratorType::parallel);
+ auto channelAxisAffineMap = AffineMap::get(
+ inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
+ SmallVector<AffineMap> indexingMaps{
+ builder.getMultiDimIdentityMap(inputRank),
+ channelAxisAffineMap,
+ channelAxisAffineMap,
+ builder.getMultiDimIdentityMap(inputRank)
+ };
+ auto result = builder.create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps,
+ iteratorTypes,
+ [&](OpBuilder& builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto input = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result = convertRanked(builder, loc, op, input, {}, scale,
+ zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, result);
+ })
+ .getResult(0);
+
+ return result;
+}
+
+// Convert an operation using per-channel quantization.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor, or unranked tensor.
+//
+// - quantizedType
+// Per-channel quantized type.
+//
+Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
+ Value input,
+ UniformQuantizedPerAxisType quantizedType) {
+ // Flatten unranked tensor into a 3D ranked tensor if necessary
+ bool isUnranked = isa<UnrankedTensorType>(input.getType());
+ int64_t channelAxis = quantizedType.getQuantizedDimension();
+ int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+ Value inputShape;
+ if (isUnranked) {
+ std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
+ builder, loc, input, channelAxis, channelAxisSize);
+ channelAxis = 1;
+ }
+
+ // Work on a ranked tensor
+ auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
+ channelAxis);
+
+ // Restore original tensor shape if unranked
+ if (isUnranked)
+ result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
+
+ return result;
+}
+
+// Convert a quantization operation.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor, or unranked tensor. The element type matches
+// the storage type (quant.dcast) or expressed type (quant.qcast) of
+// 'quantizedType'.
+//
+// - quantizedType
+// Per-layer or per-channel quantized type.
+//
+Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
+ Value input, Type quantizedType) {
+ if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
+ return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
+
+ if (auto uniformQuantizedPerAxisType =
+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
+ return convertPerChannel(builder, loc, op, input,
+ uniformQuantizedPerAxisType);
+
+ llvm_unreachable("unexpected quantized type");
+}
+
+// Lowering pattern for 'quant.dcast'
+struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+ using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto input = op.getInput();
+ auto quantizedType =
+ cast<QuantizedType>(getScalarType(op.getInput().getType()));
+
+ // Convert quantized input to storage type
+ auto storageScalarOrTensorType =
+ getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
+ input = rewriter.create<quant::StorageCastOp>(
+ loc, storageScalarOrTensorType, input);
+
+ auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+// Lowering pattern for 'quant.qcast'
+struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+ using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto input = op.getInput();
+ auto quantizedType = getScalarType(op.getResult().getType());
+
+ // Flatten unranked tensor input
+ auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
+
+ // Cast stored value to result quantized value
+ rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
+ op, op.getResult().getType(), result);
+ return success();
+ }
+};
+
+struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateLowerQuantOpsPatterns(patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalOp<quant::StorageCastOp>();
+ target.addIllegalDialect<quant::QuantDialect>();
+ target.addLegalDialect<
+ arith::ArithDialect,
+ linalg::LinalgDialect,
+ shape::ShapeDialect,
+ tensor::TensorDialect
+ >();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
+ patterns.add<
+ DequantizeCastOpConversion,
+ QuantizeCastOpConversion
+ >(patterns.getContext());
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
new file mode 100644
index 00000000000000..8996eff61a39c0
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -0,0 +1,114 @@
+//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Strips quantized types from function headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+class QuantizedTypeConverter : public TypeConverter {
+
+ static Type convertQuantizedType(QuantizedType quantizedType) {
+ return quantizedType.getStorageType();
+ }
+
+ static Type convertTensorType(TensorType tensorType) {
+ if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
+ return tensorType.clone(convertQuantizedType(quantizedType));
+ return tensorType;
+ }
+
+ static Value materializeConversion(OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) {
+ assert(inputs.size() == 1);
+ return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
+ }
+
+public:
+
+ explicit QuantizedTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertQuantizedType);
+ addConversion(convertTensorType);
+
+ addArgumentMaterialization(materializeConversion);
+ addSourceMaterialization(materializeConversion);
+ addTargetMaterialization(materializeConversion);
+ }
+};
+
+// Conversion pass
+class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
+
+ // Return whether a type is considered legal when occurring in the header of
+ // a function or as an operand to a 'return' op.
+ static bool isLegalType(Type type) {
+ if (auto tensorType = dyn_cast<TensorType>(type))
+ return isLegalType(tensorType.getElementType());
+ return !isa<quant::QuantizedType>(type);
+ }
+
+public:
+
+ void runOnOperation() override {
+
+ auto moduleOp = cast<ModuleOp>(getOperation());
+ auto* context = &getContext();
+
+ QuantizedTypeConverter typeConverter;
+ ConversionTarget target(*context);
+ RewritePatternSet patterns(context);
+
+ // Mark func.func, func.return, and func.call illegal if they contain any
+ // quantized types.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
+ target.addDynamicallyLegalOp<func::CallOp>(
+ [&](func::CallOp op) { return typeConverter.isLegal(op); });
+
+ // Register conversion patterns
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+
+ // Apply conversion
+ if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
+
diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index 8c69729824691c..fb27640bfd2784 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/FakeQuantSupport.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
using namespace mlir;
using namespace mlir::quant;
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index 408701f80444a1..62c7a7128d63a6 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/UniformSupport.h"
+#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
#include "mlir/IR/BuiltinTypes.h"
#include <numeric>
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 03876a7c64d07c..c62942e1be78e2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -11,7 +11,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6dce3d03066c9a..7f740be4efb4ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -14,7 +14,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
diff --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir
index 36c3eaf5e10d20..73c57e2a48212a 100644
--- a/mlir/test/Dialect/Quant/canonicalize.mlir
+++ b/mlir/test/Dialect/Quant/canonicalize.mlir
@@ -1,24 +1,124 @@
// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
+// CHECK-LABEL: @dcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+ %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
+
// -----
-// CHECK-LABEL: redundant_scast
-func.func @redundant_scast() -> tensor<4xi8> {
- // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8>
- // CHECK-NEXT: return
- %cst = arith.constant dense<5> : tensor<4xi8>
- %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
- %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<4xi8>
- %3 = arith.addi %2, %2 : tensor<4xi8>
- return %3 : tensor<4xi8>
+
+// CHECK-LABEL: @dcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> {
+ %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias>
+ %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32>
+ return %1 : tensor<4xf32>
}
// -----
-// CHECK-LABEL: non_redundant_scast
-func.func @non_redundant_scast() -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>> {
- // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8>
- // CHECK-NEXT: scast
- // CHECK-NEXT: return
- %cst = arith.constant dense<5> : tensor<4xi8>
- %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
- return %1 : tensor<4x!quant.uniform<u8:f32, 7.812500e-03:128>>
+
+// CHECK-LABEL: @qcast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+ %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+ %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+ return %1 : tensor<4x!qalias>
}
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> {
+ %0 = arith.negf %arg0 : tensor<4xf32>
+ %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias>
+ return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+ %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32>
+ %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1>
+ return %1 : tensor<4x!qalias1>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_fold
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: return %[[ARG_0]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> {
+ %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+ %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias>
+ return %1 : tensor<4x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_source
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]]
+// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]]
+// CHECK: return %[[SCAST]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> {
+ %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias>
+ %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @scast_no_fold_type
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]]
+// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]]
+// CHECK: return %[[VAL_1]]
+
+!qalias = !quant.uniform<u8:f32, 2.0:128>
+!qalias1 = !quant.uniform<u8:f32, 3.0:128>
+func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> {
+ %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8>
+ %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1>
+ return %1 : tensor<4x!qalias1>
+}
+
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
new file mode 100644
index 00000000000000..ba3a8e312d96e9
--- /dev/null
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -0,0 +1,258 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func.func @dcast_invalid_input(%arg0: f32) {
+ // expected-error at +1 {{operand #0 must be scalar or tensor of quantized type}}
+ %0 = quant.dcast %arg0 : f32 to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_invalid_result(%arg0: !qalias) {
+ // expected-error at +1 {{result #0 must be scalar or tensor of floating-point}}
+ %0 = quant.dcast %arg0 : !qalias to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.dcast %arg0 : !qalias to tensor<f32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor<!qalias>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<*xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<?x3xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_float_type_mismatch(%arg0: !qalias) {
+ // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+ %0 = quant.dcast %arg0 : !qalias to f64
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_scalar(%arg0: !qalias) {
+ // expected-error at +1 {{scalar types may not use per-axis quantization}}
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) {
+ // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+ %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) {
+ // expected-error at +1 {{quantized dimension size does not match number of scales}}
+ %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32>
+ return
+}
+
+// -----
+
+func.func @qcast_invalid_input(%arg0: f32) {
+ // expected-error at +1 {{result #0 must be scalar or tensor of quantized type}}
+ %0 = quant.qcast %arg0 : f32 to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_invalid_result(%arg0: !qalias) {
+ // expected-error at +1 {{operand #0 must be scalar or tensor of floating-point}}
+ %0 = quant.qcast %arg0 : !qalias to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_scalar_tensor(%arg0: tensor<f32>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.qcast %arg0 : tensor<f32> to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor<f32>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.qcast %arg0 : tensor<f32> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<?x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_float_type_mismatch(%arg0: f64) {
+ // expected-error at +1 {{expressed type in quantized type expected to match float type}}
+ %0 = quant.qcast %arg0 : f64 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_scalar(%arg0: f32) {
+ // expected-error at +1 {{scalar types may not use per-axis quantization}}
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) {
+ // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+ %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) {
+ // expected-error at +1 {{quantized dimension size does not match number of scales}}
+ %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_input(%arg0: si32) {
+ // expected-error at +1 {{operand #0 must be scalar or tensor of signless integer or quantized type}}
+ %0 = quant.scast %arg0 : si32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_invalid_result(%arg0: !qalias) {
+ // expected-error at +1 {{result #0 must be scalar or tensor of signless integer or quantized type}}
+ %0 = quant.scast %arg0 : !qalias to si32
+ return
+}
+
+// -----
+
+func.func @scast_both_integers(%arg0: i8) {
+ // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+ %0 = quant.scast %arg0 : i8 to i8
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_both_quantized(%arg0: !qalias) {
+ // expected-error at +1 {{input must be integer and result must be quantized, or vice versa}}
+ %0 = quant.scast %arg0 : !qalias to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_scalar_tensor(%arg0: tensor<i8>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.scast %arg0 : tensor<i8> to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor<i8>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.scast %arg0 : tensor<i8> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) {
+ // expected-error at +1 {{input and result are both scalars or both tensors with matching shape}}
+ %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<?x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_integer_type_mismatch(%arg0: i32) {
+ // expected-error at +1 {{storage type in quantized type expected to match integer type}}
+ %0 = quant.scast %arg0 : i32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_scalar(%arg0: i8) {
+ // expected-error at +1 {{scalar types may not use per-axis quantization}}
+ %0 = quant.scast %arg0 : i8 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) {
+ // expected-error at +1 {{quantized dimension must be less than tensor rank}}
+ %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) {
+ // expected-error at +1 {{quantized dimension size does not match number of scales}}
+ %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias>
+ return
+}
+
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
new file mode 100644
index 00000000000000..6bba9f5c037727
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -0,0 +1,511 @@
+// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @dcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<i8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_scalar_unsigned
+// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform<u8:f32, 2.000000e+00:10> to i8
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: return %[[EXPRESSED]] : f32
+
+!qalias = !quant.uniform<u8:f32, 2.0:10>
+func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<i8>
+
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<i8> to tensor<f32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<f32>
+// CHECK: return %[[EXPRESSED]] : tensor<f32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_0d(%arg0: tensor<!qalias>) -> tensor<f32> {
+ %0 = quant.dcast %arg0 : tensor<!qalias> to tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<3x?x5xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> {
+ %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32>
+ return %0 : tensor<3x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>> to tensor<*xi8>
+// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor<?xi8>
+// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?xi8>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<?xi8> to tensor<?xf32>
+
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>> to tensor<4x?x?x5xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+// CHECK: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<4x?x?x5xf32>
+// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> {
+ %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32>
+ return %0 : tensor<4x?x?x5xf32>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @dcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor
+
+// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>> to tensor<*xi8>
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index
+// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor<?x3x?xi8>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor<?x3x?xi8>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor<?x3x?xi8>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xi8>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xf32>) {
+// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32):
+// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32
+// CHECK: linalg.yield %[[EXPRESSED]] : f32
+// CHECK: } -> tensor<?x3x?xf32>
+
+// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform<i8:f32, 2.000000e+00:10>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8:f32, 2.000000e+00:10>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<i8<-5:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<i8<-5:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: f32
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8
+
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8
+
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8
+// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+// CHECK: return %[[STORED_QUANT]] : !quant.uniform<u8<2:10>:f32, 2.000000e+00>
+
+!qalias = !quant.uniform<u8<2:10>:f32, 2.0>
+func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return %0 : !qalias
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_0d
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor<f32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<f32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<i8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<i8> to tensor<f32>
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<f32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<f32> to tensor<i8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<i8> to tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_0d(%arg0: tensor<f32>) -> tensor<!qalias> {
+ %0 = quant.qcast %arg0 : tensor<f32> to tensor<!qalias>
+ return %0 : tensor<!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32>
+
+// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8>
+// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias>
+ return %0 : tensor<3x?x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32>
+
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8>
+
+// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8>
+// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8>
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform<i8<-8:7>:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32, 2.0:10>
+func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias>
+ return %0 : tensor<3x5x!qalias>
+}
+
+// -----
+
+// CHECK-LABEL: @qcast_per_layer_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor<?xf32>
+// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor<?xf32>
+// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor<?xf32>
+
+// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor<?xi8>
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<?xi8> to tensor<?xf32>
+// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<?xf32>
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<?xf32> to tensor<?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor<?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32, 2.000000e+00:10>>
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return %0 : tensor<*x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8>
+
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32>
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<4x?x?x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform<i8:f32:1, {2.000000e+00:10,3.000000e+00:20}>>
+
+!qalias = !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
+func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias>
+ return %0 : tensor<4x?x?x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8>
+
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: %[[C_NEG_8:.*]] = arith.constant -8 : i8
+// CHECK: %[[C_7:.*]] = arith.constant 7 : i8
+// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8
+// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8
+// CHECK: linalg.yield %[[STORED_CLAMPED]] : i8
+// CHECK: } -> tensor<4x2x5xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform<i8<-8:7>:f32:1, {2.000000e+00,3.000000e+00}>>
+
+!qalias = !quant.uniform<i8<-8:7>:f32:1, {2.0, 3.0}>
+func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias>
+ return %0 : tensor<4x2x5x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: @qcast_per_channel_unranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>
+
+// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index
+// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index
+// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor<?xindex> -> index
+// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor<?xindex> -> index
+
+// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index
+// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex>
+// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x3x?xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8>
+
+// CHECK: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor<?x3x?xf32>
+// CHECK: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor<?x3x?xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor<?x3x?xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<?x3x?xf32>, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor<?x3x?xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<?x3x?xi8>
+
+// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor<?x3x?xi8>, tensor<?xindex>) -> tensor<*xi8>
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform<i8:f32:2, {2.000000e+00:10,3.000000e+00:20,4.000000e+00:30}>>
+
+!qalias = !quant.uniform<i8:f32:2, {2.0:10, 3.0:20, 4.0:30}>
+func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return %0 : tensor<*x!qalias>
+}
+
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
new file mode 100644
index 00000000000000..4abc5830d081e1
--- /dev/null
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -0,0 +1,151 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_scalar(%arg0: !qalias) {
+ %0 = quant.dcast %arg0 : !qalias to f32
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @dcast_unranked(%arg0: tensor<*x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_dynamic(%arg0: tensor<?x?x?x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<?x?x?x!qalias> to tensor<?x?x?xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) {
+ %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_scalar(%arg0: f32) {
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) {
+ %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @qcast_unranked(%arg0: tensor<*xf32>) {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) {
+ %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_dynamic(%arg0: tensor<?x?x?xf32>) {
+ %0 = quant.qcast %arg0 : tensor<?x?x?xf32> to tensor<?x?x?x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) {
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_scalar(%arg0: i8) {
+ %0 = quant.scast %arg0 : i8 to !qalias
+ %1 = quant.scast %0 : !qalias to i8
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) {
+ %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias>
+ %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32, 1.0>
+func.func @scast_unranked(%arg0: tensor<*xi8>) {
+ %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+ %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) {
+ %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias>
+ %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_dynamic(%arg0: tensor<?x?x?xi8>) {
+ %0 = quant.scast %arg0 : tensor<?x?x?xi8> to tensor<?x?x?x!qalias>
+ %1 = quant.scast %0 : tensor<?x?x?x!qalias> to tensor<?x?x?xi8>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<i8:f32:2, {1.0, 2.0, 3.0}>
+func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) {
+ %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias>
+ %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8>
+ return
+}
+
+
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index a82e8efdb1a3c3..7613a344cf2b8f 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -120,3 +120,28 @@
// provided.
// expected-error at +1 {{expected floating point literal}}
!qalias = !quant.uniform<i8<-4:3>:f32, {2.000000e+02,-19.987200e-01:1}>
+
+// -----
+// Illegal negative axis in per-axis quantization
+// expected-error at +1 {{illegal quantized dimension: -1}}
+!qalias = !quant.uniform<i8:f32:-1, {2.0,3.0:1}>
+
+// -----
+// Scale f16 underflow
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16, 5.8e-8>
+
+// -----
+// Scale f16 overflow
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16, 6.6e4>
+
+// -----
+// Scale f16 underflow in per-axis quantization
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16:1, {2.0,5.8e-8}>
+
+// -----
+// Scale f16 overflow in per-axis quantization
+// expected-error at +1 {{scale out of expressed type range}}
+!qalias = !quant.uniform<i8:f16:1, {2.0,6.6e4}>
diff --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
new file mode 100644
index 00000000000000..e5f0d4921bed3e
--- /dev/null
+++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @strip_operands
+// CHECK-SAME: %[[ARG_0:.*]]: i8
+// CHECK-SAME: %[[ARG_1:.*]]: i16
+// CHECK-SAME: %[[ARG_2:.*]]: f32
+
+// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}>
+// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}>
+
+// CHECK: "test.custom_op"(%[[ARG_1_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_0_CAST]])
+// CHECK: "test.custom_op"(%[[ARG_2]])
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) {
+ "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias>
+ "test.custom_op"(%arg1) : (!qalias1) -> tensor<?x!qalias1>
+ "test.custom_op"(%arg2) : (f32) -> tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_results
+// CHECK-SAME: tensor<4xi8>, tensor<?xi16>, tensor<*xi8>, tensor<4xf32>
+
+// CHECK: %[[RESULT_0:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[RESULT_1:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULT_2:.*]] = "test.custom_op"()
+// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8>
+
+// CHECK: %[[RESULT_3:.*]] = "test.custom_op"()
+
+// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]]
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func @strip_results() -> (tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>) {
+ %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+ %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+ %2 = "test.custom_op"() : () -> tensor<*x!qalias>
+ %3 = "test.custom_op"() : () -> tensor<4xf32>
+ return %0, %1, %2, %3 : tensor<4x!qalias>, tensor<?x!qalias1>, tensor<*x!qalias>, tensor<4xf32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @callee
+// CHECK-SAME: (tensor<4xi8>, tensor<?xi16>) -> (tensor<*xi8>, tensor<4xf32>)
+
+// CHECK-LABEL: @strip_call
+
+// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8>
+
+// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"()
+// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor<?x!quant.uniform<{{.*}}>> to tensor<?xi16>
+
+// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]])
+
+// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>>
+// CHECK: "test.custom_op"(%[[RESULT_0_CAST]])
+
+// CHECK: "test.custom_op"(%[[RESULTS]]#1)
+
+// CHECK: return
+
+!qalias = !quant.uniform<i8:f32, 2.0:128>
+!qalias1 = !quant.uniform<i16:f32, 3.0:128>
+
+func.func private @callee(tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+
+func.func @strip_call() {
+ %0 = "test.custom_op"() : () -> tensor<4x!qalias>
+ %1 = "test.custom_op"() : () -> tensor<?x!qalias1>
+ %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor<?x!qalias1>) -> (tensor<*x!qalias>, tensor<4xf32>)
+ "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> ()
+ "test.custom_op"(%2#1) : (tensor<4xf32>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list