[Mlir-commits] [mlir] Sub-channel quantized type implementation (PR #120172)
Kevin Gleason
llvmlistbot at llvm.org
Tue Feb 11 14:55:22 PST 2025
================
@@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType
}
};
+/// Represents sub-channel (also known as blockwise quantization).
+///
+/// Syntax synopsis:
+/// UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
+/// storageType ('<' storageMin ':' storageMax '>')? ':'
+/// expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
+/// BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
+/// AxisBlock ::= AxisSpec ':' BlockSizeSpec
+/// ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
+/// ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
+/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
+/// ScaleZero ::= Scale (':' ZeroPoint)?
+///
+/// StorageType: 'i'|'u' NumBits
+/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+/// AxisSpec: An integer value
+/// BlockSizeSpec: An integer value
+/// Scale: An attribute (usually floating-point value)
+/// ZeroPoint: An attribute (usually integer value)
+class UniformQuantizedSubChannelType
+ : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
+ detail::UniformQuantizedSubChannelTypeStorage> {
+public:
+ using Base::Base;
+ using Base::getChecked;
+
+ static constexpr StringLiteral name = "quant.uniform_sub_channel";
+
+ /// Gets an instance of the type with all parameters specified but not
+ /// checked.
+ static UniformQuantizedSubChannelType
+ get(unsigned flags, Type storageType, Type expressedType,
+ DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+ int64_t storageTypeMin, int64_t storageTypeMax);
+
+ /// Gets an instance of the type with all specified parameters checked.
+ /// Returns a nullptr convertible type on failure.
+ static UniformQuantizedSubChannelType
+ getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, DenseElementsAttr scales,
+ DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax);
+
+ /// Verifies construction invariants and issues errors/warnings.
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType,
+ DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+ ArrayRef<int32_t> quantizedDimensions,
+ ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+ int64_t storageTypeMax);
+
+ /// Gets the quantization scales. The scales are organized in a
+ /// multi-dimensional tensor. The size of each dimension in the scales tensor
+ /// is determined by the number of blocks along the corresponding dimension in
+ /// the quantized data tensor.
+ ///
+ /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+ /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
+ /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+ ///
+ /// The scale value for a specific element in the quantized data tensor at
+ /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
+ /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
+ DenseElementsAttr getScales() const;
+
+ /// Gets the quantization zero-points. The zero-points are organized in a
+ /// multi-dimensional tensor. The size of each dimension in the zero-point
+ /// tensor is determined by the number of blocks along the corresponding
+ /// dimension in the quantized data tensor.
+ ///
+ /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+ /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
+ /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+ ///
+ /// The zero-point value for a specific element in the quantized data tensor
+ /// at position [i0, i1, ..., iR-1] is determined by accessing the
+ /// corresponding element in the zero-point tensor at position [i0/B0, i1/B1,
+ /// ..., iR-1/BR-1].
+ DenseElementsAttr getZeroPoints() const;
+
+ /// Gets the quantized dimensions. Each element in the returned list
+ /// represents an axis of the quantized data tensor that has a specified block
+ /// size. The order of elements corresponds to the order of block sizes
+ /// returned by `getBlockSizes()`.
+ ///
+ /// It means that the data tensor is quantized along the `i`-th dimension in
+ /// the returned list using the `i`-th block size from `getBlockSizes()`.
+ ///
+ /// Note that the type expression does not have to specify the block size for
+ /// all axes in the data tensor. Any unspecified block size for an axis `i`
+ /// defaults to the tensor dimension size of that axis.
+ ///
+ /// For example, for a quantized type:
+ /// `tensor<8x4x2x!quant.uniform<i8:f32:{1:2, 0:8}, {{1.0, 2.0}, {3.0, 4.0}}>`
+ ///
+ /// `getQuantizedDimensions()` returns [1, 0].
+ /// `getBlockSizes()` returns [2, 8].
+ ///
+ /// This indicates that:
+ /// * Axis 1 (second dimension) is quantized with a block size of 2.
+ /// * Axis 0 (first dimension) is quantized with a block size of 8.
+ /// Since axis 2 is not specified, it implicitly has a block size equal to
+ /// the size of the third dimension (which is 2 in this case).
+ ArrayRef<int32_t> getQuantizedDimensions() const;
+
+ /// Gets the block sizes for the quantized dimensions. The `i`-th element in
+ /// the returned list corresponds to the block size for the `i`-th dimension
+ /// in the list returned by `getQuantizedDimensions()`.
+ ///
+ /// See `getQuantizedDimensions()` for more details and examples.
+ ArrayRef<int64_t> getBlockSizes() const;
+
+ /// Gets the block size information. This returns a list of pairs, where each
+ /// pair represents a quantized dimension and its corresponding block size.
+ ///
+ /// For example, for the type:
+ /// `tensor<8x4x!quant.uniform<i8:f32:{1:2, 0:8}, {{2.0, 3.0}}>`
+ ///
+ /// This method returns:
+ /// `[(1, 2), (0, 8)]`
+ ///
+ /// This list indicates that axis 1 has a block size of 2, and axis 0 has a
----------------
GleasonK wrote:
Edit: Looks like no. In that case pairs is an ok API. We could consider a lightweight struct to give the elements better names `(element.getQuantizationDimension() / element.getBlockSize())`, but pair is probably ok.
https://github.com/llvm/llvm-project/pull/120172
More information about the Mlir-commits
mailing list