[Mlir-commits] [mlir] Sub-channel quantized type implementation (PR #120172)
Sandeep Dasgupta
llvmlistbot at llvm.org
Tue Dec 17 10:18:37 PST 2024
https://github.com/sdasgup3 updated https://github.com/llvm/llvm-project/pull/120172
>From 79c2cc7ac4483fb9e4c7586492cd69f5eb103110 Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Sun, 15 Dec 2024 12:13:43 +0000
Subject: [PATCH 1/4] Add implementation for sub-channel type
(design/assembly/verification/bytecode)
---
mlir/include/mlir-c/Dialect/Quant.h | 41 +++
.../mlir/Dialect/Quant/IR/QuantBase.td | 192 ++++++++++-
.../Dialect/Quant/IR/QuantDialectBytecode.td | 30 +-
.../mlir/Dialect/Quant/IR/QuantTypes.h | 131 ++++++++
mlir/lib/Bindings/Python/DialectQuant.cpp | 73 ++++
mlir/lib/CAPI/Dialect/Quant.cpp | 55 +++
.../Dialect/Quant/IR/QuantDialectBytecode.cpp | 1 +
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 147 ++++++--
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 121 ++++++-
mlir/lib/Dialect/Quant/IR/TypeDetail.h | 122 +++++++
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 318 +++++++++++++++---
.../Quant/Transforms/LowerQuantOps.cpp | 67 ++++
mlir/test/CAPI/quant.c | 124 +++++++
mlir/test/Dialect/Quant/Bytecode/types.mlir | 9 +
mlir/test/Dialect/Quant/invalid.mlir | 68 ++++
mlir/test/Dialect/Quant/ops.mlir | 19 ++
.../Dialect/Quant/parse-uniform-invalid.mlir | 100 +++++-
mlir/test/Dialect/Quant/parse-uniform.mlir | 18 +
18 files changed, 1547 insertions(+), 89 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h
index a7d98dc3c1a775..dc0989e53344ea 100644
--- a/mlir/include/mlir-c/Dialect/Quant.h
+++ b/mlir/include/mlir-c/Dialect/Quant.h
@@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
MLIR_CAPI_EXPORTED bool
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
+//===---------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===---------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is a UniformQuantizedSubChannel.
+MLIR_CAPI_EXPORTED bool
+mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
+
+/// Creates a UniformQuantizedSubChannelType with the given parameters.
+///
+/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
+/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes`
+/// point to `blockSizeInfoLength` number of elements, describing respectively
+/// the quantization axis and corresponding block size.
+MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
+ unsigned flags, MlirType storageType, MlirType expressedType,
+ MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
+ intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
+ int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
+
+/// Returns the number of block sizes provided in type.
+MLIR_CAPI_EXPORTED intptr_t
+mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
+
+/// Returns the quantized dimension at the given position.
+MLIR_CAPI_EXPORTED int32_t
+mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
+ intptr_t pos);
+
+/// Returns the block size at the given position.
+MLIR_CAPI_EXPORTED int64_t
+mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);
+
+/// Returns the scales of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);
+
+/// Returns the zero-points of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
+
//===---------------------------------------------------------------------===//
// CalibratedQuantizedType
//===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index 791cb9de48d058..0d97889960019c 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
encodes the necessary information for (lossy) round-trip conversion between
an expressed and a stored value.
- The `quant.uniform` type has two variants: per-layer quantization and
- per-channel (or per-axis) quantization. In per-layer quantization, the
- quantization information affects an entire tensor uniformly. Conversely, in
- per-channel quantization, the data type encodes the specific tensor axis
- that serves as the channel and includes quantization information for each
- individual channel within the tensor. Below are the specific syntactic and
- semantic considerations for each modality.
+ The `quant.uniform` type has three variants: per-layer quantization,
+ per-channel (or per-axis) quantization, and sub-channel (or blockwize)
+ quantization. In per-layer quantization, the quantization information
+ affects an entire tensor uniformly. Conversely, in per-channel
+ quantization, the data type encodes the specific tensor axis that serves
+ as the channel and includes quantization information for each individual
+ channel within the tensor. Sub-channel quantization is a generalization
+ of per-tensor and per-channel quantization, where the quantization
+ parameters are defined for blocks of elements along one or more
+ dimensions of the tensor. Below are the specific syntactic and semantic
+ considerations for each modality.
### Per-layer quantization
@@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
```
// A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
// floats. Dimension 1 of the tensor acts as the channel dimension. Its
- // size 3 matches the number of provided scale values. Tensor elemenets at
+ // size 3 matches the number of provided scale values. Tensor elements at
// positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
// 5.0, respectively.
tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
@@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
```
+ ### Sub-channel quantization
+
+ Sub-channel quantization, also known as blockwise quantization, provides
+ finer-grained control than per-tensor or per-channel quantization. It
+ divides a tensor into blocks of elements, each with its own quantization
+ parameters (scale and zero point). This is particularly useful when
+ different regions of a tensor exhibit distinct value ranges.
+
+ The `!quant.uniform` type represents sub-channel quantization with the
+ following syntax:
+
+ ```
+ `!quant.uniform` `<`
+ storedType (`<` storageMin `:` storageMax `>`)? `:`
+ expressedType `:` blockSizeInfo
+ scaleZeroTensor `>`
+
+ blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
+ axisBlock ::= axis `:` blockSize
+ scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
+ scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
+ scaleZeroList ::= scaleZero (`,` scaleZero)*
+ scaleZero ::= scale (`:` zeroPoint)?
+
+ scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
+ scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
+ scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
+ ```
+
+ The `blockSize` field specifies the size of the blocks along dimension
+ `axis` of the tensor. The `scale` and `zeroPoint` fields specify the
+ quantization parameters for a particular block. Specifically, the tensor
+ element at position [i0...iN] uses
+ `scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
+ `scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
+ and zeroPoint respectively.
+
+ Here are some examples:
+
+ ```
+ // A 3x4 tensor of i8 values representing f32 values, quantized
+ // along axis-0 and axis-1 with block sizes 1 and 2,
+ // respectively. As a result, the shape of the scales (or zero-points) will
+ // be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
+ // blocks along each axis. Tensor elements at positions
+ // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+ // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+ // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+ // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+ // [2][0] and [2][1] use scale `s20` and zero point `z20`,
+ // [2][2] and [2][3] use scale `s21` and zero point `z21`,
+ tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
+ {{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>
+
+ // A 2D dynamically sized tensor contains u16 values
+ // representing f32 values. Since the shape of the quantization
+ // parameters (i.e. scales and zero-points) is given as [2,2] and
+ // the blocks-sizes are given as [1,2], the shape of the tensor is expected
+ // to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
+ // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+ // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+ // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+ // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+ tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
+ {{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
+ ```
## Per-axis quantization integrity
@@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
respected in any context in which the `!quant.uniform` data type is used,
such as the header of a `func.func` op, or the input of an arithmetic
operation.
-
+
- A quantized type with per-channel quantization information must be the
element type of a tensor container type, and may not occur directly as
the data type of a scalar value.
@@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
// Correct. The quantized type now includes 3 scale values, matching the
// size of dimension 1 of the result tensor.
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+
+ ## Sub-channel quantization integrity
+
+ When type `!quant.uniform` contains sub-channel quantization information,
+ the following rules are enforced. For efficiency, these rules are actively
+ enforced by the verifiers of `quant` dialect ops, but they must be
+ respected in any context in which the `!quant.uniform` data type is used,
+ such as the header of a `func.func` op, or the input of an arithmetic
+ operation.
+
+ - A quantized type with sub-channel quantization information must be the
+ element type of a tensor container type, and may not occur directly as
+ the data type of a scalar value.
+
+ ```
+ // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+ // scalar type.
+ %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
+
+ // Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
+ // in a `tensor` type.
+ %result = quant.qcast %input : tensor<2x2xf32> to
+ tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+ ```
+
+ - The tensor containing the sub-channel quantized type must be ranked.
+
+ ```
+ // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+ // unranked tensor type.
+ %result = quant.qcast %input : tensor<*xf32> to
+ tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+ ```
+
+ - The axis for which a block size is specified should be valid for a tensor
+ of a given rank. Block sizes can be specified for a subset of axes.
+ Any unspecified block size for an axis i defaults to the tensor dimension
+ size of that axis (shape(tensor)[i]).
+
+ ```
+ // Incorrect. The block-size is specified for axis 2 which is greater than
+ // the rank of the tensor.
+ %result = quant.qcast %input : tensor<2x2xf32> to
+ tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
+
+ // Incorrect. The block-size is specified for a negative axis.
+ %result = quant.qcast %input : tensor<2x2xf32> to
+ tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
+
+ // Correct. The block size for axis 1 is skipped which should be assumed as
+ // 2, the dim-size of tensor at axis 1.
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
+
+ // Correct. The block size for all the axes are skipped making the
+ // sub-channel type essentially a per-tensor type.
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
+ ```
+
+ - Block size for a particular axis should be a positive integer and should
+ be less than the dimension size of the tensor along that axis.
+
+ ```
+ // Incorrect. The block size for axis 0 is -1.
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
+
+ // Incorrect. The block size for axis 0 is 8 which is greater than the
+ // dimension size of tensor at axis 0 (which is 6).
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
+
+ // Correct. The block size for axis 0 is now 3.
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+ ```
+
+ - shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
+ axis i in [0, 1, ..., rank(tensor)-1]].
+
+ ```
+ // Incorrect. The block size for axis 0 is 4 and the corresponding
+ // dimension size is 6 and 6 % 4 != 0.
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
+
+ // Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+ ```
+
+ - shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
+
+ ```
+ // Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
+ // shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
+
+ // Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
+ // shape(scales) equals [6,2]/[3,2].
+ %result = quant.qcast %input : tensor<6x2xf32> to
+ tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
```
}];
let cppNamespace = "::mlir::quant";
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
index bd9cdf82382275..8c74dbef5d94a3 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
@@ -13,6 +13,7 @@
#ifndef QUANT_BYTECODE
#define QUANT_BYTECODE
+include "mlir/IR/BuiltinDialectBytecode.td"
include "mlir/IR/BytecodeBase.td"
def DoubleAPFloat:
@@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type
}];
}
+def UniformQuantizedSubChannelType
+ : DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
+ SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
+ Array<SignedVarIntList>:$quantizedDimensions,
+ Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
+ DenseElementsAttr:$zeroPoints)> {
+ // Note: builder order differs from bytecode.
+ let cBuilder = [{
+ get<$_resultType>(context, flags, storageType, expressedType, scales,
+ zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
+ [](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
+ storageTypeMin, storageTypeMax)
+ }];
+}
+
/// This enum contains marker codes used to indicate which attribute is
/// currently being decoded, and how it should be decoded. The order of these
/// codes should generally be unchanged, as any changes will inevitably break
/// compatibility with older bytecode.
def QuantDialectTypes : DialectTypes<"Quant"> {
- let elems = [
- ReservedOrDead,
- AnyQuantizedType,
- AnyQuantizedTypeWithExpressedType,
- CalibratedQuantizedType,
- UniformQuantizedType,
- UniformQuantizedPerAxisType
- ];
+ let elems = [ReservedOrDead, AnyQuantizedType,
+ AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
+ UniformQuantizedType, UniformQuantizedPerAxisType,
+ UniformQuantizedSubChannelType];
}
-#endif // QUANT_BYTECODE
\ No newline at end of file
+#endif // QUANT_BYTECODE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 43440ba623b9c1..44062fe376ec0d 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -23,6 +23,7 @@ namespace detail {
struct QuantizedTypeStorage;
struct AnyQuantizedTypeStorage;
+struct UniformQuantizedSubChannelTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
struct CalibratedQuantizedTypeStorage;
@@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType
}
};
+/// Represents sub-channel (also known as blockwise quantization).
+///
+/// Syntax synopsis:
+/// UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
+/// storageType ('<' storageMin ':' storageMax '>')? ':'
+/// expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
+/// BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
+/// AxisBlock ::= AxisSpec ':' BlockSizeSpec
+/// ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
+/// ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
+/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
+/// ScaleZero ::= Scale (':' ZeroPoint)?
+///
+/// StorageType: 'i'|'u' NumBits
+/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+/// AxisSpec: An integer value
+/// BlockSizeSpec: An integer value
+/// Scale: An attribute (usually floating-point value)
+/// ZeroPoint: An attribute (usually integer value)
+class UniformQuantizedSubChannelType
+ : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
+ detail::UniformQuantizedSubChannelTypeStorage> {
+public:
+ using Base::Base;
+ using Base::getChecked;
+
+ static constexpr StringLiteral name = "quant.uniform_sub_channel";
+
+ /// Gets an instance of the type with all parameters specified but not
+ /// checked.
+ static UniformQuantizedSubChannelType
+ get(unsigned flags, Type storageType, Type expressedType,
+ DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+ int64_t storageTypeMin, int64_t storageTypeMax);
+
+ /// Gets an instance of the type with all specified parameters checked.
+ /// Returns a nullptr convertible type on failure.
+ static UniformQuantizedSubChannelType
+ getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, DenseElementsAttr scales,
+ DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax);
+
+ /// Verifies construction invariants and issues errors/warnings.
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType,
+ DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax);
+
+ /// Gets the quantization scales. The scales are organized in a
+ /// multi-dimensional tensor. The size of each dimension in the scales tensor
+ /// is determined by the number of blocks along the corresponding dimension in
+ /// the quantized data tensor.
+ ///
+ /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+ /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
+ /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+ ///
+ /// The scale value for a specific element in the quantized data tensor at
+ /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
+ /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
+ DenseElementsAttr getScales() const;
+
+ /// Gets the quantization zero-points. The zero-points are organized in a
+ /// multi-dimensional tensor. The size of each dimension in the zero-point
+ /// tensor is determined by the number of blocks along the corresponding
+ /// dimension in the quantized data tensor.
+ ///
+ /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+ /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
+ /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+ ///
+ /// The zero-point value for a specific element in the quantized data tensor
+ /// at position [i0, i1, ..., iR-1] is determined by accessing the
+ /// corresponding element in the zero-point tensor at position [i0/B0, i1/B1,
+ /// ..., iR-1/BR-1].
+ DenseElementsAttr getZeroPoints() const;
+
+ /// Gets the quantized dimensions. Each element in the returned list
+ /// represents an axis of the quantized data tensor that has a specified block
+ /// size. The order of elements corresponds to the order of block sizes
+ /// returned by `getBlockSizes()`.
+ ///
+ /// It means that the data tensor is quantized along the `i`-th dimension in
+ /// the returned list using the `i`-th block size from `getBlockSizes()`.
+ ///
+ /// Note that the type expression does not have to specify the block size for
+ /// all axes in the data tensor. Any unspecified block size for an axis `i`
+ /// defaults to the tensor dimension size of that axis.
+ ///
+ /// For example, for a quantized type:
+ /// `tensor<8x4x2x!quant.uniform<i8:f32:{1:2, 0:8}, {{1.0, 2.0}, {3.0, 4.0}}>`
+ ///
+ /// `getQuantizedDimensions()` returns [1, 0].
+ /// `getBlockSizes()` returns [2, 8].
+ ///
+ /// This indicates that:
+ /// * Axis 1 (second dimension) is quantized with a block size of 2.
+ /// * Axis 0 (first dimension) is quantized with a block size of 8.
+ /// Since axis 2 is not specified, it implicitly has a block size equal to
+ /// the size of the third dimension (which is 2 in this case).
+ ArrayRef<int32_t> getQuantizedDimensions() const;
+
+ /// Gets the block sizes for the quantized dimensions. The `i`-th element in
+ /// the returned list corresponds to the block size for the `i`-th dimension
+ /// in the list returned by `getQuantizedDimensions()`.
+ ///
+ /// See `getQuantizedDimensions()` for more details and examples.
+ ArrayRef<int64_t> getBlockSizes() const;
+
+ /// Gets the block size information. This returns a list of pairs, where each
+ /// pair represents a quantized dimension and its corresponding block size.
+ ///
+ /// For example, for the type:
+ /// `tensor<8x4x!quant.uniform<i8:f32:{1:2, 0:8}, {{2.0, 3.0}}>`
+ ///
+ /// This method returns:
+ /// `[(1, 2), (0, 8)]`
+ ///
+ /// This list indicates that axis 1 has a block size of 2, and axis 0 has a
+ /// block size of 8.
+ const SmallVector<std::pair<int32_t, int64_t>> getBlockSizeInfo() const;
+};
+
/// A quantized type that infers its range from given min/max values.
///
/// Typical syntax:
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 9a871f2c122d12..44a596caa24a65 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -285,6 +285,79 @@ static void populateDialectQuantSubmodule(const py::module &m) {
},
"Fixed point values are real numbers divided by a scale.");
+ //===-------------------------------------------------------------------===//
+ // UniformQuantizedSubChannelType
+ //===-------------------------------------------------------------------===//
+ auto uniformQuantizedSubChannelType = mlir_type_subclass(
+ m, "UniformQuantizedSubChannelType",
+ mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
+ uniformQuantizedSubChannelType.def_classmethod(
+ "get",
+ [](py::object cls, unsigned flags, MlirType storageType,
+ MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
+ std::vector<int32_t> quantizedDimensions,
+ std::vector<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ return cls(mlirUniformQuantizedSubChannelTypeGet(
+ flags, storageType, expressedType, scales, zeroPoints,
+ static_cast<intptr_t>(blockSizes.size()),
+ quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
+ storageTypeMax));
+ },
+ "Gets an instance of UniformQuantizedSubChannel in the same context as "
+ "the provided storage type.",
+ py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
+ py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
+ py::arg("quantized_dimensions"), py::arg("block_sizes"),
+ py::arg("storage_type_min"), py::arg("storage_type_max"));
+ uniformQuantizedSubChannelType.def_property_readonly(
+ "quantized_dimensions",
+ [](MlirType type) {
+ intptr_t nDim =
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+ std::vector<int32_t> quantizedDimensions;
+ quantizedDimensions.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ quantizedDimensions.push_back(
+ mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
+ }
+ return quantizedDimensions;
+ },
+ "Gets the quantized dimensions. Each element in the returned list "
+ "represents an axis of the quantized data tensor that has a specified "
+ "block size. The order of elements corresponds to the order of block "
+ "sizes returned by 'block_sizes' method. It means that the data tensor "
+ "is quantized along the i-th dimension in the returned list using the "
+ "i-th block size from block_sizes method.");
+ uniformQuantizedSubChannelType.def_property_readonly(
+ "block_sizes",
+ [](MlirType type) {
+ intptr_t nDim =
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+ std::vector<int64_t> blockSizes;
+ blockSizes.reserve(nDim);
+ for (intptr_t i = 0; i < nDim; ++i) {
+ blockSizes.push_back(
+ mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
+ }
+ return blockSizes;
+ },
+ "Gets the block sizes for the quantized dimensions. The i-th element in "
+ "the returned list corresponds to the block size for the i-th dimension "
+ "in the list returned by quantized_dimensions method.");
+ uniformQuantizedSubChannelType.def_property_readonly(
+ "scales",
+ [](MlirType type) -> MlirAttribute {
+ return mlirUniformQuantizedSubChannelTypeGetScales(type);
+ },
+ "The scales of the quantized type.");
+ uniformQuantizedSubChannelType.def_property_readonly(
+ "zero_points",
+ [](MlirType type) -> MlirAttribute {
+ return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+ },
+ "The zero points of the quantized type.");
+
//===-------------------------------------------------------------------===//
// CalibratedQuantizedType
//===-------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index c94dbb5692fdb0..88648497895ab7 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -194,6 +194,61 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
}
+//===---------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
+ return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
+}
+
+MlirType mlirUniformQuantizedSubChannelTypeGet(
+ unsigned flags, MlirType storageType, MlirType expressedType,
+ MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
+ int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ auto scales = dyn_cast<mlir::DenseElementsAttr>(unwrap(scalesAttr));
+ auto zeroPoints = dyn_cast<mlir::DenseElementsAttr>(unwrap(zeroPointsAttr));
+
+ if (!scales || !zeroPoints) {
+ return {};
+ }
+
+ return wrap(quant::UniformQuantizedSubChannelType::get(
+ flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints,
+ llvm::ArrayRef<int32_t>(quantizedDimensions, nDims),
+ llvm::ArrayRef<int64_t>(blockSizes, nDims), storageTypeMin,
+ storageTypeMax));
+}
+
+intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) {
+ return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+ .getBlockSizes()
+ .size();
+}
+
+int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
+ intptr_t pos) {
+ return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+ .getQuantizedDimensions()[pos];
+}
+
+int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type,
+ intptr_t pos) {
+ return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+ .getBlockSizes()[pos];
+}
+
+MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) {
+ return wrap(
+ cast<quant::UniformQuantizedSubChannelType>(unwrap(type)).getScales());
+}
+
+MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) {
+ return wrap(cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+ .getZeroPoints());
+}
+
//===---------------------------------------------------------------------===//
// CalibratedQuantizedType
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index 6a4ac310eb0524..44ec0c517d5611 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c584903f3a15de..683aa26a2d0621 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
-
namespace mlir {
namespace quant {
@@ -25,22 +24,17 @@ namespace {
// Verify the integrity of per-axis quantization information, if present.
//
-// - quantizedType
-// Any quantized type. Any quantized type with no per-axis quantization is
-// ignored.
+// - uniformQuantizedPerAxisType
+// A quantized type with per-axis quantization.
//
// - containerType
// Original input or result type of the operation using the provided quantized
// type. Used to ensure that the quantized type appears within a tensor and
// that the tensor is compatible with per-axis quantization information.
//
-LogicalResult verifyPerAxisQuantization(Operation *op,
- QuantizedType quantizedType,
- Type containerType) {
- auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
- if (!quantizedPerAxisType)
- return success();
-
+LogicalResult verifyPerAxisQuantization(
+ Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
+ Type containerType) {
auto tensorType = dyn_cast<TensorType>(containerType);
if (!tensorType)
return op->emitError("scalar types may not use per-axis quantization");
@@ -48,19 +42,112 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
if (!tensorType.hasRank())
return success();
- int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
- if (quantizedDimension >= tensorType.getRank())
+ int32_t quantizedDimension =
+ uniformQuantizedPerAxisType.getQuantizedDimension();
+ if ((int64_t)quantizedDimension >= tensorType.getRank())
return op->emitError("quantized dimension must be less than tensor rank");
int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
if (quantizedDimensionSize != ShapedType::kDynamic &&
- quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+ quantizedDimensionSize !=
+ (int64_t)uniformQuantizedPerAxisType.getScales().size())
return op->emitError(
"quantized dimension size does not match number of scales");
return success();
}
+// Verifies that the sub-channel quantization parameters are consistent with
+// the given container type. The function checks the following:
+//
+// - The container type must be a ranked tensor type.
+// - Each quantized dimension must be less than the rank of the tensor.
+// - The size of each dimension at the quantized dimension must be divisible
+// by the corresponding block size.
+// - The scale dimension size at each axis index should match the tensor
+// dimension at the index divided by the corresponding block size.
+//
+// The `uniformQuantizedSubChannelType` argument provides the sub-channel
+// quantization parameters, and the `containerType` argument specifies the
+// type of the container holding the quantized data.
+//
+LogicalResult verifySubChannelQuantization(
+ Operation *op,
+ UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
+ Type containerType) {
+ auto tensorType = dyn_cast<TensorType>(containerType);
+ if (!tensorType)
+ return op->emitError("scalar types may not use sub-channel quantization");
+
+ if (!tensorType.hasRank())
+ return op->emitError(
+ "tensor containing the sub-channel quantized type must be ranked");
+
+ const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
+ uniformQuantizedSubChannelType.getBlockSizeInfo();
+ auto shape = tensorType.getShape();
+
+ // The dimension size of scale for an axis which is not specified as quantized
+ // dimension should be 1.
+ SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
+ for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
+ if (quantizedDimension >= tensorType.getRank())
+ return op->emitError()
+ << "quantized dimension " << quantizedDimension
+ << " must be less than tensor rank " << tensorType.getRank();
+ if (!tensorType.isDynamicDim(quantizedDimension) &&
+ tensorType.getDimSize(quantizedDimension) % blockSize != 0)
+ return op->emitError()
+ << "tensor dimension size "
+ << tensorType.getDimSize(quantizedDimension) << " at axis "
+ << quantizedDimension
+ << " must be divisible by the corresponding block size "
+ << blockSize;
+ if (tensorType.isDynamicDim(quantizedDimension))
+ expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
+ else
+ expectedScaleShape[quantizedDimension] =
+ tensorType.getDimSize(quantizedDimension) / blockSize;
+ }
+
+ // Block sizes must be greater than 0 and divide the corresponding dimension
+ // size. While a block size b must be less than or equal to the corresponding
+ // dimension size d, this constraint is implicitly enforced by requiring that
+ // d % b == 0 when d != 0.
+ //
+ // However, a problem arises when d = 0. The divisibility constraint allows b
+ // to be any value, potentially violating the requirement that b <= d.
+ // Furthermore, if b is unspecified (implicitly equal to d), it violates the
+ // constraint that b > 0.
+ //
+ // Therefore, we explicitly disallow the case where d = 0 to maintain
+ // consistency and avoid these issues.
+ if (llvm::find(tensorType.getShape(), 0) != tensorType.getShape().end()) {
+ return op->emitError() << "tensor dimension size of zero is not allowed "
+ "with sub-channel quantization";
+ }
+
+ auto scaleShape =
+ uniformQuantizedSubChannelType.getScales().getType().getShape();
+ if (scaleShape.size() != shape.size()) {
+ return op->emitError() << "Rank of scales " << scaleShape.size()
+ << " must match "
+ << "the rank of the tensor " << shape.size();
+ }
+
+ for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
+ if (expectedScaleShape[index] != ShapedType::kDynamic &&
+ expectedScaleShape[index] != scaleShape[index])
+ return op->emitError() << "dimension size " << scaleDim
+ << " of scales tensor at axis " << index
+ << " should match (tensor dimension at axis / "
+ "block sizes at axis) = "
+ << expectedScaleShape[index];
+ }
+
+ return success();
+}
+
// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
//
// - quantizedType
@@ -81,11 +168,19 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
"expressed type in quantized type expected to match float type");
// Veriy integrity of per-axis quantization information, if present.
- return verifyPerAxisQuantization(op, quantizedType, containerType);
-}
+ if (auto quantizedPerAxisType =
+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
+ return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
+ } else if (auto quantizedSubChannelType =
+ dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
+ return verifySubChannelQuantization(op, quantizedSubChannelType,
+ containerType);
+ }
-} // namespace
+ return success();
+}
+} // namespace
//===----------------------------------------------------------------------===//
// Dialect
@@ -93,7 +188,7 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
- UniformQuantizedPerAxisType>();
+ UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
@@ -101,7 +196,6 @@ void QuantDialect::initialize() {
detail::addBytecodeInterface(this);
}
-
//===----------------------------------------------------------------------===//
// DequantizeCastOp
//===----------------------------------------------------------------------===//
@@ -130,7 +224,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() {
return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
}
-
//===----------------------------------------------------------------------===//
// QuantizeCastOp
//===----------------------------------------------------------------------===//
@@ -160,7 +253,6 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
}
-
//===----------------------------------------------------------------------===//
// StorageCastOp
//===----------------------------------------------------------------------===//
@@ -175,7 +267,16 @@ LogicalResult StorageCastOp::verify() {
// Verify integrity of per-axis quantization information, if available. While
// the quantization type may appear in the input or the result, their tensor
// shapes are guaranteed to be identical at this point.
- return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
+ if (auto quantizedPerAxisType =
+ dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
+ return verifyPerAxisQuantization(*this, quantizedPerAxisType,
+ getInput().getType());
+ else if (auto quantizedSunChannelType =
+ dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
+ return verifySubChannelQuantization(*this, quantizedSunChannelType,
+ getInput().getType());
+
+ return success();
}
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
@@ -205,10 +306,8 @@ QuantizedType StorageCastOp::getQuantizedType() {
return cast<QuantizedType>(resultScalarType);
}
-
} // namespace quant
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
-
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 7c0d3696486515..9b8eec609b0396 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) {
return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
}
-} // namespace
+} // namespace
unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(impl)->flags;
@@ -410,6 +410,123 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
return getImpl()->quantizedDimension;
}
+UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get(
+ unsigned flags, Type storageType, Type expressedType,
+ DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+ int64_t storageTypeMin, int64_t storageTypeMax) {
+ return Base::get(storageType.getContext(), flags, storageType, expressedType,
+ scales, zeroPoints, quantizedDimensions, blockSizes,
+ storageTypeMin, storageTypeMax);
+}
+
+UniformQuantizedSubChannelType UniformQuantizedSubChannelType::getChecked(
+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, DenseElementsAttr scales,
+ DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ return Base::getChecked(emitError, storageType.getContext(), flags,
+ storageType, expressedType, scales, zeroPoints,
+ quantizedDimensions, blockSizes, storageTypeMin,
+ storageTypeMax);
+}
+
+LogicalResult UniformQuantizedSubChannelType::verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, DenseElementsAttr scales,
+ DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
+ return failure();
+ }
+
+ // Uniform quantization requires fully expressed parameters, including
+ // expressed type.
+ if (!expressedType)
+ return emitError() << "uniform quantization requires expressed type";
+
+ // Verify that the expressed type is floating point.
+ // If this restriction is ever eliminated, the parser/printer must be
+ // extended.
+ if (!llvm::isa<FloatType>(expressedType))
+ return emitError() << "expressed type must be floating point";
+
+ // Verify scale type to match expressedType.
+ if (scales.getType().getElementType() != expressedType) {
+ return emitError() << "type of scale values "
+ << scales.getType().getElementType()
+ << " must match the expressed type " << expressedType;
+ }
+
+ // Verify zero-point type to match storageType.
+ if (zeroPoints.getType().getElementType() != storageType) {
+ return emitError() << "type of zero point values "
+ << zeroPoints.getType().getElementType()
+ << " must match the storage type " << storageType;
+ }
+
+ // Ensure that the shape of scales and zeroPoints match.
+ if (scales.getType().getShape() != zeroPoints.getType().getShape())
+ return emitError() << "shape of scales and zeroPoints ("
+ << scales.getType().getShape() << " vs "
+ << zeroPoints.getType().getShape() << ") does not match";
+
+ // Ensure that the number of quantized-dimensions and block-sizes match.
+ if (quantizedDimensions.size() != blockSizes.size())
+ return emitError() << "number of quantized dimensions and block sizes ("
+ << scales.size() << " vs " << zeroPoints.size()
+ << ") does not match";
+
+ // Verify quantized dimension.
+ for (auto quantizedDimension : quantizedDimensions) {
+ if (quantizedDimension < 0)
+ return emitError() << "illegal quantized dimension: "
+ << quantizedDimension;
+ }
+
+ // Verify block sizes.
+ for (auto blockSize : blockSizes) {
+ if (blockSize <= 0)
+ return emitError() << "illegal block size: " << blockSize;
+ }
+
+ return success();
+}
+
+DenseElementsAttr UniformQuantizedSubChannelType::getScales() const {
+ return getImpl()->getScales();
+}
+
+DenseElementsAttr UniformQuantizedSubChannelType::getZeroPoints() const {
+ return getImpl()->getZeroPoints();
+}
+
+ArrayRef<int32_t>
+UniformQuantizedSubChannelType::getQuantizedDimensions() const {
+ return getImpl()->getQuantizedDimensions();
+}
+
+ArrayRef<int64_t> UniformQuantizedSubChannelType::getBlockSizes() const {
+ return getImpl()->getBlockSizes();
+}
+
+const SmallVector<std::pair<int32_t, int64_t>>
+UniformQuantizedSubChannelType::getBlockSizeInfo() const {
+ SmallVector<std::pair<int32_t, int64_t>> result;
+ result.reserve(getQuantizedDimensions().size());
+
+ for (auto [dim, size] :
+ llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
+ result.push_back({dim, size});
+ }
+
+ return result;
+}
+
CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
double min, double max) {
return Base::get(expressedType.getContext(), expressedType, min, max);
diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h
index ef098811927cda..bb38b1a2a91e28 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h
@@ -9,6 +9,7 @@
#ifndef TYPE_DETAIL_H_
#define TYPE_DETAIL_H_
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
@@ -253,6 +254,127 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
int32_t quantizedDimension;
};
+struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage {
+ struct KeyTy {
+ KeyTy(unsigned flags, Type storageType, Type expressedType,
+ DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+ int64_t storageTypeMin, int64_t storageTypeMax)
+ : flags(flags), storageType(storageType), expressedType(expressedType),
+ scales(scales), zeroPoints(zeroPoints),
+ quantizedDimensions(quantizedDimensions), blockSizes(blockSizes),
+ storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+ /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+ unsigned flags;
+
+ // Integral type for the storage point representation.
+ Type storageType;
+
+ // Floating point type that the quantized type approximates.
+ Type expressedType;
+
+ DenseElementsAttr scales;
+ DenseElementsAttr zeroPoints;
+ ArrayRef<int32_t> quantizedDimensions;
+ ArrayRef<int64_t> blockSizes;
+ int64_t storageTypeMin;
+ int64_t storageTypeMax;
+
+ DenseElementsAttr getScales() const { return scales; }
+
+ DenseElementsAttr getZeroPoints() const { return zeroPoints; }
+
+ // Check for equality of two structures that share KeyTy data members
+ // (by name).
+ template <typename T, typename U>
+ static bool genericIsEqual(const T &lhs, const U &rhs) {
+ return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+ lhs.expressedType == rhs.expressedType &&
+ lhs.scales == rhs.scales && lhs.zeroPoints == rhs.zeroPoints &&
+ lhs.quantizedDimensions == rhs.quantizedDimensions &&
+ lhs.blockSizes == rhs.blockSizes &&
+ lhs.storageTypeMin == rhs.storageTypeMin &&
+ lhs.storageTypeMax == rhs.storageTypeMax;
+ }
+
+ bool operator==(const KeyTy &other) const {
+ return genericIsEqual(*this, other);
+ }
+
+ unsigned getHashValue() const {
+ // Hash the scalar attributes.
+ unsigned hash = llvm::hash_combine(flags, storageType, expressedType,
+ storageTypeMin, storageTypeMax);
+
+ // Hash the scales.
+ for (auto scaleAttr : scales.getValues<APFloat>()) {
+ hash = llvm::hash_combine(
+ hash, llvm::bit_cast<int64_t>(scaleAttr.convertToDouble()));
+ }
+
+ // Hash the zero points. (Assumed to be integers, adjust if needed).
+ for (auto zeroPointAttr : zeroPoints.getValues<APInt>()) {
+ hash = llvm::hash_combine(hash, zeroPointAttr.getSExtValue());
+ }
+
+ // Hash the quantized dimensions and block sizes.
+ hash = llvm::hash_combine(
+ hash,
+ llvm::hash_combine_range(quantizedDimensions.begin(),
+ quantizedDimensions.end()),
+ llvm::hash_combine_range(blockSizes.begin(), blockSizes.end()));
+
+ return hash;
+ }
+ };
+
+ // We pass scales and zeroPoints in directly rather than relying on KeyTy
+ // because we have to create new reallocated versions in `construct` below.
+ UniformQuantizedSubChannelTypeStorage(const KeyTy &key,
+ DenseElementsAttr scales,
+ DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes)
+ : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+ key.storageTypeMin, key.storageTypeMax),
+ scales(scales), zeroPoints(zeroPoints),
+ quantizedDimensions(quantizedDimensions), blockSizes(blockSizes) {}
+
+ bool operator==(const KeyTy &key) const {
+ return KeyTy::genericIsEqual(*this, key);
+ }
+
+ /// Construction.
+ static UniformQuantizedSubChannelTypeStorage *
+ construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+ DenseElementsAttr scales = key.scales;
+ DenseElementsAttr zeroPoints = key.zeroPoints;
+ ArrayRef<int32_t> quantizedDimensions =
+ allocator.copyInto(key.quantizedDimensions);
+ ArrayRef<int64_t> blockSizes = allocator.copyInto(key.blockSizes);
+ return new (allocator.allocate<UniformQuantizedSubChannelTypeStorage>())
+ UniformQuantizedSubChannelTypeStorage(key, scales, zeroPoints,
+ quantizedDimensions, blockSizes);
+ }
+
+ static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+ DenseElementsAttr getScales() const { return scales; }
+
+ DenseElementsAttr getZeroPoints() const { return zeroPoints; }
+
+ ArrayRef<int32_t> getQuantizedDimensions() const {
+ return quantizedDimensions;
+ }
+
+ ArrayRef<int64_t> getBlockSizes() const { return blockSizes; }
+
+ DenseElementsAttr scales;
+ DenseElementsAttr zeroPoints;
+ ArrayRef<int32_t> quantizedDimensions;
+ ArrayRef<int64_t> blockSizes;
+};
+
struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
struct KeyTy {
KeyTy(Type expressedType, double min, double max)
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 851763d8942e83..fb6f5d950acc9e 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -159,38 +159,173 @@ static Type parseAnyType(DialectAsmParser &parser) {
typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
}
-static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
+/// Checks if the given scale value is within the valid range of the expressed
+/// type. The `expressedType` argument is the floating-point type used for
+/// expressing the quantized values, and `scale` is the double value to check.
+LogicalResult
+isScaleInExpressedTypeRange(function_ref<InFlightDiagnostic()> emitError,
+ Type expressedType, double scale) {
+ auto floatType = cast<FloatType>(expressedType);
+ double minScale =
+ APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
+ double maxScale =
+ APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
+ if (scale < minScale || scale > maxScale)
+ return emitError() << "scale " << scale << " out of expressed type range ["
+ << minScale << ", " << maxScale << "]";
+ return success();
+}
+
+/// Parses a quantization parameter, which is either a scale value (float) or a
+/// scale-zero point pair (float:integer). `expressedType`, expressing the type
+/// of scale values, is used to validate the scale. The parsed scale and zero
+/// point (if any) are stored in `scale` and `zeroPoint`.
+static ParseResult parseQuantParams(DialectAsmParser &parser,
+ Type expressedType, double &scale,
int64_t &zeroPoint) {
- // scale[:zeroPoint]?
- // scale.
- if (parser.parseFloat(scale))
+
+ if (parser.parseFloat(scale)) {
return failure();
+ }
+
+ if (failed(isScaleInExpressedTypeRange(
+ [&]() { return parser.emitError(parser.getCurrentLocation()); },
+ expressedType, scale))) {
+ return failure();
+ }
- // zero point.
zeroPoint = 0;
if (failed(parser.parseOptionalColon())) {
- // Default zero point.
return success();
}
return parser.parseInteger(zeroPoint);
}
+/// Parses block size information for sub-channel quantization, assuming the
+/// leading '{' has already been parsed. The block size information is provided
+/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
+///
+/// The parsed axis indices are stored in `quantizedDimensions`, and the
+/// corresponding block sizes are stored in `blockSizes`.
+static ParseResult
+parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser,
+ SmallVectorImpl<int32_t> &quantizedDimensions,
+ SmallVectorImpl<int64_t> &blockSizes) {
+ // Empty block-sizes info.
+ if (succeeded(parser.parseOptionalRBrace())) {
+ return success();
+ }
+
+ auto parseBlockSizeElements = [&]() -> ParseResult {
+ quantizedDimensions.resize(quantizedDimensions.size() + 1);
+ blockSizes.resize(blockSizes.size() + 1);
+ if (parser.parseInteger(quantizedDimensions.back()) ||
+ parser.parseColon() || parser.parseInteger(blockSizes.back()))
+ return failure();
+ return success();
+ };
+
+ if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
+ parser.parseRBrace()) {
+ return failure();
+ }
+
+ return success();
+}
+
+/// Parses a bracketed list of quantization parameters, returning the dimensions
+/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
+/// to the dimensions of the sub-tensors. This function assumes that the initial
+/// left brace has already been parsed. For example:
+///
+/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
+/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
+///
+/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
+/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
+/// 9]
+///
+/// This function expects all sub-tensors to have the same rank.
+static ParseResult
+parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
+ SmallVectorImpl<double> &scales,
+ SmallVectorImpl<int64_t> &zeroPoints,
+ SmallVectorImpl<int64_t> &dims) {
+ auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
+ const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
+ if (prevDims == newDims)
+ return success();
+ return parser.emitError(parser.getCurrentLocation())
+ << "tensor literal is invalid; ranks are not consistent "
+ "between elements";
+ };
+
+ bool first = true;
+ SmallVector<int64_t, 4> newDims;
+ unsigned size = 0;
+
+ auto parseOneElement = [&]() -> ParseResult {
+ SmallVector<int64_t, 4> thisDims;
+ if (succeeded(parser.parseOptionalLBrace())) {
+ if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
+ zeroPoints, thisDims))
+ return failure();
+ } else {
+ zeroPoints.resize(zeroPoints.size() + 1);
+ scales.resize(scales.size() + 1);
+ if (parseQuantParams(parser, expressedType, scales.back(),
+ zeroPoints.back())) {
+ return failure();
+ }
+ }
+ ++size;
+ if (!first)
+ return checkDims(newDims, thisDims);
+ newDims = thisDims;
+ first = false;
+ return success();
+ };
+
+ if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
+ return failure();
+ }
+
+ // Return the sublists' dimensions with 'size' prepended.
+ dims.clear();
+ dims.push_back(size);
+ dims.append(newDims.begin(), newDims.end());
+
+ return success();
+}
+
/// Parses a UniformQuantizedType.
///
/// uniform_type ::= uniform_per_layer
/// | uniform_per_axis
+/// | uniform_sub_channel
/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
/// `,` scale-zero `>`
/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
-/// axis-spec `,` scale-zero-list `>`
+/// axis-spec `,` `{` scale-zero-list `}` `>`
+/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
+/// block-size-info `,` scale-zero-tensor `>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
/// axis-spec ::= `:` integer-literal
-/// scale-zero ::= float-literal `:` integer-literal
-/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
+/// scale-zero ::= scale (`:` zero-point)?
+/// scale ::= float-literal
+/// zero-point ::= integer-literal
+/// scale-zero-list ::= scale-zero (`,` scale-zero)*
+/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
+/// axis-block ::= axis-spec `:` block-size-spec
+/// block-size-spec ::= integer-literal
+/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
+/// scale-zero-dense-exp ::= `{`
+/// scale-zero-tensor (`,` scale-zero-tensor)*
+/// `}`
static Type parseUniformType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
@@ -198,7 +333,9 @@ static Type parseUniformType(DialectAsmParser &parser) {
int64_t storageTypeMin;
int64_t storageTypeMax;
bool isPerAxis = false;
- int32_t quantizedDimension;
+ bool isSubChannel = false;
+ SmallVector<int32_t, 1> quantizedDimensions;
+ SmallVector<int64_t, 1> blockSizes;
SmallVector<double, 1> scales;
SmallVector<int64_t, 1> zeroPoints;
@@ -228,11 +365,22 @@ static Type parseUniformType(DialectAsmParser &parser) {
return nullptr;
}
- // Optionally parse quantized dimension for per-axis quantization.
+ // Optionally parse quantized dimension for per-axis or sub-channel
+ // quantization.
if (succeeded(parser.parseOptionalColon())) {
- if (parser.parseInteger(quantizedDimension))
- return nullptr;
- isPerAxis = true;
+ if (succeeded(parser.parseOptionalLBrace())) {
+ isSubChannel = true;
+ if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
+ blockSizes)) {
+ return nullptr;
+ }
+ } else {
+ isPerAxis = true;
+ quantizedDimensions.resize(1);
+ if (parser.parseInteger(quantizedDimensions.back())) {
+ return nullptr;
+ }
+ }
}
// Comma leading into range_spec.
@@ -240,26 +388,21 @@ static Type parseUniformType(DialectAsmParser &parser) {
return nullptr;
}
- // Parameter specification.
- // For per-axis, ranges are in a {} delimitted list.
- if (isPerAxis) {
- if (parser.parseLBrace()) {
- return nullptr;
- }
- }
-
- // Parse scales/zeroPoints.
- SMLoc scaleZPLoc = parser.getCurrentLocation();
- do {
- scales.resize(scales.size() + 1);
+ // Quantization parameter (scales/zeroPoints) specification.
+ bool isPerTensor = !isPerAxis && !isSubChannel;
+ SmallVector<int64_t> dims;
+ if (isPerTensor) {
zeroPoints.resize(zeroPoints.size() + 1);
- if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
+ scales.resize(scales.size() + 1);
+ if (parseQuantParams(parser, expressedType, scales.back(),
+ zeroPoints.back())) {
return nullptr;
}
- } while (isPerAxis && succeeded(parser.parseOptionalComma()));
- if (isPerAxis) {
- if (parser.parseRBrace()) {
+ } else {
+ if (parser.parseLBrace() ||
+ parseQuantParamListUntilRBrace(parser, expressedType, scales,
+ zeroPoints, dims)) {
return nullptr;
}
}
@@ -268,19 +411,30 @@ static Type parseUniformType(DialectAsmParser &parser) {
return nullptr;
}
- if (!isPerAxis && scales.size() > 1) {
- return (parser.emitError(scaleZPLoc,
- "multiple scales/zeroPoints provided, but "
- "quantizedDimension wasn't specified"),
- nullptr);
- }
-
if (isPerAxis) {
- ArrayRef<double> scalesRef(scales.begin(), scales.end());
- ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
return parser.getChecked<UniformQuantizedPerAxisType>(
+ typeFlags, storageType, expressedType, scales, zeroPoints,
+ quantizedDimensions[0], storageTypeMin, storageTypeMax);
+ } else if (isSubChannel) {
+ SmallVector<APFloat> apFloatScales =
+ llvm::to_vector(llvm::map_range(scales, [&](double scale) -> APFloat {
+ APFloat apFloatScale(scale);
+ bool unused;
+ apFloatScale.convert(expressedType.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &unused);
+ return apFloatScale;
+ }));
+ SmallVector<APInt> apIntZeroPoints = llvm::to_vector(
+ llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
+ return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
+ }));
+ auto scalesRef = mlir::DenseElementsAttr::get(
+ RankedTensorType::get(dims, expressedType), apFloatScales);
+ auto zeroPointsRef = mlir::DenseElementsAttr::get(
+ RankedTensorType::get(dims, storageType), apIntZeroPoints);
+ return parser.getChecked<UniformQuantizedSubChannelType>(
typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
- quantizedDimension, storageTypeMin, storageTypeMax);
+ quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
}
return parser.getChecked<UniformQuantizedType>(
@@ -360,6 +514,19 @@ static void printQuantParams(double scale, int64_t zeroPoint,
}
}
+static void
+printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
+ DialectAsmPrinter &out) {
+ out << "{";
+ llvm::interleave(
+ llvm::seq<size_t>(0, blockSizeInfo.size()), out,
+ [&](size_t index) {
+ out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
+ },
+ ",");
+ out << "}";
+}
+
/// Helper that prints a AnyQuantizedType.
static void printAnyQuantizedType(AnyQuantizedType type,
DialectAsmPrinter &out) {
@@ -405,6 +572,74 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
out << "}>";
}
+/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
+/// elements. The nesting corresponds to the `shape` dimensions.
+///
+/// Elements are delimited by commas, and the inner dimensions are enclosed in
+/// braces. `zero_point` is only printed if it is non-zero. For example:
+///
+/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
+/// zeroPoints=[0, 0, 1, 9],
+/// shape=[2, 2])
+///
+/// would print:
+///
+/// {{1.0, 2.0}, {3.0:1, 4.0:9}}
+void printDenseQuantizationParameters(ArrayRef<APFloat> scales,
+ ArrayRef<APInt> zeroPoints,
+ ArrayRef<int64_t> shape,
+ DialectAsmPrinter &out) {
+ int64_t rank = shape.size();
+ SmallVector<unsigned, 4> counter(rank, 0);
+ unsigned openBrackets = 0;
+
+ auto bumpCounter = [&]() {
+ ++counter[rank - 1];
+ for (unsigned i = rank - 1; i > 0; --i) {
+ if (counter[i] >= shape[i]) {
+ counter[i] = 0;
+ ++counter[i - 1];
+ --openBrackets;
+ out << '}';
+ }
+ }
+ };
+
+ for (unsigned idx = 0, e = scales.size(); idx != e; ++idx) {
+ if (idx != 0)
+ out << ", ";
+ while (openBrackets++ < rank)
+ out << '{';
+ openBrackets = rank;
+ out << scales[idx];
+ if (zeroPoints[idx] != 0) {
+ out << ":" << zeroPoints[idx];
+ }
+ bumpCounter();
+ }
+ while (openBrackets-- > 0)
+ out << '}';
+}
+
+/// Helper that prints a UniformQuantizedSubChannelType.
+static void
+printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type,
+ DialectAsmPrinter &out) {
+ out << "uniform<";
+ printStorageType(type, out);
+ out << ":" << type.getExpressedType() << ":";
+ printBlockSizeInfo(type.getBlockSizeInfo(), out);
+ out << ", ";
+
+ auto scalesItr = type.getScales().getValues<APFloat>();
+ auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
+ SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
+ SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
+ printDenseQuantizationParameters(scales, zeroPoints,
+ type.getScales().getType().getShape(), out);
+ out << ">";
+}
+
/// Helper that prints a CalibratedQuantizedType.
static void printCalibratedQuantizedType(CalibratedQuantizedType type,
DialectAsmPrinter &out) {
@@ -421,6 +656,9 @@ void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
printUniformQuantizedType(uniformType, os);
else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
printUniformQuantizedPerAxisType(perAxisType, os);
+ else if (auto perAxisType =
+ llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
+ printUniformQuantizedSubChannelType(perAxisType, os);
else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
printCalibratedQuantizedType(calibratedType, os);
else
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 4adeb9218ff8ec..45d0a0c3e697a2 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -570,6 +570,73 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
return result;
}
+// Convert an operation using sub-channel quantization.
+//
+// - op
+// 'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+// Scalar, ranked tensor.
+//
+// - quantizedType
+// Sub-channel quantized type.
+//
+Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
+ Value input,
+ UniformQuantizedSubChannelType quantizedType) {
+ auto *context = builder.getContext();
+
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto inputRank = inputType.getRank();
+
+ auto scales = materializeSubChannelScales(builder, loc, quantizedType);
+ auto zeroPoints =
+ materializeSubChannelZeroPoints(builder, loc, quantizedType);
+
+ auto elementType = isa<FloatType>(inputType.getElementType())
+ ? quantizedType.getStorageType()
+ : quantizedType.getExpressedType();
+ auto initShape = tensor::getMixedSizes(builder, loc, input);
+ Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
+
+ SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+ utils::IteratorType::parallel);
+ const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
+ quantizedType.getBlockSizeInfo();
+ SmallVector<AffineExpr> affineExprs(inputRank,
+ builder.getAffineConstantExpr(0));
+ for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
+ affineExprs[quantizedDimension] =
+ builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
+ }
+ auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
+ SmallVector<AffineMap> indexingMaps{
+ builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
+ builder.getMultiDimIdentityMap(inputRank)};
+ auto result = builder
+ .create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto input = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result =
+ convertRanked(builder, loc, op, input, {}, scale,
+ zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, result);
+ })
+ .getResult(0);
+
+ return result;
+}
+
// Convert a quantization operation.
//
// - op
diff --git a/mlir/test/CAPI/quant.c b/mlir/test/CAPI/quant.c
index 0a09e084119f71..bc7cd1436efe18 100644
--- a/mlir/test/CAPI/quant.c
+++ b/mlir/test/CAPI/quant.c
@@ -203,6 +203,130 @@ void testUniformPerAxisType(MlirContext ctx) {
fprintf(stderr, "\n\n");
}
+// CHECK-LABEL: testUniformSubChannelType
+void testUniformSubChannelType(MlirContext ctx) {
+ fprintf(stderr, "testUniformSubChannelType\n");
+
+ MlirType subChannelParsed =
+ mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
+ "!quant.uniform<i8:f32:{0:1,1:2}, "
+ "{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"));
+
+ MlirType i8 = mlirIntegerTypeGet(ctx, 8);
+ MlirType f32 = mlirF32TypeGet(ctx);
+
+ // block-size information
+ int32_t quantizedDimensions[] = {0, 1};
+ int64_t blockSizes[] = {1, 2};
+ int64_t numBlockSizes = 2;
+
+ // quantization parameters
+ int64_t quantParamShape[] = {2, 2};
+ int64_t quantParamRank = 2;
+ int64_t numQuantizationParams = 4;
+ MlirAttribute scales[] = {mlirFloatAttrDoubleGet(ctx, f32, 2.0),
+ mlirFloatAttrDoubleGet(ctx, f32, 3.0),
+ mlirFloatAttrDoubleGet(ctx, f32, 4.0),
+ mlirFloatAttrDoubleGet(ctx, f32, 5.0)};
+ MlirAttribute zeroPoints[] = {
+ mlirIntegerAttrGet(i8, 10), mlirIntegerAttrGet(i8, 20),
+ mlirIntegerAttrGet(i8, 30), mlirIntegerAttrGet(i8, 40)};
+
+ MlirType scalesType =
+ mlirRankedTensorTypeGet(quantParamRank, quantParamShape, f32,
+ /*encoding=*/mlirAttributeGetNull());
+ MlirType zeroPointsType = mlirRankedTensorTypeGet(
+ quantParamRank, quantParamShape, i8, /*encoding=*/mlirAttributeGetNull());
+ MlirAttribute denseScalesAttr =
+ mlirDenseElementsAttrGet(scalesType, numQuantizationParams, scales);
+ MlirAttribute denseZeroPointsAttr = mlirDenseElementsAttrGet(
+ zeroPointsType, numQuantizationParams, zeroPoints);
+
+ MlirType subChannel = mlirUniformQuantizedSubChannelTypeGet(
+ mlirQuantizedTypeGetSignedFlag(), i8, f32, denseScalesAttr,
+ denseZeroPointsAttr, numBlockSizes, quantizedDimensions, blockSizes,
+ mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
+ /*integralWidth=*/8),
+ mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
+ /*integralWidth=*/8));
+
+ MlirAttribute arrayScalesAttr =
+ mlirArrayAttrGet(ctx, numQuantizationParams, scales);
+ MlirAttribute arrayZeroPointsAttr =
+ mlirArrayAttrGet(ctx, numQuantizationParams, zeroPoints);
+ MlirType illegalSubChannel = mlirUniformQuantizedSubChannelTypeGet(
+ mlirQuantizedTypeGetSignedFlag(), i8, f32, arrayScalesAttr,
+ arrayZeroPointsAttr, numBlockSizes, quantizedDimensions, blockSizes,
+ mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
+ /*integralWidth=*/8),
+ mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
+ /*integralWidth=*/8));
+
+ // CHECK: is null sub-channel type: 1
+ fprintf(stderr, "is null sub-channel type: %d\n",
+ mlirTypeIsNull(illegalSubChannel));
+
+ // CHECK: num dims: 2
+ fprintf(stderr, "num dims: %" PRId64 "\n",
+ mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(subChannel));
+
+ // CHECK: axis-block-size-pair[0]: 0:1
+ fprintf(
+ stderr, "axis-block-size-pair[0]: %" PRId32 ":%" PRId64 "\n",
+ mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(subChannel, 0),
+ mlirUniformQuantizedSubChannelTypeGetBlockSize(subChannel, 0));
+
+ // CHECK: axis-block-size-pair[1]: 1:2
+ fprintf(
+ stderr, "axis-block-size-pair[1]: %" PRId32 ":%" PRId64 "\n",
+ mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(subChannel, 1),
+ mlirUniformQuantizedSubChannelTypeGetBlockSize(subChannel, 1));
+
+ denseScalesAttr = mlirUniformQuantizedSubChannelTypeGetScales(subChannel);
+ denseZeroPointsAttr =
+ mlirUniformQuantizedSubChannelTypeGetZeroPoints(subChannel);
+ scalesType = mlirAttributeGetType(denseScalesAttr);
+ zeroPointsType = mlirAttributeGetType(denseZeroPointsAttr);
+
+ // CHECK: tensor<2x2xf32>
+ mlirTypeDump(scalesType);
+ // CHECK: tensor<2x2xi8>
+ mlirTypeDump(zeroPointsType);
+
+ // CHECK: number of quantization parameters: 4
+ fprintf(stderr, "number of quantization parameters: %" PRId64 "\n",
+ mlirElementsAttrGetNumElements(denseScalesAttr));
+
+ // CHECK: quantization-parameter[0]: 2.000000:10
+ fprintf(stderr, "quantization-parameter[0]: %lf:%" PRId8 "\n",
+ mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 0),
+ mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 0));
+
+ // CHECK: quantization-parameter[1]: 3.000000:20
+ fprintf(stderr, "quantization-parameter[1]: %lf:%" PRId8 "\n",
+ mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 1),
+ mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 1));
+
+ // CHECK: quantization-parameter[2]: 4.000000:30
+ fprintf(stderr, "quantization-parameter[2]: %lf:%" PRId8 "\n",
+ mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 2),
+ mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 2));
+
+ // CHECK: quantization-parameter[3]: 5.000000:40
+ fprintf(stderr, "quantization-parameter[3]: %lf:%" PRId8 "\n",
+ mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 3),
+ mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 3));
+
+ // CHECK: equal: 1
+ fprintf(stderr, "equal: %d\n", mlirTypeEqual(subChannel, subChannelParsed));
+
+ // CHECK: !quant.uniform<i8:f32:{0:1,1:2},
+ // {{.*}}2.000000e+00:10, 3.000000e+00:20},
+ // {4.000000e+00:30, 5.000000e+00:40{{.*}}}}>
+ mlirTypeDump(subChannel);
+ fprintf(stderr, "\n\n");
+}
+
// CHECK-LABEL: testCalibratedType
void testCalibratedType(MlirContext ctx) {
fprintf(stderr, "testCalibratedType\n");
diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir
index 359a58557087e1..29815ec0de41ee 100644
--- a/mlir/test/Dialect/Quant/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir
@@ -64,3 +64,12 @@ module @parseUniformPerAxisMixed attributes {
bytecode.test = !quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>
} {}
+//===----------------------------------------------------------------------===//
+// UniformQuantizedSubChannel
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: parseUniformSubChannel
+module @parseUniformSubChannel attributes {
+ // CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
+ bytecode.test = !quant.uniform<i8:f32:{0:1, 1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
+} {}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
index ba3a8e312d96e9..7bb50f352f9389 100644
--- a/mlir/test/Dialect/Quant/invalid.mlir
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -256,3 +256,71 @@ func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) {
return
}
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_scalar(%arg0: f32) {
+ // expected-error at +1 {{scalar types may not use sub-channel quantization}}
+ %0 = quant.qcast %arg0 : f32 to !qalias
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_unranked(%arg0: tensor<*xf32>) {
+ // expected-error at +1 {{tensor containing the sub-channel quantized type must be ranked}}
+ %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,3:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_invalid_quantized_dimension(%arg0: tensor<2x4xf32>) {
+ // expected-error at +1 {{quantized dimension 3 must be less than tensor rank 2}}
+ %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:3},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_invalid_tensor_dim_size(%arg0: tensor<2x4xf32>) {
+ // expected-error at +1 {{tensor dimension size 4 at axis 1 must be divisible by the corresponding block size 3}}
+ %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_invalid_zero_tensor_dim_size(%arg0: tensor<0x4xf32>) {
+ // expected-error at +1 {{tensor dimension size of zero is not allowed with sub-channel quantization}}
+ %0 = quant.qcast %arg0 : tensor<0x4xf32> to tensor<0x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120}, {2.000000e+02}}>
+func.func @qcast_sub_channel_invalid_scale_dim_size(%arg0: tensor<2x4xf32>) {
+ // expected-error at +1 {{dimension size 2 of scales tensor at axis 1 should match (tensor dimension at axis / block sizes at axis) = 2}}
+ %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias>
+ return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{},{{{2.000000e+02:120}}}>
+func.func @qcast_sub_channel_invalid_scale_dim_size(%arg0: tensor<?x?xf32>) {
+ // expected-error at +1 {{Rank of scales 3 must match the rank of the tensor 2}}
+ %0 = quant.qcast %arg0 : tensor<?x?xf32> to tensor<?x?x!qalias>
+ return
+}
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
index 4abc5830d081e1..33ff93ecbc1d7b 100644
--- a/mlir/test/Dialect/Quant/ops.mlir
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -148,4 +148,23 @@ func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) {
return
}
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @sub_channel_quantization(%arg0: tensor<2x4xi8>) -> tensor<2x4xi8> {
+ %0 = quant.scast %arg0 : tensor<2x4xi8> to tensor<2x4x!qalias>
+ %1 = quant.dcast %0 : tensor<2x4x!qalias> to tensor<2x4xf32>
+ %2 = quant.qcast %1 : tensor<2x4xf32> to tensor<2x4x!qalias>
+ %3 = quant.scast %2 : tensor<2x4x!qalias> to tensor<2x4xi8>
+ return %3 : tensor<2x4xi8>
+}
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @sub_channel_quantization_with_unknown_dims(%arg0: tensor<2x?xf32>) {
+ %0 = quant.qcast %arg0 : tensor<2x?xf32> to tensor<2x?x!qalias>
+ return
+}
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index 4528d2826a850c..3b358443e43f2d 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -107,7 +107,7 @@
// -----
// Illegal scale: negative
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale -1.000000e+00 out of expressed type range}}
!qalias = !quant.uniform<i8<-4:3>:f32, -1.0:127>
// -----
@@ -128,20 +128,110 @@
// -----
// Scale f16 underflow
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 5.800000e-08 out of expressed type range}}
!qalias = !quant.uniform<i8:f16, 5.8e-8>
// -----
// Scale f16 overflow
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 6.600000e+04 out of expressed type range}}
!qalias = !quant.uniform<i8:f16, 6.6e4>
// -----
// Scale f16 underflow in per-axis quantization
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 5.800000e-08 out of expressed type range}}
!qalias = !quant.uniform<i8:f16:1, {2.0,5.8e-8}>
// -----
// Scale f16 overflow in per-axis quantization
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 6.600000e+04 out of expressed type range}}
!qalias = !quant.uniform<i8:f16:1, {2.0,6.6e4}>
+
+// -----
+// Illegal negative axis in sub-channel quantization
+// expected-error at +1 {{illegal quantized dimension: -1}}
+!qalias = !quant.uniform<u8:f32:{0:1,-1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Illegal zero block-size in sub-channel quantization
+// expected-error at +1 {{illegal block size: 0}}
+!qalias = !quant.uniform<u8:f32:{0:0,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Illegal negative block-size in sub-channel quantization
+// expected-error at +1 {{illegal block size: -1}}
+!qalias = !quant.uniform<u8:f32:{0:-1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing block size in sub-channel quantization
+// expected-error at +1 {{expected ':'}}
+!qalias = !quant.uniform<u8:f32:{0,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing quantization dimension in sub-channel quantization
+// expected-error at +1 {{expected integer value}}
+!qalias = !quant.uniform<u8:f32:{:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Invalid tensor literal structure in sub-channel quantization
+// expected-error at +2 {{expected '>'}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}>
+
+// -----
+// Ragged tensor literal in sub-channel quantization
+// expected-error at +2 {{ranks are not consistent between elements}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02}}>
+
+// -----
+// Missing braces around block-size information in sub-channel quantization
+// expected-error at +1 {{expected ','}}
+!qalias = !quant.uniform<u8:f32:0:1,1:2,
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing right-brace around block-size information in sub-channel quantization
+// expected-error at +1 {{unbalanced '{' character}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2,
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing left-brace around block-size information in sub-channel quantization
+// expected-error at +1 {{unbalanced '<' character}}
+!qalias = !quant.uniform<u8:f32:0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing Axis:BlockSize pair
+// expected-error at +1 {{expected integer value}}
+!qalias = !quant.uniform<u8:f32:{0:1,},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing Scale:ZeroPoint pair
+// expected-error at +2 {{expected floating point literal}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,}}>
+
+// -----
+// Missing ZeroPoint in Scale:ZeroPoint pair
+// expected-error at +2 {{expected integer value}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+ {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01:}}>
+
+// -----
+// Empty quantization paramaters in sub-channel quantization
+// expected-error at +1 {{expected floating point literal}}
+!qalias = !quant.uniform<u8:f32:{0:1, 1:2}, {}>
+
+// -----
+// Scale out of expressed type range in sub-channel quantization
+// expected-error at +2 {{scale 6.600000e+04 out of expressed type range}}
+!qalias = !quant.uniform<i8:f16:{0:1,1:2},
+ {{6.6e4:120,9.987200e-01:127}, {2.000000e+02:256,9.987200e-01}}>
+
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 4fbe86d935ea39..35530833f2d3a7 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -154,3 +154,21 @@ func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
+
+// -----
+// Sub-channel scales and zero points (mixed affine and fixedpoint)
+// CHECK: !quant.uniform<u8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:120, 3.000000e+00:127}, {4.000000e+00, 5.000000e+00}}>
+!qalias = !quant.uniform<u8:f32:{0:1,1:2}, {{2.0:120,3.0:127}, {4.0,5.0}}>
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Empty block-size information in sub-channel quantization
+// CHECK: !quant.uniform<u8:f32:{}, {{\{}}{2.000000e+00:120, 3.000000e+00:127}, {4.000000e+00, 5.000000e+00}}>
+!qalias = !quant.uniform<u8:f32:{}, {{2.0:120,3.0:127}, {4.0,5.0}}>
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
>From a19f70c16bdcfb33f58242c03827191376fa670a Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Sun, 15 Dec 2024 10:37:17 +0000
Subject: [PATCH 2/4] Lowing to linalg ops
---
.../Quant/Transforms/LowerQuantOps.cpp | 214 +++++++++++-------
mlir/test/Dialect/Quant/lower-quant-ops.mlir | 64 ++++++
2 files changed, 194 insertions(+), 84 deletions(-)
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 45d0a0c3e697a2..c2dbcde1aeba6b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
return inputType;
}
-// Return the shape of an input value as a list of attributes (static dimensions)
-// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
-// returned. If 'input' is a tensor, its shape is returned.
-SmallVector<OpFoldResult>
-getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+// Return the shape of an input value as a list of attributes (static
+// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
+// list is returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
+ Location loc, Value input) {
if (isa<TensorType>(input.getType()))
return tensor::getMixedSizes(builder, loc, input);
return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
// Turn input size into 1D tensor
auto flatShapeType = shape::getExtentTensorType(context, 1);
- auto flatInputShape = builder.create<tensor::FromElementsOp>(
- loc, flatShapeType, inputSize);
+ auto flatInputShape =
+ builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
// Reshape input tensor into 1D
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType =
RankedTensorType::get({ShapedType::kDynamic}, elementType);
- auto flatInput = builder.create<tensor::ReshapeOp>(
- loc, flatInputType, input, flatInputShape);
+ auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+ flatInputShape);
return std::make_pair(flatInput, inputShape);
}
@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
// - inputShape
// 1D extent tensor containing the shape of the original unranked input.
//
-std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
- Location loc,
- Value input,
- int64_t axis,
- int64_t axisSize) {
+std::pair<Value, Value>
+flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
+ int64_t axis, int64_t axisSize) {
// Get full tensor shape
auto *context = builder.getContext();
auto indexType = builder.getIndexType();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
// Get shape and sizes on left and right of axis
auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
- auto shapeLeft = builder.create<shape::SplitAtOp>(
- loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
- .getResult(0);
- auto sizeLeft = builder.create<shape::NumElementsOp>(
- loc, indexType, shapeLeft);
- auto shapeRight = builder.create<shape::SplitAtOp>(
- loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
- .getResult(1);
- auto sizeRight = builder.create<shape::NumElementsOp>(
- loc, indexType, shapeRight);
+ auto shapeLeft =
+ builder
+ .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+ inputShape, axisValue)
+ .getResult(0);
+ auto sizeLeft =
+ builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
+ auto shapeRight =
+ builder
+ .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+ inputShape, axisNextValue)
+ .getResult(1);
+ auto sizeRight =
+ builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
// Compute flat input shape as a 3-element 1D tensor
auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
auto elementType = inputType.getElementType();
auto flatInputType = RankedTensorType::get(
{ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
- auto flatInput = builder.create<tensor::ReshapeOp>(
- loc, flatInputType, input, flatInputShape);
+ auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+ flatInputShape);
return std::make_pair(flatInput, inputShape);
}
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
auto inputType = cast<RankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto unrankedType = UnrankedTensorType::get(elementType);
- return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
+ return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
+ inputShape);
}
// Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
return builder.getFloatAttr(expressedType, scale);
});
- auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+ auto tensorType =
+ RankedTensorType::get({(int64_t)scales.size()}, expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
}
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
UniformQuantizedPerAxisType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
- auto zeroPointAttrs = llvm::map_to_vector(
- zeroPoints,
- [&](int64_t zeroPoint) -> Attribute {
+ auto zeroPointAttrs =
+ llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
return builder.getIntegerAttr(storageType, zeroPoint);
});
auto tensorType =
@@ -239,6 +242,54 @@ Value materializePerChannelZeroPoints(
return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
}
+// Create a tensor constant containing all scales in a sub-channel quantized
+// type. Example:
+//
+// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
+//
+// produces
+//
+// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
+//
+Value materializeSubChannelScales(
+ OpBuilder &builder, Location loc,
+ UniformQuantizedSubChannelType quantizedType) {
+ auto scales = quantizedType.getScales();
+ auto expressedType = quantizedType.getExpressedType();
+ auto scaleAttrs = llvm::map_to_vector(
+ scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
+ return builder.getFloatAttr(expressedType, scale);
+ });
+ auto tensorType =
+ RankedTensorType::get(scales.getType().getShape(), expressedType);
+ auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
+ return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+}
+
+// Create a tensor constant containing all zero points in a sub-channel
+// quantized type. Example:
+//
+// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
+//
+// produces
+//
+// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
+//
+Value materializeSubChannelZeroPoints(
+ OpBuilder &builder, Location loc,
+ UniformQuantizedSubChannelType quantizedType) {
+ auto zeroPoints = quantizedType.getZeroPoints();
+ auto storageType = quantizedType.getStorageType();
+ auto zeroPointAttrs = llvm::map_to_vector(
+ zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
+ return builder.getIntegerAttr(storageType, zeroPoint);
+ });
+ auto tensorType =
+ RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
+ auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
+ return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+}
+
// Clamp the given scalar or tensor input using the storage bounds encoded in
// the given quantized type, if present.
//
@@ -299,7 +350,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
return builder.create<arith::UIToFPOp>(loc, resultType, input);
}
-// Quantize a scalar or ranked tensor value. The stored value is clamped using
+// Quantize a scalar or ranked tensor value. The stored value is clamped using
// the storage bounds encoded in the given quantized type.
//
// See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +359,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
- scale = getScalarOrTensorConstant(
- builder, loc, scale, inputType, inputShape);
+ scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Scale input
auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
@@ -322,8 +372,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
inputShape);
// Convert zero point from storage to expressed type
- zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
- scale.getType(),
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Add zero point to stored value
@@ -334,9 +383,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
// Convert stored value to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), inputType);
- auto storedValueInt = convertFloatToInteger(
- builder, loc, storedValueFloat, storageScalarOrTensorType,
- quantizedType.isSigned());
+ auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
+ storageScalarOrTensorType,
+ quantizedType.isSigned());
// Clamp stored value it if the storage type is bound
auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
@@ -352,12 +401,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
- scale = getScalarOrTensorConstant(
- builder, loc, scale, inputType, inputShape);
+ scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Convert stored value to float
- auto result = convertIntegerToFloat(
- builder, loc, input, scale.getType(), quantizedType.isSigned());
+ auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
+ quantizedType.isSigned());
// Skip unnecessary computations if no zero point is given
if (!matchPattern(zeroPoint, m_Zero())) {
@@ -366,8 +414,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
inputShape);
// Convert zero point from storage to expressed type
- zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
- scale.getType(),
+ zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Subtract zero point to stored value
@@ -501,35 +548,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
auto initShape = tensor::getMixedSizes(builder, loc, input);
Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
- SmallVector<utils::IteratorType> iteratorTypes(
- inputRank, utils::IteratorType::parallel);
+ SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+ utils::IteratorType::parallel);
auto channelAxisAffineMap = AffineMap::get(
inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
SmallVector<AffineMap> indexingMaps{
- builder.getMultiDimIdentityMap(inputRank),
- channelAxisAffineMap,
- channelAxisAffineMap,
- builder.getMultiDimIdentityMap(inputRank)
- };
- auto result = builder.create<linalg::GenericOp>(
- loc,
- init.getType(), // resultType
- ValueRange{input, scales, zeroPoints}, // inputs
- ValueRange{init}, // outputs
- indexingMaps,
- iteratorTypes,
- [&](OpBuilder& builder, Location loc, ValueRange args) {
- assert(args.size() == 4);
- auto input = args[0];
- auto scale = args[1];
- auto zeroPoint = args[2];
-
- auto result = convertRanked(builder, loc, op, input, {}, scale,
- zeroPoint, quantizedType);
-
- builder.create<linalg::YieldOp>(loc, result);
- })
- .getResult(0);
+ builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
+ channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
+ auto result = builder
+ .create<linalg::GenericOp>(
+ loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto input = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result =
+ convertRanked(builder, loc, op, input, {}, scale,
+ zeroPoint, quantizedType);
+
+ builder.create<linalg::YieldOp>(loc, result);
+ })
+ .getResult(0);
return result;
}
@@ -551,7 +596,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
// Flatten unranked tensor into a 3D ranked tensor if necessary
bool isUnranked = isa<UnrankedTensorType>(input.getType());
int64_t channelAxis = quantizedType.getQuantizedDimension();
- int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+ int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
Value inputShape;
if (isUnranked) {
std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
@@ -660,11 +705,17 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
return convertPerChannel(builder, loc, op, input,
uniformQuantizedPerAxisType);
+ if (auto uniformQuantizedSubChannelType =
+ dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
+ return convertSubChannel(builder, loc, op, input,
+ uniformQuantizedSubChannelType);
+
llvm_unreachable("unexpected quantized type");
}
// Lowering pattern for 'quant.dcast'
-struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+struct DequantizeCastOpConversion
+ : public OpConversionPattern<quant::DequantizeCastOp> {
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
LogicalResult
@@ -689,7 +740,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
};
// Lowering pattern for 'quant.qcast'
-struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+struct QuantizeCastOpConversion
+ : public OpConversionPattern<quant::QuantizeCastOp> {
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
LogicalResult
@@ -717,12 +769,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
ConversionTarget target(getContext());
target.addLegalOp<quant::StorageCastOp>();
target.addIllegalDialect<quant::QuantDialect>();
- target.addLegalDialect<
- arith::ArithDialect,
- linalg::LinalgDialect,
- shape::ShapeDialect,
- tensor::TensorDialect
- >();
+ target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+ shape::ShapeDialect, tensor::TensorDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -733,10 +781,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
} // namespace
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
- patterns.add<
- DequantizeCastOpConversion,
- QuantizeCastOpConversion
- >(patterns.getContext());
+ patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
+ patterns.getContext());
}
} // namespace quant
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 6bba9f5c037727..2d22b1af89f2b4 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -509,3 +509,67 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias>
return %0 : tensor<*x!qalias>
}
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3 floordiv 2)>
+
+// CHECK-LABEL: @qcast_sub_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x?x?x4xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<{{.*}}2.000000e+00, 3.000000e+00{{.*}}, {{.*}}4.000000e+00, 5.000000e+00{{.*}}> : tensor<2x1x1x2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<{{.*}}10, 20{{.*}}, {{.*}}30, 40{{.*}}> : tensor<2x1x1x2xi8>
+
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<2x?x?x4xf32>
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<2x?x?x4xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<2x?x?x4xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<2x?x?x4xf32>, tensor<2x1x1x2xf32>, tensor<2x1x1x2xi8>) outs(%[[INIT]] : tensor<2x?x?x4xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<2x?x?x4xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x?x?x4xi8> to tensor<2x?x?x4x!quant.uniform<i8:f32:{0:1,3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
+// CHECK: return %[[STORED_QUANT]]
+
+!qalias = !quant.uniform<i8:f32:{0:1, 3:2}, {{{{2.0:10, 3.0:20}}}, {{{4.0:30, 5.0:40}}}}>
+func.func @qcast_sub_channel_ranked(%arg0: tensor<2x?x?x4xf32>) -> tensor<2x?x?x4x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<2x?x?x4xf32> to tensor<2x?x?x4x!qalias>
+ return %0 : tensor<2x?x?x4x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3 floordiv 2)>
+
+// CHECK-LABEL: @qcast_sub_channel_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x3x5x4xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<{{.*}}2.000000e+00, 3.000000e+00{{.*}}, {{.*}}4.000000e+00, 5.000000e+00{{.*}}> : tensor<2x1x1x2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<{{.*}}10, 20{{.*}}, {{.*}}30, 40{{.*}}> : tensor<2x1x1x2xi8>
+
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x3x5x4xi8>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<2x3x5x4xf32>, tensor<2x1x1x2xf32>, tensor<2x1x1x2xi8>) outs(%[[INIT]] : tensor<2x3x5x4xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK: linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<2x3x5x4xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x3x5x4xi8> to tensor<2x3x5x4x!quant.uniform<i8:f32:{0:1,3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
+// CHECK: return %[[STORED_QUANT]]
+
+!qalias = !quant.uniform<i8:f32:{0:1, 3:2}, {{{{2.0:10, 3.0:20}}}, {{{4.0:30, 5.0:40}}}}>
+func.func @qcast_sub_channel_ranked_bounds(%arg0: tensor<2x3x5x4xf32>) -> tensor<2x3x5x4x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<2x3x5x4xf32> to tensor<2x3x5x4x!qalias>
+ return %0 : tensor<2x3x5x4x!qalias>
+}
\ No newline at end of file
>From 4727eba354ed48b0da6fec83644d7f6faa1aeeb7 Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Sun, 15 Dec 2024 12:13:59 +0000
Subject: [PATCH 3/4] Addings c-api and py-apis
---
mlir/lib/Bindings/Python/DialectQuant.cpp | 1 +
mlir/lib/CAPI/Dialect/Quant.cpp | 1 +
.../mlir/_mlir_libs/_mlir/dialects/quant.pyi | 22 ++++++++-
mlir/test/CAPI/quant.c | 2 +
mlir/test/python/dialects/quant.py | 49 +++++++++++++++++++
5 files changed, 74 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 44a596caa24a65..8f7da9ef3c1bca 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Quant.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 88648497895ab7..01a6a948f1dc07 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
index 47168d49c5568b..3f5304584edeff 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from mlir.ir import Type
+from mlir.ir import DenseElementsAttr, Type
__all__ = [
"QuantizedType",
@@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType):
@property
def is_fixed_point(self) -> bool: ...
+class UniformQuantizedSubChannelType(QuantizedType):
+
+ @classmethod
+ def get(cls, flags: int, storage_type: Type, expressed_type: Type,
+ scales: DenseElementsAttr, zero_points: DenseElementsAttr,
+ quantized_dimensions: list[int], block_sizes: list[int],
+ storage_type_min: int, storage_type_max: int):
+ ...
+
+ @property
+ def quantized_dimensions(self) -> list[int]: ...
+
+ @property
+ def block_sizes(self) -> list[int]: ...
+
+ @property
+ def scales(self) -> DenseElementsAttr: ...
+
+ @property
+ def zero_points(self) -> DenseElementsAttr: ...
def CalibratedQuantizedType(QuantizedType):
diff --git a/mlir/test/CAPI/quant.c b/mlir/test/CAPI/quant.c
index bc7cd1436efe18..a5be7550b7d63f 100644
--- a/mlir/test/CAPI/quant.c
+++ b/mlir/test/CAPI/quant.c
@@ -10,6 +10,7 @@
// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
@@ -357,6 +358,7 @@ int main(void) {
testAnyQuantizedType(ctx);
testUniformType(ctx);
testUniformPerAxisType(ctx);
+ testUniformSubChannelType(ctx);
testCalibratedType(ctx);
mlirContextDestroy(ctx);
return EXIT_SUCCESS;
diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py
index b1d6e85f519b5d..b31bd2088610a3 100644
--- a/mlir/test/python/dialects/quant.py
+++ b/mlir/test/python/dialects/quant.py
@@ -1,5 +1,6 @@
# RUN: %PYTHON %s | FileCheck %s
+import numpy as np
from mlir.ir import *
from mlir.dialects import quant
@@ -18,21 +19,28 @@ def test_type_hierarchy():
any = Type.parse("!quant.any<i8<-8:7>:f32>")
uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+ sub_channel = Type.parse(
+ "!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
+ )
calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
assert not quant.QuantizedType.isinstance(i8)
assert quant.QuantizedType.isinstance(any)
assert quant.QuantizedType.isinstance(uniform)
assert quant.QuantizedType.isinstance(per_axis)
+ assert quant.QuantizedType.isinstance(sub_channel)
assert quant.QuantizedType.isinstance(calibrated)
assert quant.AnyQuantizedType.isinstance(any)
assert quant.UniformQuantizedType.isinstance(uniform)
assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
+ assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
assert quant.CalibratedQuantizedType.isinstance(calibrated)
assert not quant.AnyQuantizedType.isinstance(uniform)
assert not quant.UniformQuantizedType.isinstance(per_axis)
+ assert not quant.UniformQuantizedType.isinstance(sub_channel)
+ assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
# CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +129,47 @@ def test_uniform_per_axis_type():
assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+# CHECK-LABEL: TEST: test_uniform_sub_channel_type
+ at run
+def test_uniform_sub_channel_type():
+ with Context():
+ i8 = IntegerType.get_signless(8)
+ f32 = F32Type.get()
+ sub_channel = quant.UniformQuantizedSubChannelType.get(
+ quant.QuantizedType.FLAG_SIGNED,
+ i8,
+ f32,
+ DenseElementsAttr.get(
+ np.asarray([2.0, 3.0, 4.0, 5.0], np.float32).reshape(2, 2)
+ ),
+ DenseElementsAttr.get(np.asarray([10, 20, 30, 40], np.int8).reshape(2, 2)),
+ [0, 1],
+ [1, 2],
+ storage_type_min=quant.QuantizedType.default_minimum_for_integer(
+ is_signed=True, integral_width=8
+ ),
+ storage_type_max=quant.QuantizedType.default_maximum_for_integer(
+ is_signed=True, integral_width=8
+ ),
+ )
+
+ # CHECK: quantized dimensions: [0, 1]
+ print(f"quantized dimensions: {sub_channel.quantized_dimensions}")
+ # CHECK: block sizes: [1, 2]
+ print(f"block sizes: {sub_channel.block_sizes}")
+ # CHECK: scales: {{\[}}[2. 3.]
+ # CHECK: [4. 5.]]
+ print(f"scales: {np.asarray(sub_channel.scales)}")
+ # CHECK: zero-points: {{\[}}[10 20]
+ # CHECK: [30 40]]
+ print(f"zero-points: {np.asarray(sub_channel.zero_points)}")
+ # CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
+ print(sub_channel)
+ assert sub_channel == Type.parse(
+ "!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"
+ )
+
+
# CHECK-LABEL: TEST: test_calibrated_type
@run
def test_calibrated_type():
>From 0f4147e224f733df95d88367e6eb4a6a06c2714b Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Thu, 12 Dec 2024 20:27:04 +0000
Subject: [PATCH 4/4] Add pass to normalize generic quantized types to specific
quantized types
---
.../mlir/Dialect/Quant/Transforms/Passes.td | 33 ++++
.../Dialect/Quant/Transforms/CMakeLists.txt | 1 +
.../Quant/Transforms/NormalizeQuantTypes.cpp | 179 ++++++++++++++++++
.../Dialect/Quant/normalize-quant-types.mlir | 51 +++++
4 files changed, 264 insertions(+)
create mode 100644 mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
create mode 100644 mlir/test/Dialect/Quant/normalize-quant-types.mlir
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index b25296d4db5a99..4b438706c9af69 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -31,6 +31,39 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
];
}
+def NormalizeQuantTypes : Pass<"normalize-quant-types"> {
+ let summary = "Normalize generic quantized types to specific quantized types";
+ let description = [{
+ This pass converts generic quantized types in the `quant` dialect to more
+ specific types when possible.
+
+ The following conversions are performed:
+
+ 1. Sub-channel to per-axis: If the shape of the scales tensor of sub-channel
+ quantized type has all but one non-one value, it is converted to a
+ per-axis quantized type.
+
+ For example:
+
+ * `!quant.uniform<i8:f32:{0:1}, {{2.0}, {3.0}}>`
+ -> `!quant.uniform<i8:f32:0, {2.0, 3.0}>`
+ * `tensor<?x?x!quant.uniform<i8:f32:{0:1,1:4}, {{2.0}, {3.0}}>>`
+ -> `tensor<?x?x!quant.uniform<i8:f32:0, {2.0, 3.0}>>`
+
+ 2. Sub-channel to per-tensor: If a sub-channel quantized type has only
+ one scale or zero-point, it is converted to a per-tensor
+ quantized type.
+
+ For example:
+
+ * `!quant.uniform<i8:f32:{}, {{2.0}}>`
+ -> `!quant.uniform<i8:f32, 2.0>`
+ * `tensor<?x?x!quant.uniform<i8:f32:{0:1, 0:4}, {{2.0}}>>`
+ -> `tensor<?x?x!quant.uniform<i8:f32, 2.0>>`
+ }];
+ let dependentDialects = ["func::FuncDialect", "quant::QuantDialect"];
+}
+
def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
let summary = "Strip quantized types from function headers";
let description = [{
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 2fd4a41999d456..825d11992d309c 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRQuantTransforms
LowerQuantOps.cpp
+ NormalizeQuantTypes.cpp
StripFuncQuantTypes.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
new file mode 100644
index 00000000000000..ecd6679651f462
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
@@ -0,0 +1,179 @@
+//===- NormalizeQuantTypes.cpp - Normalize quantized types
+//----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Normalize generic quantized types to specific quantized types
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+/// Returns true if the given sub-channel quantized type is convertible to a
+/// per-tensor quantized type. This is true if the sub-channel type has only
+/// one scale and one zero point.
+///
+/// Assumes that `tensorType` is a tensor with element type
+/// `quant::UniformQuantizedSubChannelType`.
+static bool isConvertibleToPerTensor(TensorType tensorType) {
+ return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
+ .getScales()
+ .getType()
+ .getNumElements() == 1;
+}
+
+/// Returns true if the given sub-channel quantized type is convertible to a
+/// per-axis quantized type. This is true if the shape of the scales tensor has
+/// all but one non-one value.
+///
+/// Assumes that `tensorType` is a tensor with element type
+/// `quant::UniformQuantizedSubChannelType`.
+static bool isConvertibleToPerAxis(TensorType tensorType) {
+ auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
+ .getScales()
+ .getType()
+ .getShape();
+ return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
+}
+
+/// This class defines a type converter that converts sub-channel quantized
+/// types to per-tensor or per-axis quantized types whenever possible.
+class NormalizedQuantTypesConverter : public TypeConverter {
+
+ static Type convertType(Type type) {
+ auto tensorType = dyn_cast<TensorType>(type);
+ if (!tensorType) {
+ return type;
+ }
+
+ auto subChannelType =
+ dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
+ if (!subChannelType) {
+ return type;
+ }
+
+ if (isConvertibleToPerTensor(tensorType)) {
+ double scale =
+ subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
+ int64_t zeroPoint =
+ subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
+ auto perTensorType = UniformQuantizedType::get(
+ subChannelType.getFlags(), subChannelType.getStorageType(),
+ subChannelType.getExpressedType(), scale, zeroPoint,
+ subChannelType.getStorageTypeMin(),
+ subChannelType.getStorageTypeMax());
+ return tensorType.clone(perTensorType);
+ }
+
+ if (isConvertibleToPerAxis(tensorType)) {
+ auto shape = subChannelType.getScales().getType().getShape();
+ auto quantizedDimItr =
+ llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
+ auto scales = llvm::to_vector(llvm::map_range(
+ subChannelType.getScales().getValues<APFloat>(),
+ [](APFloat scale) { return scale.convertToDouble(); }));
+ auto zeroPoints = llvm::to_vector(llvm::map_range(
+ subChannelType.getZeroPoints().getValues<APInt>(),
+ [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
+ auto perAxisType = UniformQuantizedPerAxisType::get(
+ subChannelType.getFlags(), subChannelType.getStorageType(),
+ subChannelType.getExpressedType(), scales, zeroPoints,
+ quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
+ subChannelType.getStorageTypeMax());
+ return tensorType.clone(perAxisType);
+ }
+ return type;
+ }
+
+public:
+ explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
+};
+
+/// This class implements a conversion pattern that converts any generic
+/// operation with sub-channel quantized types to an equivalent operation with
+/// per-tensor or per-axis quantized types.
+class ConvertGenericOpwithSubChannelType : public ConversionPattern {
+public:
+ ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
+ MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ SmallVector<Type> resultTypes;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
+ return failure();
+
+ auto *newOp = Operation::create(
+ op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+ for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
+ Region &before = std::get<0>(regions);
+ Region &parent = std::get<1>(regions);
+ rewriter.inlineRegionBefore(before, parent, parent.end());
+ if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
+ return failure();
+ }
+ rewriter.insert(newOp);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+// Conversion pass
+class NormalizeQuantTypes
+ : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
+public:
+ void runOnOperation() override {
+
+ auto moduleOp = cast<ModuleOp>(getOperation());
+ auto *context = &getContext();
+
+ NormalizedQuantTypesConverter typeConverter;
+ ConversionTarget target(*context);
+
+ // Determine legal operations.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+ return typeConverter.isLegal(op->getOperandTypes()) &&
+ typeConverter.isLegal(op->getResultTypes());
+ });
+
+ // Register conversion patterns
+ RewritePatternSet patterns(context);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
+
+ // Apply conversion
+ if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
diff --git a/mlir/test/Dialect/Quant/normalize-quant-types.mlir b/mlir/test/Dialect/Quant/normalize-quant-types.mlir
new file mode 100644
index 00000000000000..573781c9ecc04a
--- /dev/null
+++ b/mlir/test/Dialect/Quant/normalize-quant-types.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --normalize-quant-types --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @callee(
+// CHECK-SAME: [[PER_TENSOR:tensor<\?x\?x!quant.uniform<i8:f32, 2.000000e\+00:127>>]],
+// CHECK-SAME: [[PER_TENSOR]]
+// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]])
+// CHECK-LABEL: @normalize_quant_types_to_per_tensor
+// CHECK-SAME: %[[ARG_0:.*]]: [[PER_TENSOR:tensor<\?x\?x!quant.uniform<i8:f32, 2.000000e\+00:127>>]],
+// CHECK-SAME: %[[ARG_1:.*]]: [[PER_TENSOR]]
+// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]])
+// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]]
+// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]]
+// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]])
+// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_TENSOR]], [[PER_TENSOR]]
+
+!qalias1 = !quant.uniform<i8:f32:{}, {{2.0:127}}>
+!qalias2 = !quant.uniform<i8:f32:{0:1,1:4}, {{2.0:127}}>
+
+func.func private @callee(tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+
+func.func @normalize_quant_types_to_per_tensor(%arg0: tensor<?x?x!qalias1>,
+ %arg1: tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) {
+ %0 = "test.custom_op"(%arg0) : (tensor<?x?x!qalias1>) -> tensor<?x?x!qalias1>
+ %1 = "test.custom_op"(%arg1) : (tensor<?x?x!qalias2>) -> tensor<?x?x!qalias2>
+ %3:2 = func.call @callee(%0, %1) : (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+ return %3#0, %3#1 : tensor<?x?x!qalias1>, tensor<?x?x!qalias2>
+}
+
+// -----
+
+// CHECK-LABEL: @normalize_quant_types_to_per_axis
+// CHECK-SAME: %[[ARG_0:.*]]: [[PER_AXIS:tensor<\?x\?x!quant.uniform<i8:f32:0, \{2.000000e\+00:127,3.000000e\+00:127\}>>]],
+// CHECK-SAME: %[[ARG_1:.*]]: [[PER_AXIS]]
+// CHECK-SAME: ([[PER_AXIS]], [[PER_AXIS]])
+// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_AXIS]]) -> [[PER_AXIS]]
+// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_AXIS]]) -> [[PER_AXIS]]
+// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]])
+// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_AXIS]], [[PER_AXIS]]
+
+!qalias1 = !quant.uniform<i8:f32:{0:1}, {{2.0:127}, {3.0:127}}>
+!qalias2 = !quant.uniform<i8:f32:{0:1,1:4}, {{2.0:127}, {3.0:127}}>
+
+func.func private @callee(tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+
+func.func @normalize_quant_types_to_per_axis(%arg0: tensor<?x?x!qalias1>,
+ %arg1: tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) {
+ %0 = "test.custom_op"(%arg0) : (tensor<?x?x!qalias1>) -> tensor<?x?x!qalias1>
+ %1 = "test.custom_op"(%arg1) : (tensor<?x?x!qalias2>) -> tensor<?x?x!qalias2>
+ %3:2 = func.call @callee(%0, %1) : (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+ return %3#0, %3#1 : tensor<?x?x!qalias1>, tensor<?x?x!qalias2>
+}
More information about the Mlir-commits
mailing list