[Mlir-commits] [mlir] Sub-channel quantized type implementation (PR #120172)

Kevin Gleason llvmlistbot at llvm.org
Tue Feb 11 14:55:22 PST 2025


================
@@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType
   }
 };
 
+/// Represents sub-channel (also known as blockwise quantization).
+///
+/// Syntax synopsis:
+///   UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
+///       storageType ('<' storageMin ':' storageMax '>')? ':'
+///       expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
+///   BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
+///   AxisBlock ::= AxisSpec ':' BlockSizeSpec
+///   ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
+///   ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
+///   ScaleZeroList  ::= ScaleZero (',' ScaleZero)*
+///   ScaleZero ::= Scale (':' ZeroPoint)?
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   AxisSpec: An integer value
+///   BlockSizeSpec: An integer value
+///   Scale: An attribute (usually floating-point value)
+///   ZeroPoint: An attribute (usually integer value)
+class UniformQuantizedSubChannelType
+    : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
+                            detail::UniformQuantizedSubChannelTypeStorage> {
+public:
+  using Base::Base;
+  using Base::getChecked;
+
+  static constexpr StringLiteral name = "quant.uniform_sub_channel";
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedSubChannelType
+  get(unsigned flags, Type storageType, Type expressedType,
+      DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+      ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedSubChannelType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, DenseElementsAttr scales,
+             DenseElementsAttr zeroPoints,
+             ArrayRef<int32_t> quantizedDimensions,
+             ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+             int64_t storageTypeMax);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+                   ArrayRef<int32_t> quantizedDimensions,
+                   ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
+
+  /// Gets the quantization scales. The scales are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the scales tensor
+  /// is determined by the number of blocks along the corresponding dimension in
+  /// the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
+  /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The scale value for a specific element in the quantized data tensor at
+  /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
+  /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
+  DenseElementsAttr getScales() const;
+
+  /// Gets the quantization zero-points. The zero-points are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the zero-point
+  /// tensor is determined by the number of blocks along the corresponding
+  /// dimension in the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
+  /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The zero-point value for a specific element in the quantized data tensor
+  /// at position [i0, i1, ..., iR-1] is determined by accessing the
+  /// corresponding element in the zero-point tensor at position [i0/B0, i1/B1,
+  /// ..., iR-1/BR-1].
+  DenseElementsAttr getZeroPoints() const;
+
+  /// Gets the quantized dimensions. Each element in the returned list
+  /// represents an axis of the quantized data tensor that has a specified block
+  /// size. The order of elements corresponds to the order of block sizes
+  /// returned by `getBlockSizes()`.
+  ///
+  /// It means that the data tensor is quantized along the `i`-th dimension in
+  /// the returned list using the `i`-th block size from `getBlockSizes()`.
+  ///
+  /// Note that the type expression does not have to specify the block size for
+  /// all axes in the data tensor. Any unspecified block size for an axis `i`
+  /// defaults to the tensor dimension size of that axis.
+  ///
+  /// For example, for a quantized type:
+  /// `tensor<8x4x2x!quant.uniform<i8:f32:{1:2, 0:8}, {{1.0, 2.0}, {3.0, 4.0}}>`
+  ///
+  /// `getQuantizedDimensions()` returns [1, 0].
+  /// `getBlockSizes()` returns [2, 8].
+  ///
+  /// This indicates that:
+  ///  * Axis 1 (second dimension) is quantized with a block size of 2.
+  ///  * Axis 0 (first dimension) is quantized with a block size of 8.
+  ///  Since axis 2 is not specified, it implicitly has a block size equal to
+  ///  the size of the third dimension (which is 2 in this case).
+  ArrayRef<int32_t> getQuantizedDimensions() const;
+
+  /// Gets the block sizes for the quantized dimensions. The `i`-th element in
+  /// the returned list corresponds to the block size for the `i`-th dimension
+  /// in the list returned by `getQuantizedDimensions()`.
+  ///
+  /// See `getQuantizedDimensions()` for more details and examples.
+  ArrayRef<int64_t> getBlockSizes() const;
+
+  /// Gets the block size information. This returns a list of pairs, where each
+  /// pair represents a quantized dimension and its corresponding block size.
+  ///
+  /// For example, for the type:
+  ///  `tensor<8x4x!quant.uniform<i8:f32:{1:2, 0:8}, {{2.0, 3.0}}>`
+  ///
+  /// This method returns:
+  ///  `[(1, 2), (0, 8)]`
+  ///
+  /// This list indicates that axis 1 has a block size of 2, and axis 0 has a
----------------
GleasonK wrote:

Do we enforce that quantization dims must be increasing / have one specified for each quantization dimension? Should we?

This API could just return a list of ints that indicates the block size for each quantization dimension where `getBlockSizeInfo()[0]` is the block size for quantization dim 0.

https://github.com/llvm/llvm-project/pull/120172


More information about the Mlir-commits mailing list