[Mlir-commits] [mlir] Subchannel quant impl (PR #120172)

Sandeep Dasgupta llvmlistbot at llvm.org
Mon Dec 16 18:14:40 PST 2024


https://github.com/sdasgup3 created https://github.com/llvm/llvm-project/pull/120172

This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694).

In order to make the review process easier, the PR has been divided into the following commit labels:

1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered.
2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 
3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those.
4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible.    


A  design note:
 - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** 
   While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked
  data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping.

```
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>
```

The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome!


PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.

>From 4a160d34a551692eec03146c9697fe8e2d412bd8 Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Sun, 15 Dec 2024 12:13:43 +0000
Subject: [PATCH 1/4] Add implementation for sub-channel type
 (design/assembly/verification/bytecode)

---
 mlir/include/mlir-c/Dialect/Quant.h           |  41 +++
 .../mlir/Dialect/Quant/IR/QuantBase.td        | 192 ++++++++++-
 .../Dialect/Quant/IR/QuantDialectBytecode.td  |  30 +-
 .../mlir/Dialect/Quant/IR/QuantTypes.h        | 131 ++++++++
 mlir/lib/Bindings/Python/DialectQuant.cpp     |  73 ++++
 mlir/lib/CAPI/Dialect/Quant.cpp               |  55 +++
 .../Dialect/Quant/IR/QuantDialectBytecode.cpp |   1 +
 mlir/lib/Dialect/Quant/IR/QuantOps.cpp        | 147 ++++++--
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      | 121 ++++++-
 mlir/lib/Dialect/Quant/IR/TypeDetail.h        | 122 +++++++
 mlir/lib/Dialect/Quant/IR/TypeParser.cpp      | 318 +++++++++++++++---
 .../Quant/Transforms/LowerQuantOps.cpp        |  67 ++++
 mlir/test/CAPI/quant.c                        | 124 +++++++
 mlir/test/Dialect/Quant/Bytecode/types.mlir   |   9 +
 mlir/test/Dialect/Quant/invalid.mlir          |  68 ++++
 mlir/test/Dialect/Quant/ops.mlir              |  19 ++
 .../Dialect/Quant/parse-uniform-invalid.mlir  | 100 +++++-
 mlir/test/Dialect/Quant/parse-uniform.mlir    |  18 +
 18 files changed, 1547 insertions(+), 89 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h
index a7d98dc3c1a775..dc0989e53344ea 100644
--- a/mlir/include/mlir-c/Dialect/Quant.h
+++ b/mlir/include/mlir-c/Dialect/Quant.h
@@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
 MLIR_CAPI_EXPORTED bool
 mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===---------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is a UniformQuantizedSubChannel.
+MLIR_CAPI_EXPORTED bool
+mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
+
+/// Creates a UniformQuantizedSubChannelType with the given parameters.
+///
+/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
+/// DenseElementsAttrs.  `quantizedDimensions` and `blockSizes`
+/// point to `blockSizeInfoLength` number of elements, describing respectively
+/// the quantization axis and corresponding block size.
+MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
+    unsigned flags, MlirType storageType, MlirType expressedType,
+    MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
+    intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
+    int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
+
+/// Returns the number of block sizes provided in type.
+MLIR_CAPI_EXPORTED intptr_t
+mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
+
+/// Returns the quantized dimension at the given position.
+MLIR_CAPI_EXPORTED int32_t
+mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
+                                                        intptr_t pos);
+
+/// Returns the block size at the given position.
+MLIR_CAPI_EXPORTED int64_t
+mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);
+
+/// Returns the scales of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);
+
+/// Returns the zero-points of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
+
 //===---------------------------------------------------------------------===//
 // CalibratedQuantizedType
 //===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index 791cb9de48d058..0d97889960019c 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
     encodes the necessary information for (lossy) round-trip conversion between
     an expressed and a stored value.
 
-    The `quant.uniform` type has two variants: per-layer quantization and
-    per-channel (or per-axis) quantization. In per-layer quantization, the
-    quantization information affects an entire tensor uniformly. Conversely, in
-    per-channel quantization, the data type encodes the specific tensor axis
-    that serves as the channel and includes quantization information for each
-    individual channel within the tensor. Below are the specific syntactic and
-    semantic considerations for each modality.
+    The `quant.uniform` type has three variants: per-layer quantization,
+    per-channel (or per-axis) quantization, and sub-channel (or blockwize)
+    quantization.  In per-layer quantization, the quantization information
+    affects an entire tensor uniformly. Conversely, in per-channel
+    quantization, the data type encodes the specific tensor axis that serves
+    as the channel and includes quantization information for each individual
+    channel within the tensor. Sub-channel quantization is a generalization
+    of per-tensor and per-channel quantization, where the quantization
+    parameters are defined for blocks of elements along one or more
+    dimensions of the tensor. Below are the specific syntactic and semantic
+    considerations for each modality.
 
 
     ### Per-layer quantization
@@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
     ```
     // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
     // floats. Dimension 1 of the tensor acts as the channel dimension. Its
-    // size 3 matches the number of provided scale values. Tensor elemenets at
+    // size 3 matches the number of provided scale values. Tensor elements at
     // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
     // 5.0, respectively.
     tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
@@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
     tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
     ```
 
+    ### Sub-channel quantization
+
+    Sub-channel quantization, also known as blockwise quantization, provides
+    finer-grained control than per-tensor or per-channel quantization. It
+    divides a tensor into blocks of elements, each with its own quantization
+    parameters (scale and zero point). This is particularly useful when
+    different regions of a tensor exhibit distinct value ranges.
+
+    The `!quant.uniform` type represents sub-channel quantization with the
+    following syntax:
+
+    ```
+    `!quant.uniform` `<`
+      storedType (`<` storageMin `:` storageMax `>`)? `:`
+      expressedType `:` blockSizeInfo
+      scaleZeroTensor `>`
+
+    blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
+    axisBlock ::= axis `:` blockSize
+    scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
+    scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
+    scaleZeroList  ::= scaleZero (`,` scaleZero)*
+    scaleZero ::= scale (`:` zeroPoint)?
+    
+    scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
+    scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
+    scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
+    ```
+
+    The `blockSize` field specifies the size of the blocks along dimension
+    `axis` of the tensor. The `scale` and `zeroPoint` fields specify the
+    quantization parameters for a particular block. Specifically, the tensor
+    element at position [i0...iN] uses
+    `scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
+    `scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
+    and zeroPoint respectively.
+
+    Here are some examples:
+
+    ```
+    // A 3x4 tensor of i8 values representing f32 values, quantized 
+    // along axis-0 and axis-1 with block sizes 1 and 2,
+    // respectively. As a result, the shape of the scales (or zero-points) will
+    // be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
+    // blocks along each axis. Tensor elements at positions 
+    // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+    // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+    // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+    // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+    // [2][0] and [2][1] use scale `s20` and zero point `z20`,
+    // [2][2] and [2][3] use scale `s21` and zero point `z21`,
+    tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
+      {{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>
+
+    // A 2D dynamically sized tensor contains u16 values
+    // representing f32 values. Since the shape of the quantization
+    // parameters (i.e. scales and zero-points) is given as [2,2] and
+    // the blocks-sizes are given as [1,2], the shape of the tensor is expected
+    // to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
+    // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+    // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+    // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+    // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+    tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
+      {{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
+    ```
 
     ## Per-axis quantization integrity
 
@@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
     respected in any context in which the `!quant.uniform` data type is used,
     such as the header of a `func.func` op, or the input of an arithmetic
     operation.
- 
+
     - A quantized type with per-channel quantization information must be the
       element type of a tensor container type, and may not occur directly as
       the data type of a scalar value.
@@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
     // Correct. The quantized type now includes 3 scale values, matching the
     // size of dimension 1 of the result tensor.
     %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+
+    ## Sub-channel quantization integrity
+
+    When type `!quant.uniform` contains sub-channel quantization information,
+    the following rules are enforced.  For efficiency, these rules are actively
+    enforced by the verifiers of `quant` dialect ops, but they must be
+    respected in any context in which the `!quant.uniform` data type is used,
+    such as the header of a `func.func` op, or the input of an arithmetic
+    operation.
+
+    - A quantized type with sub-channel quantization information must be the
+      element type of a tensor container type, and may not occur directly as
+      the data type of a scalar value.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+    // scalar type.
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
+
+    // Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
+    // in a `tensor` type.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+    ```
+
+    - The tensor containing the sub-channel quantized type must be ranked.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+    // unranked tensor type.
+    %result = quant.qcast %input : tensor<*xf32> to
+                tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+    ```
+
+    - The axis for which a block size is specified should be valid for a tensor
+    of a given rank. Block sizes can be specified for a subset of axes. 
+    Any unspecified block size for an axis i defaults to the tensor dimension
+    size of that axis (shape(tensor)[i]).
+
+    ```
+    // Incorrect. The block-size is specified for axis 2 which is greater than
+    // the rank of the tensor.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
+
+    // Incorrect. The block-size is specified for a negative axis.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
+
+    // Correct. The block size for axis 1 is skipped which should be assumed as
+    // 2, the dim-size of tensor at axis 1.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
+
+    // Correct. The block size for all the axes are skipped making the
+    // sub-channel type essentially a per-tensor type.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
+    ```
+
+    - Block size for a particular axis should be a positive integer and should
+      be less than the dimension size of the tensor along that axis.
+
+    ```
+    // Incorrect. The block size for axis 0 is -1.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
+
+    // Incorrect. The block size for axis 0 is 8 which is greater than the
+    // dimension size of tensor at axis 0 (which is 6).
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
+
+    // Correct. The block size for axis 0 is now 3.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+    ```
+
+    - shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
+      axis i in [0, 1, ..., rank(tensor)-1]].
+
+    ```
+    // Incorrect. The block size for axis 0 is 4 and the corresponding
+    // dimension size is 6 and 6 % 4 != 0.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
+
+    // Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+    ```
+
+    - shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
+
+    ```
+    // Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
+    // shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
+
+    // Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
+    // shape(scales) equals [6,2]/[3,2].
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
     ```
   }];
   let cppNamespace = "::mlir::quant";
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
index bd9cdf82382275..8c74dbef5d94a3 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
@@ -13,6 +13,7 @@
 #ifndef QUANT_BYTECODE
 #define QUANT_BYTECODE
 
+include "mlir/IR/BuiltinDialectBytecode.td"
 include "mlir/IR/BytecodeBase.td"
 
 def DoubleAPFloat:
@@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type
   }];
 }
 
+def UniformQuantizedSubChannelType
+    : DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
+          SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
+          Array<SignedVarIntList>:$quantizedDimensions,
+          Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
+          DenseElementsAttr:$zeroPoints)> {
+  // Note: builder order differs from bytecode.
+  let cBuilder = [{
+      get<$_resultType>(context, flags, storageType, expressedType, scales,
+        zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
+        [](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
+        storageTypeMin, storageTypeMax)
+  }];
+}
+
 /// This enum contains marker codes used to indicate which attribute is
 /// currently being decoded, and how it should be decoded. The order of these
 /// codes should generally be unchanged, as any changes will inevitably break
 /// compatibility with older bytecode.
 
 def QuantDialectTypes : DialectTypes<"Quant"> {
-  let elems = [
-    ReservedOrDead,
-    AnyQuantizedType,
-    AnyQuantizedTypeWithExpressedType,
-    CalibratedQuantizedType,
-    UniformQuantizedType,
-    UniformQuantizedPerAxisType
-  ];
+  let elems = [ReservedOrDead, AnyQuantizedType,
+               AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
+               UniformQuantizedType, UniformQuantizedPerAxisType,
+               UniformQuantizedSubChannelType];
 }
 
-#endif // QUANT_BYTECODE
\ No newline at end of file
+#endif // QUANT_BYTECODE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 43440ba623b9c1..44062fe376ec0d 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -23,6 +23,7 @@ namespace detail {
 
 struct QuantizedTypeStorage;
 struct AnyQuantizedTypeStorage;
+struct UniformQuantizedSubChannelTypeStorage;
 struct UniformQuantizedTypeStorage;
 struct UniformQuantizedPerAxisTypeStorage;
 struct CalibratedQuantizedTypeStorage;
@@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType
   }
 };
 
+/// Represents sub-channel (also known as blockwise quantization).
+///
+/// Syntax synopsis:
+///   UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
+///       storageType ('<' storageMin ':' storageMax '>')? ':'
+///       expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
+///   BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
+///   AxisBlock ::= AxisSpec ':' BlockSizeSpec
+///   ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
+///   ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
+///   ScaleZeroList  ::= ScaleZero (',' ScaleZero)*
+///   ScaleZero ::= Scale (':' ZeroPoint)?
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   AxisSpec: An integer value
+///   BlockSizeSpec: An integer value
+///   Scale: An attribute (usually floating-point value)
+///   ZeroPoint: An attribute (usually integer value)
+class UniformQuantizedSubChannelType
+    : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
+                            detail::UniformQuantizedSubChannelTypeStorage> {
+public:
+  using Base::Base;
+  using Base::getChecked;
+
+  static constexpr StringLiteral name = "quant.uniform_sub_channel";
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedSubChannelType
+  get(unsigned flags, Type storageType, Type expressedType,
+      DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+      ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedSubChannelType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, DenseElementsAttr scales,
+             DenseElementsAttr zeroPoints,
+             ArrayRef<int32_t> quantizedDimensions,
+             ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+             int64_t storageTypeMax);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+                   ArrayRef<int32_t> quantizedDimensions,
+                   ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
+
+  /// Gets the quantization scales. The scales are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the scales tensor
+  /// is determined by the number of blocks along the corresponding dimension in
+  /// the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
+  /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The scale value for a specific element in the quantized data tensor at
+  /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
+  /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
+  DenseElementsAttr getScales() const;
+
+  /// Gets the quantization zero-points. The zero-points are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the zero-point
+  /// tensor is determined by the number of blocks along the corresponding
+  /// dimension in the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
+  /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The zero-point value for a specific element in the quantized data tensor
+  /// at position [i0, i1, ..., iR-1] is determined by accessing the
+  /// corresponding element in the zero-point tensor at position [i0/B0, i1/B1,
+  /// ..., iR-1/BR-1].
+  DenseElementsAttr getZeroPoints() const;
+
+  /// Gets the quantized dimensions. Each element in the returned list
+  /// represents an axis of the quantized data tensor that has a specified block
+  /// size. The order of elements corresponds to the order of block sizes
+  /// returned by `getBlockSizes()`.
+  ///
+  /// It means that the data tensor is quantized along the `i`-th dimension in
+  /// the returned list using the `i`-th block size from `getBlockSizes()`.
+  ///
+  /// Note that the type expression does not have to specify the block size for
+  /// all axes in the data tensor. Any unspecified block size for an axis `i`
+  /// defaults to the tensor dimension size of that axis.
+  ///
+  /// For example, for a quantized type:
+  /// `tensor<8x4x2x!quant.uniform<i8:f32:{1:2, 0:8}, {{1.0, 2.0}, {3.0, 4.0}}>`
+  ///
+  /// `getQuantizedDimensions()` returns [1, 0].
+  /// `getBlockSizes()` returns [2, 8].
+  ///
+  /// This indicates that:
+  ///  * Axis 1 (second dimension) is quantized with a block size of 2.
+  ///  * Axis 0 (first dimension) is quantized with a block size of 8.
+  ///  Since axis 2 is not specified, it implicitly has a block size equal to
+  ///  the size of the third dimension (which is 2 in this case).
+  ArrayRef<int32_t> getQuantizedDimensions() const;
+
+  /// Gets the block sizes for the quantized dimensions. The `i`-th element in
+  /// the returned list corresponds to the block size for the `i`-th dimension
+  /// in the list returned by `getQuantizedDimensions()`.
+  ///
+  /// See `getQuantizedDimensions()` for more details and examples.
+  ArrayRef<int64_t> getBlockSizes() const;
+
+  /// Gets the block size information. This returns a list of pairs, where each
+  /// pair represents a quantized dimension and its corresponding block size.
+  ///
+  /// For example, for the type:
+  ///  `tensor<8x4x!quant.uniform<i8:f32:{1:2, 0:8}, {{2.0, 3.0}}>`
+  ///
+  /// This method returns:
+  ///  `[(1, 2), (0, 8)]`
+  ///
+  /// This list indicates that axis 1 has a block size of 2, and axis 0 has a
+  /// block size of 8.
+  const SmallVector<std::pair<int32_t, int64_t>> getBlockSizeInfo() const;
+};
+
 /// A quantized type that infers its range from given min/max values.
 ///
 /// Typical syntax:
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 9a871f2c122d12..44a596caa24a65 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -285,6 +285,79 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Fixed point values are real numbers divided by a scale.");
 
+  //===-------------------------------------------------------------------===//
+  // UniformQuantizedSubChannelType
+  //===-------------------------------------------------------------------===//
+  auto uniformQuantizedSubChannelType = mlir_type_subclass(
+      m, "UniformQuantizedSubChannelType",
+      mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
+  uniformQuantizedSubChannelType.def_classmethod(
+      "get",
+      [](py::object cls, unsigned flags, MlirType storageType,
+         MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
+         std::vector<int32_t> quantizedDimensions,
+         std::vector<int64_t> blockSizes, int64_t storageTypeMin,
+         int64_t storageTypeMax) {
+        return cls(mlirUniformQuantizedSubChannelTypeGet(
+            flags, storageType, expressedType, scales, zeroPoints,
+            static_cast<intptr_t>(blockSizes.size()),
+            quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
+            storageTypeMax));
+      },
+      "Gets an instance of UniformQuantizedSubChannel in the same context as "
+      "the provided storage type.",
+      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
+      py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
+      py::arg("quantized_dimensions"), py::arg("block_sizes"),
+      py::arg("storage_type_min"), py::arg("storage_type_max"));
+  uniformQuantizedSubChannelType.def_property_readonly(
+      "quantized_dimensions",
+      [](MlirType type) {
+        intptr_t nDim =
+            mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+        std::vector<int32_t> quantizedDimensions;
+        quantizedDimensions.reserve(nDim);
+        for (intptr_t i = 0; i < nDim; ++i) {
+          quantizedDimensions.push_back(
+              mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
+        }
+        return quantizedDimensions;
+      },
+      "Gets the quantized dimensions. Each element in the returned list "
+      "represents an axis of the quantized data tensor that has a specified "
+      "block size. The order of elements corresponds to the order of block "
+      "sizes returned by 'block_sizes' method. It means that the data tensor "
+      "is quantized along the i-th dimension in the returned list using the "
+      "i-th block size from block_sizes method.");
+  uniformQuantizedSubChannelType.def_property_readonly(
+      "block_sizes",
+      [](MlirType type) {
+        intptr_t nDim =
+            mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
+        std::vector<int64_t> blockSizes;
+        blockSizes.reserve(nDim);
+        for (intptr_t i = 0; i < nDim; ++i) {
+          blockSizes.push_back(
+              mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
+        }
+        return blockSizes;
+      },
+      "Gets the block sizes for the quantized dimensions. The i-th element in "
+      "the returned list corresponds to the block size for the i-th dimension "
+      "in the list returned by quantized_dimensions method.");
+  uniformQuantizedSubChannelType.def_property_readonly(
+      "scales",
+      [](MlirType type) -> MlirAttribute {
+        return mlirUniformQuantizedSubChannelTypeGetScales(type);
+      },
+      "The scales of the quantized type.");
+  uniformQuantizedSubChannelType.def_property_readonly(
+      "zero_points",
+      [](MlirType type) -> MlirAttribute {
+        return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
+      },
+      "The zero points of the quantized type.");
+
   //===-------------------------------------------------------------------===//
   // CalibratedQuantizedType
   //===-------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index c94dbb5692fdb0..88648497895ab7 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -194,6 +194,61 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
   return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
 }
 
+//===---------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
+  return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
+}
+
+MlirType mlirUniformQuantizedSubChannelTypeGet(
+    unsigned flags, MlirType storageType, MlirType expressedType,
+    MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
+    int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  auto scales = dyn_cast<mlir::DenseElementsAttr>(unwrap(scalesAttr));
+  auto zeroPoints = dyn_cast<mlir::DenseElementsAttr>(unwrap(zeroPointsAttr));
+
+  if (!scales || !zeroPoints) {
+    return {};
+  }
+
+  return wrap(quant::UniformQuantizedSubChannelType::get(
+      flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints,
+      llvm::ArrayRef<int32_t>(quantizedDimensions, nDims),
+      llvm::ArrayRef<int64_t>(blockSizes, nDims), storageTypeMin,
+      storageTypeMax));
+}
+
+intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) {
+  return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+      .getBlockSizes()
+      .size();
+}
+
+int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
+                                                                intptr_t pos) {
+  return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+      .getQuantizedDimensions()[pos];
+}
+
+int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type,
+                                                       intptr_t pos) {
+  return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+      .getBlockSizes()[pos];
+}
+
+MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) {
+  return wrap(
+      cast<quant::UniformQuantizedSubChannelType>(unwrap(type)).getScales());
+}
+
+MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) {
+  return wrap(cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
+                  .getZeroPoints());
+}
+
 //===---------------------------------------------------------------------===//
 // CalibratedQuantizedType
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
index 6a4ac310eb0524..44ec0c517d5611 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index c584903f3a15de..683aa26a2d0621 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -17,7 +17,6 @@
 
 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
-
 namespace mlir {
 namespace quant {
 
@@ -25,22 +24,17 @@ namespace {
 
 // Verify the integrity of per-axis quantization information, if present.
 //
-// - quantizedType
-//   Any quantized type. Any quantized type with no per-axis quantization is
-//   ignored.
+// - uniformQuantizedPerAxisType
+//   A quantized type with per-axis quantization.
 //
 // - containerType
 //   Original input or result type of the operation using the provided quantized
 //   type. Used to ensure that the quantized type appears within a tensor and
 //   that the tensor is compatible with per-axis quantization information.
 //
-LogicalResult verifyPerAxisQuantization(Operation *op,
-                                        QuantizedType quantizedType,
-                                        Type containerType) {
-  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
-  if (!quantizedPerAxisType)
-    return success();
-
+LogicalResult verifyPerAxisQuantization(
+    Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
+    Type containerType) {
   auto tensorType = dyn_cast<TensorType>(containerType);
   if (!tensorType)
     return op->emitError("scalar types may not use per-axis quantization");
@@ -48,19 +42,112 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
   if (!tensorType.hasRank())
     return success();
 
-  int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
-  if (quantizedDimension >= tensorType.getRank())
+  int32_t quantizedDimension =
+      uniformQuantizedPerAxisType.getQuantizedDimension();
+  if ((int64_t)quantizedDimension >= tensorType.getRank())
     return op->emitError("quantized dimension must be less than tensor rank");
 
   int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
   if (quantizedDimensionSize != ShapedType::kDynamic &&
-      quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+      quantizedDimensionSize !=
+          (int64_t)uniformQuantizedPerAxisType.getScales().size())
     return op->emitError(
         "quantized dimension size does not match number of scales");
 
   return success();
 }
 
+// Verifies that the sub-channel quantization parameters are consistent with
+// the given container type. The function checks the following:
+//
+// - The container type must be a ranked tensor type.
+// - Each quantized dimension must be less than the rank of the tensor.
+// - The size of each dimension at the quantized dimension must be divisible
+//    by the corresponding block size.
+// - The scale dimension size at each axis index should match the tensor
+//    dimension at the index divided by the corresponding block size.
+//
+// The `uniformQuantizedSubChannelType` argument provides the sub-channel
+// quantization parameters, and the `containerType` argument specifies the
+// type of the container holding the quantized data.
+//
+LogicalResult verifySubChannelQuantization(
+    Operation *op,
+    UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
+    Type containerType) {
+  auto tensorType = dyn_cast<TensorType>(containerType);
+  if (!tensorType)
+    return op->emitError("scalar types may not use sub-channel quantization");
+
+  if (!tensorType.hasRank())
+    return op->emitError(
+        "tensor containing the sub-channel quantized type must be ranked");
+
+  const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
+      uniformQuantizedSubChannelType.getBlockSizeInfo();
+  auto shape = tensorType.getShape();
+
+  // The dimension size of scale for an axis which is not specified as quantized
+  // dimension should be 1.
+  SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
+  for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
+    if (quantizedDimension >= tensorType.getRank())
+      return op->emitError()
+             << "quantized dimension " << quantizedDimension
+             << " must be less than tensor rank " << tensorType.getRank();
+    if (!tensorType.isDynamicDim(quantizedDimension) &&
+        tensorType.getDimSize(quantizedDimension) % blockSize != 0)
+      return op->emitError()
+             << "tensor dimension size "
+             << tensorType.getDimSize(quantizedDimension) << " at axis "
+             << quantizedDimension
+             << " must be divisible by the corresponding block size "
+             << blockSize;
+    if (tensorType.isDynamicDim(quantizedDimension))
+      expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
+    else
+      expectedScaleShape[quantizedDimension] =
+          tensorType.getDimSize(quantizedDimension) / blockSize;
+  }
+
+  // Block sizes must be greater than 0 and divide the corresponding dimension
+  // size. While a block size b must be less than or equal to the corresponding
+  // dimension size d, this constraint is implicitly enforced by requiring that
+  // d % b == 0 when d != 0.
+  //
+  // However, a problem arises when d = 0.  The divisibility constraint allows b
+  // to be any value, potentially violating the requirement that b <= d.
+  // Furthermore, if b is unspecified (implicitly equal to d), it violates the
+  // constraint that b > 0.
+  //
+  // Therefore, we explicitly disallow the case where d = 0 to maintain
+  // consistency and avoid these issues.
+  if (llvm::find(tensorType.getShape(), 0) != tensorType.getShape().end()) {
+    return op->emitError() << "tensor dimension size of zero is not allowed "
+                              "with sub-channel quantization";
+  }
+
+  auto scaleShape =
+      uniformQuantizedSubChannelType.getScales().getType().getShape();
+  if (scaleShape.size() != shape.size()) {
+    return op->emitError() << "Rank of scales " << scaleShape.size()
+                           << " must match "
+                           << "the rank of the tensor " << shape.size();
+  }
+
+  for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
+    if (expectedScaleShape[index] != ShapedType::kDynamic &&
+        expectedScaleShape[index] != scaleShape[index])
+      return op->emitError() << "dimension size " << scaleDim
+                             << " of scales tensor at axis " << index
+                             << " should match (tensor dimension at axis / "
+                                "block sizes at axis) = "
+                             << expectedScaleShape[index];
+  }
+
+  return success();
+}
+
 // Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
 //
 // - quantizedType
@@ -81,11 +168,19 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
         "expressed type in quantized type expected to match float type");
 
   // Veriy integrity of per-axis quantization information, if present.
-  return verifyPerAxisQuantization(op, quantizedType, containerType);
-}
+  if (auto quantizedPerAxisType =
+          dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
+    return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
+  } else if (auto quantizedSubChannelType =
+                 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
+    return verifySubChannelQuantization(op, quantizedSubChannelType,
+                                        containerType);
+  }
 
-}  // namespace
+  return success();
+}
 
+} // namespace
 
 //===----------------------------------------------------------------------===//
 // Dialect
@@ -93,7 +188,7 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
 
 void QuantDialect::initialize() {
   addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
-           UniformQuantizedPerAxisType>();
+           UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
@@ -101,7 +196,6 @@ void QuantDialect::initialize() {
   detail::addBytecodeInterface(this);
 }
 
-
 //===----------------------------------------------------------------------===//
 // DequantizeCastOp
 //===----------------------------------------------------------------------===//
@@ -130,7 +224,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() {
   return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
 }
 
-
 //===----------------------------------------------------------------------===//
 // QuantizeCastOp
 //===----------------------------------------------------------------------===//
@@ -160,7 +253,6 @@ QuantizedType QuantizeCastOp::getQuantizedType() {
   return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
 }
 
-
 //===----------------------------------------------------------------------===//
 // StorageCastOp
 //===----------------------------------------------------------------------===//
@@ -175,7 +267,16 @@ LogicalResult StorageCastOp::verify() {
   // Verify integrity of per-axis quantization information, if available. While
   // the quantization type may appear in the input or the result, their tensor
   // shapes are guaranteed to be identical at this point.
-  return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
+  if (auto quantizedPerAxisType =
+          dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
+    return verifyPerAxisQuantization(*this, quantizedPerAxisType,
+                                     getInput().getType());
+  else if (auto quantizedSunChannelType =
+               dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
+    return verifySubChannelQuantization(*this, quantizedSunChannelType,
+                                        getInput().getType());
+
+  return success();
 }
 
 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
@@ -205,10 +306,8 @@ QuantizedType StorageCastOp::getQuantizedType() {
   return cast<QuantizedType>(resultScalarType);
 }
 
-
 } // namespace quant
 } // namespace mlir
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
-
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 7c0d3696486515..9b8eec609b0396 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) {
   return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
 }
 
-}  // namespace
+} // namespace
 
 unsigned QuantizedType::getFlags() const {
   return static_cast<ImplType *>(impl)->flags;
@@ -410,6 +410,123 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
   return getImpl()->quantizedDimension;
 }
 
+UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get(
+    unsigned flags, Type storageType, Type expressedType,
+    DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+    ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  return Base::get(storageType.getContext(), flags, storageType, expressedType,
+                   scales, zeroPoints, quantizedDimensions, blockSizes,
+                   storageTypeMin, storageTypeMax);
+}
+
+UniformQuantizedSubChannelType UniformQuantizedSubChannelType::getChecked(
+    function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+    Type storageType, Type expressedType, DenseElementsAttr scales,
+    DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
+    ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  return Base::getChecked(emitError, storageType.getContext(), flags,
+                          storageType, expressedType, scales, zeroPoints,
+                          quantizedDimensions, blockSizes, storageTypeMin,
+                          storageTypeMax);
+}
+
+LogicalResult UniformQuantizedSubChannelType::verifyInvariants(
+    function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+    Type storageType, Type expressedType, DenseElementsAttr scales,
+    DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
+    ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
+    return failure();
+  }
+
+  // Uniform quantization requires fully expressed parameters, including
+  // expressed type.
+  if (!expressedType)
+    return emitError() << "uniform quantization requires expressed type";
+
+  // Verify that the expressed type is floating point.
+  // If this restriction is ever eliminated, the parser/printer must be
+  // extended.
+  if (!llvm::isa<FloatType>(expressedType))
+    return emitError() << "expressed type must be floating point";
+
+  // Verify scale type to match expressedType.
+  if (scales.getType().getElementType() != expressedType) {
+    return emitError() << "type of scale values "
+                       << scales.getType().getElementType()
+                       << " must match the expressed type " << expressedType;
+  }
+
+  // Verify zero-point type to match storageType.
+  if (zeroPoints.getType().getElementType() != storageType) {
+    return emitError() << "type of zero point values "
+                       << zeroPoints.getType().getElementType()
+                       << " must match the storage type " << storageType;
+  }
+
+  // Ensure that the shape of scales and zeroPoints match.
+  if (scales.getType().getShape() != zeroPoints.getType().getShape())
+    return emitError() << "shape of scales and zeroPoints ("
+                       << scales.getType().getShape() << " vs "
+                       << zeroPoints.getType().getShape() << ") does not match";
+
+  // Ensure that the number of quantized-dimensions and block-sizes match.
+  if (quantizedDimensions.size() != blockSizes.size())
+    return emitError() << "number of quantized dimensions and block sizes ("
+                       << scales.size() << " vs " << zeroPoints.size()
+                       << ") does not match";
+
+  // Verify quantized dimension.
+  for (auto quantizedDimension : quantizedDimensions) {
+    if (quantizedDimension < 0)
+      return emitError() << "illegal quantized dimension: "
+                         << quantizedDimension;
+  }
+
+  // Verify block sizes.
+  for (auto blockSize : blockSizes) {
+    if (blockSize <= 0)
+      return emitError() << "illegal block size: " << blockSize;
+  }
+
+  return success();
+}
+
+DenseElementsAttr UniformQuantizedSubChannelType::getScales() const {
+  return getImpl()->getScales();
+}
+
+DenseElementsAttr UniformQuantizedSubChannelType::getZeroPoints() const {
+  return getImpl()->getZeroPoints();
+}
+
+ArrayRef<int32_t>
+UniformQuantizedSubChannelType::getQuantizedDimensions() const {
+  return getImpl()->getQuantizedDimensions();
+}
+
+ArrayRef<int64_t> UniformQuantizedSubChannelType::getBlockSizes() const {
+  return getImpl()->getBlockSizes();
+}
+
+const SmallVector<std::pair<int32_t, int64_t>>
+UniformQuantizedSubChannelType::getBlockSizeInfo() const {
+  SmallVector<std::pair<int32_t, int64_t>> result;
+  result.reserve(getQuantizedDimensions().size());
+
+  for (auto [dim, size] :
+       llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
+    result.push_back({dim, size});
+  }
+
+  return result;
+}
+
 CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
                                                      double min, double max) {
   return Base::get(expressedType.getContext(), expressedType, min, max);
diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h
index ef098811927cda..bb38b1a2a91e28 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h
@@ -9,6 +9,7 @@
 #ifndef TYPE_DETAIL_H_
 #define TYPE_DETAIL_H_
 
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeSupport.h"
 #include "mlir/IR/Types.h"
@@ -253,6 +254,127 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
   int32_t quantizedDimension;
 };
 
+struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(unsigned flags, Type storageType, Type expressedType,
+          DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+          ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+          int64_t storageTypeMin, int64_t storageTypeMax)
+        : flags(flags), storageType(storageType), expressedType(expressedType),
+          scales(scales), zeroPoints(zeroPoints),
+          quantizedDimensions(quantizedDimensions), blockSizes(blockSizes),
+          storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+    /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+    unsigned flags;
+
+    // Integral type for the storage point representation.
+    Type storageType;
+
+    // Floating point type that the quantized type approximates.
+    Type expressedType;
+
+    DenseElementsAttr scales;
+    DenseElementsAttr zeroPoints;
+    ArrayRef<int32_t> quantizedDimensions;
+    ArrayRef<int64_t> blockSizes;
+    int64_t storageTypeMin;
+    int64_t storageTypeMax;
+
+    DenseElementsAttr getScales() const { return scales; }
+
+    DenseElementsAttr getZeroPoints() const { return zeroPoints; }
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+             lhs.expressedType == rhs.expressedType &&
+             lhs.scales == rhs.scales && lhs.zeroPoints == rhs.zeroPoints &&
+             lhs.quantizedDimensions == rhs.quantizedDimensions &&
+             lhs.blockSizes == rhs.blockSizes &&
+             lhs.storageTypeMin == rhs.storageTypeMin &&
+             lhs.storageTypeMax == rhs.storageTypeMax;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      // Hash the scalar attributes.
+      unsigned hash = llvm::hash_combine(flags, storageType, expressedType,
+                                         storageTypeMin, storageTypeMax);
+
+      // Hash the scales.
+      for (auto scaleAttr : scales.getValues<APFloat>()) {
+        hash = llvm::hash_combine(
+            hash, llvm::bit_cast<int64_t>(scaleAttr.convertToDouble()));
+      }
+
+      // Hash the zero points.  (Assumed to be integers, adjust if needed).
+      for (auto zeroPointAttr : zeroPoints.getValues<APInt>()) {
+        hash = llvm::hash_combine(hash, zeroPointAttr.getSExtValue());
+      }
+
+      // Hash the quantized dimensions and block sizes.
+      hash = llvm::hash_combine(
+          hash,
+          llvm::hash_combine_range(quantizedDimensions.begin(),
+                                   quantizedDimensions.end()),
+          llvm::hash_combine_range(blockSizes.begin(), blockSizes.end()));
+
+      return hash;
+    }
+  };
+
+  // We pass scales and zeroPoints in directly rather than relying on KeyTy
+  // because we have to create new reallocated versions in `construct` below.
+  UniformQuantizedSubChannelTypeStorage(const KeyTy &key,
+                                        DenseElementsAttr scales,
+                                        DenseElementsAttr zeroPoints,
+                                        ArrayRef<int32_t> quantizedDimensions,
+                                        ArrayRef<int64_t> blockSizes)
+      : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+                             key.storageTypeMin, key.storageTypeMax),
+        scales(scales), zeroPoints(zeroPoints),
+        quantizedDimensions(quantizedDimensions), blockSizes(blockSizes) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static UniformQuantizedSubChannelTypeStorage *
+  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+    DenseElementsAttr scales = key.scales;
+    DenseElementsAttr zeroPoints = key.zeroPoints;
+    ArrayRef<int32_t> quantizedDimensions =
+        allocator.copyInto(key.quantizedDimensions);
+    ArrayRef<int64_t> blockSizes = allocator.copyInto(key.blockSizes);
+    return new (allocator.allocate<UniformQuantizedSubChannelTypeStorage>())
+        UniformQuantizedSubChannelTypeStorage(key, scales, zeroPoints,
+                                              quantizedDimensions, blockSizes);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+  DenseElementsAttr getScales() const { return scales; }
+
+  DenseElementsAttr getZeroPoints() const { return zeroPoints; }
+
+  ArrayRef<int32_t> getQuantizedDimensions() const {
+    return quantizedDimensions;
+  }
+
+  ArrayRef<int64_t> getBlockSizes() const { return blockSizes; }
+
+  DenseElementsAttr scales;
+  DenseElementsAttr zeroPoints;
+  ArrayRef<int32_t> quantizedDimensions;
+  ArrayRef<int64_t> blockSizes;
+};
+
 struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
   struct KeyTy {
     KeyTy(Type expressedType, double min, double max)
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 851763d8942e83..fb6f5d950acc9e 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -159,38 +159,173 @@ static Type parseAnyType(DialectAsmParser &parser) {
       typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
 }
 
-static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
+/// Checks if the given scale value is within the valid range of the expressed
+/// type. The `expressedType` argument is the floating-point type used for
+/// expressing the quantized values, and `scale` is the double value to check.
+LogicalResult
+isScaleInExpressedTypeRange(function_ref<InFlightDiagnostic()> emitError,
+                            Type expressedType, double scale) {
+  auto floatType = cast<FloatType>(expressedType);
+  double minScale =
+      APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
+  double maxScale =
+      APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
+  if (scale < minScale || scale > maxScale)
+    return emitError() << "scale " << scale << " out of expressed type range ["
+                       << minScale << ", " << maxScale << "]";
+  return success();
+}
+
+/// Parses a quantization parameter, which is either a scale value (float) or a
+/// scale-zero point pair (float:integer). `expressedType`, expressing the type
+/// of scale values, is used to validate the scale. The parsed scale and zero
+/// point (if any) are stored in `scale` and `zeroPoint`.
+static ParseResult parseQuantParams(DialectAsmParser &parser,
+                                    Type expressedType, double &scale,
                                     int64_t &zeroPoint) {
-  // scale[:zeroPoint]?
-  // scale.
-  if (parser.parseFloat(scale))
+
+  if (parser.parseFloat(scale)) {
     return failure();
+  }
+
+  if (failed(isScaleInExpressedTypeRange(
+          [&]() { return parser.emitError(parser.getCurrentLocation()); },
+          expressedType, scale))) {
+    return failure();
+  }
 
-  // zero point.
   zeroPoint = 0;
   if (failed(parser.parseOptionalColon())) {
-    // Default zero point.
     return success();
   }
 
   return parser.parseInteger(zeroPoint);
 }
 
+/// Parses block size information for sub-channel quantization, assuming the
+/// leading '{' has already been parsed. The block size information is provided
+/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
+///
+/// The parsed axis indices are stored in `quantizedDimensions`, and the
+/// corresponding block sizes are stored in `blockSizes`.
+static ParseResult
+parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser,
+                              SmallVectorImpl<int32_t> &quantizedDimensions,
+                              SmallVectorImpl<int64_t> &blockSizes) {
+  // Empty block-sizes info.
+  if (succeeded(parser.parseOptionalRBrace())) {
+    return success();
+  }
+
+  auto parseBlockSizeElements = [&]() -> ParseResult {
+    quantizedDimensions.resize(quantizedDimensions.size() + 1);
+    blockSizes.resize(blockSizes.size() + 1);
+    if (parser.parseInteger(quantizedDimensions.back()) ||
+        parser.parseColon() || parser.parseInteger(blockSizes.back()))
+      return failure();
+    return success();
+  };
+
+  if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
+      parser.parseRBrace()) {
+    return failure();
+  }
+
+  return success();
+}
+
+/// Parses a bracketed list of quantization parameters, returning the dimensions
+/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
+/// to the dimensions of the sub-tensors. This function assumes that the initial
+/// left brace has already been parsed. For example:
+///
+///   parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
+///       dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
+///
+///   parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
+///       dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
+///       9]
+///
+/// This function expects all sub-tensors to have the same rank.
+static ParseResult
+parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
+                               SmallVectorImpl<double> &scales,
+                               SmallVectorImpl<int64_t> &zeroPoints,
+                               SmallVectorImpl<int64_t> &dims) {
+  auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
+                       const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
+    if (prevDims == newDims)
+      return success();
+    return parser.emitError(parser.getCurrentLocation())
+           << "tensor literal is invalid; ranks are not consistent "
+              "between elements";
+  };
+
+  bool first = true;
+  SmallVector<int64_t, 4> newDims;
+  unsigned size = 0;
+
+  auto parseOneElement = [&]() -> ParseResult {
+    SmallVector<int64_t, 4> thisDims;
+    if (succeeded(parser.parseOptionalLBrace())) {
+      if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
+                                         zeroPoints, thisDims))
+        return failure();
+    } else {
+      zeroPoints.resize(zeroPoints.size() + 1);
+      scales.resize(scales.size() + 1);
+      if (parseQuantParams(parser, expressedType, scales.back(),
+                           zeroPoints.back())) {
+        return failure();
+      }
+    }
+    ++size;
+    if (!first)
+      return checkDims(newDims, thisDims);
+    newDims = thisDims;
+    first = false;
+    return success();
+  };
+
+  if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
+    return failure();
+  }
+
+  // Return the sublists' dimensions with 'size' prepended.
+  dims.clear();
+  dims.push_back(size);
+  dims.append(newDims.begin(), newDims.end());
+
+  return success();
+}
+
 /// Parses a UniformQuantizedType.
 ///
 ///   uniform_type ::= uniform_per_layer
 ///                  | uniform_per_axis
+///                  | uniform_sub_channel
 ///   uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
 ///                          `,` scale-zero `>`
 ///   uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
-///                        axis-spec `,` scale-zero-list `>`
+///                        axis-spec `,` `{` scale-zero-list `}` `>`
+///   uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
+///                        block-size-info `,` scale-zero-tensor `>`
 ///   storage-spec ::= storage-type (`<` storage-range `>`)?
 ///   storage-range ::= integer-literal `:` integer-literal
 ///   storage-type ::= (`i` | `u`) integer-literal
 ///   expressed-type-spec ::= `:` `f` integer-literal
 ///   axis-spec ::= `:` integer-literal
-///   scale-zero ::= float-literal `:` integer-literal
-///   scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
+///   scale-zero ::= scale (`:` zero-point)?
+///   scale ::= float-literal
+///   zero-point ::= integer-literal
+///   scale-zero-list ::= scale-zero (`,` scale-zero)*
+///   block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
+///   axis-block ::= axis-spec `:` block-size-spec
+///   block-size-spec ::= integer-literal
+///   scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
+///   scale-zero-dense-exp ::= `{`
+///     scale-zero-tensor (`,` scale-zero-tensor)*
+///   `}`
 static Type parseUniformType(DialectAsmParser &parser) {
   IntegerType storageType;
   FloatType expressedType;
@@ -198,7 +333,9 @@ static Type parseUniformType(DialectAsmParser &parser) {
   int64_t storageTypeMin;
   int64_t storageTypeMax;
   bool isPerAxis = false;
-  int32_t quantizedDimension;
+  bool isSubChannel = false;
+  SmallVector<int32_t, 1> quantizedDimensions;
+  SmallVector<int64_t, 1> blockSizes;
   SmallVector<double, 1> scales;
   SmallVector<int64_t, 1> zeroPoints;
 
@@ -228,11 +365,22 @@ static Type parseUniformType(DialectAsmParser &parser) {
     return nullptr;
   }
 
-  // Optionally parse quantized dimension for per-axis quantization.
+  // Optionally parse quantized dimension for per-axis or sub-channel
+  // quantization.
   if (succeeded(parser.parseOptionalColon())) {
-    if (parser.parseInteger(quantizedDimension))
-      return nullptr;
-    isPerAxis = true;
+    if (succeeded(parser.parseOptionalLBrace())) {
+      isSubChannel = true;
+      if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
+                                        blockSizes)) {
+        return nullptr;
+      }
+    } else {
+      isPerAxis = true;
+      quantizedDimensions.resize(1);
+      if (parser.parseInteger(quantizedDimensions.back())) {
+        return nullptr;
+      }
+    }
   }
 
   // Comma leading into range_spec.
@@ -240,26 +388,21 @@ static Type parseUniformType(DialectAsmParser &parser) {
     return nullptr;
   }
 
-  // Parameter specification.
-  // For per-axis, ranges are in a {} delimitted list.
-  if (isPerAxis) {
-    if (parser.parseLBrace()) {
-      return nullptr;
-    }
-  }
-
-  // Parse scales/zeroPoints.
-  SMLoc scaleZPLoc = parser.getCurrentLocation();
-  do {
-    scales.resize(scales.size() + 1);
+  // Quantization parameter (scales/zeroPoints) specification.
+  bool isPerTensor = !isPerAxis && !isSubChannel;
+  SmallVector<int64_t> dims;
+  if (isPerTensor) {
     zeroPoints.resize(zeroPoints.size() + 1);
-    if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
+    scales.resize(scales.size() + 1);
+    if (parseQuantParams(parser, expressedType, scales.back(),
+                         zeroPoints.back())) {
       return nullptr;
     }
-  } while (isPerAxis && succeeded(parser.parseOptionalComma()));
 
-  if (isPerAxis) {
-    if (parser.parseRBrace()) {
+  } else {
+    if (parser.parseLBrace() ||
+        parseQuantParamListUntilRBrace(parser, expressedType, scales,
+                                       zeroPoints, dims)) {
       return nullptr;
     }
   }
@@ -268,19 +411,30 @@ static Type parseUniformType(DialectAsmParser &parser) {
     return nullptr;
   }
 
-  if (!isPerAxis && scales.size() > 1) {
-    return (parser.emitError(scaleZPLoc,
-                             "multiple scales/zeroPoints provided, but "
-                             "quantizedDimension wasn't specified"),
-            nullptr);
-  }
-
   if (isPerAxis) {
-    ArrayRef<double> scalesRef(scales.begin(), scales.end());
-    ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
     return parser.getChecked<UniformQuantizedPerAxisType>(
+        typeFlags, storageType, expressedType, scales, zeroPoints,
+        quantizedDimensions[0], storageTypeMin, storageTypeMax);
+  } else if (isSubChannel) {
+    SmallVector<APFloat> apFloatScales =
+        llvm::to_vector(llvm::map_range(scales, [&](double scale) -> APFloat {
+          APFloat apFloatScale(scale);
+          bool unused;
+          apFloatScale.convert(expressedType.getFloatSemantics(),
+                               APFloat::rmNearestTiesToEven, &unused);
+          return apFloatScale;
+        }));
+    SmallVector<APInt> apIntZeroPoints = llvm::to_vector(
+        llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
+          return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
+        }));
+    auto scalesRef = mlir::DenseElementsAttr::get(
+        RankedTensorType::get(dims, expressedType), apFloatScales);
+    auto zeroPointsRef = mlir::DenseElementsAttr::get(
+        RankedTensorType::get(dims, storageType), apIntZeroPoints);
+    return parser.getChecked<UniformQuantizedSubChannelType>(
         typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
-        quantizedDimension, storageTypeMin, storageTypeMax);
+        quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
   }
 
   return parser.getChecked<UniformQuantizedType>(
@@ -360,6 +514,19 @@ static void printQuantParams(double scale, int64_t zeroPoint,
   }
 }
 
+static void
+printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
+                   DialectAsmPrinter &out) {
+  out << "{";
+  llvm::interleave(
+      llvm::seq<size_t>(0, blockSizeInfo.size()), out,
+      [&](size_t index) {
+        out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
+      },
+      ",");
+  out << "}";
+}
+
 /// Helper that prints a AnyQuantizedType.
 static void printAnyQuantizedType(AnyQuantizedType type,
                                   DialectAsmPrinter &out) {
@@ -405,6 +572,74 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
   out << "}>";
 }
 
+/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
+/// elements.  The nesting corresponds to the `shape` dimensions.
+///
+/// Elements are delimited by commas, and the inner dimensions are enclosed in
+/// braces.  `zero_point` is only printed if it is non-zero.  For example:
+///
+///   printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
+///                                   zeroPoints=[0, 0, 1, 9],
+///                                   shape=[2, 2])
+///
+///   would print:
+///
+///     {{1.0, 2.0}, {3.0:1, 4.0:9}}
+void printDenseQuantizationParameters(ArrayRef<APFloat> scales,
+                                      ArrayRef<APInt> zeroPoints,
+                                      ArrayRef<int64_t> shape,
+                                      DialectAsmPrinter &out) {
+  int64_t rank = shape.size();
+  SmallVector<unsigned, 4> counter(rank, 0);
+  unsigned openBrackets = 0;
+
+  auto bumpCounter = [&]() {
+    ++counter[rank - 1];
+    for (unsigned i = rank - 1; i > 0; --i) {
+      if (counter[i] >= shape[i]) {
+        counter[i] = 0;
+        ++counter[i - 1];
+        --openBrackets;
+        out << '}';
+      }
+    }
+  };
+
+  for (unsigned idx = 0, e = scales.size(); idx != e; ++idx) {
+    if (idx != 0)
+      out << ", ";
+    while (openBrackets++ < rank)
+      out << '{';
+    openBrackets = rank;
+    out << scales[idx];
+    if (zeroPoints[idx] != 0) {
+      out << ":" << zeroPoints[idx];
+    }
+    bumpCounter();
+  }
+  while (openBrackets-- > 0)
+    out << '}';
+}
+
+/// Helper that prints a UniformQuantizedSubChannelType.
+static void
+printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type,
+                                    DialectAsmPrinter &out) {
+  out << "uniform<";
+  printStorageType(type, out);
+  out << ":" << type.getExpressedType() << ":";
+  printBlockSizeInfo(type.getBlockSizeInfo(), out);
+  out << ", ";
+
+  auto scalesItr = type.getScales().getValues<APFloat>();
+  auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
+  SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
+  SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
+  printDenseQuantizationParameters(scales, zeroPoints,
+                                   type.getScales().getType().getShape(), out);
+  out << ">";
+}
+
 /// Helper that prints a CalibratedQuantizedType.
 static void printCalibratedQuantizedType(CalibratedQuantizedType type,
                                          DialectAsmPrinter &out) {
@@ -421,6 +656,9 @@ void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
     printUniformQuantizedType(uniformType, os);
   else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
     printUniformQuantizedPerAxisType(perAxisType, os);
+  else if (auto perAxisType =
+               llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
+    printUniformQuantizedSubChannelType(perAxisType, os);
   else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
     printCalibratedQuantizedType(calibratedType, os);
   else
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 4adeb9218ff8ec..45d0a0c3e697a2 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -570,6 +570,73 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
   return result;
 }
 
+// Convert an operation using sub-channel quantization.
+//
+// - op
+//   'quant.dcast' or 'quant.qcast' op.
+//
+// - input
+//   Scalar, ranked tensor.
+//
+// - quantizedType
+//   Sub-channel quantized type.
+//
+Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
+                        Value input,
+                        UniformQuantizedSubChannelType quantizedType) {
+  auto *context = builder.getContext();
+
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto inputRank = inputType.getRank();
+
+  auto scales = materializeSubChannelScales(builder, loc, quantizedType);
+  auto zeroPoints =
+      materializeSubChannelZeroPoints(builder, loc, quantizedType);
+
+  auto elementType = isa<FloatType>(inputType.getElementType())
+                         ? quantizedType.getStorageType()
+                         : quantizedType.getExpressedType();
+  auto initShape = tensor::getMixedSizes(builder, loc, input);
+  Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
+
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
+  const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
+      quantizedType.getBlockSizeInfo();
+  SmallVector<AffineExpr> affineExprs(inputRank,
+                                      builder.getAffineConstantExpr(0));
+  for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
+    affineExprs[quantizedDimension] =
+        builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
+  }
+  auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
+  SmallVector<AffineMap> indexingMaps{
+      builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
+      builder.getMultiDimIdentityMap(inputRank)};
+  auto result = builder
+                    .create<linalg::GenericOp>(
+                        loc,
+                        init.getType(),                        // resultType
+                        ValueRange{input, scales, zeroPoints}, // inputs
+                        ValueRange{init},                      // outputs
+                        indexingMaps, iteratorTypes,
+                        [&](OpBuilder &builder, Location loc, ValueRange args) {
+                          assert(args.size() == 4);
+                          auto input = args[0];
+                          auto scale = args[1];
+                          auto zeroPoint = args[2];
+
+                          auto result =
+                              convertRanked(builder, loc, op, input, {}, scale,
+                                            zeroPoint, quantizedType);
+
+                          builder.create<linalg::YieldOp>(loc, result);
+                        })
+                    .getResult(0);
+
+  return result;
+}
+
 // Convert a quantization operation.
 //
 // - op
diff --git a/mlir/test/CAPI/quant.c b/mlir/test/CAPI/quant.c
index 0a09e084119f71..bc7cd1436efe18 100644
--- a/mlir/test/CAPI/quant.c
+++ b/mlir/test/CAPI/quant.c
@@ -203,6 +203,130 @@ void testUniformPerAxisType(MlirContext ctx) {
   fprintf(stderr, "\n\n");
 }
 
+// CHECK-LABEL: testUniformSubChannelType
+void testUniformSubChannelType(MlirContext ctx) {
+  fprintf(stderr, "testUniformSubChannelType\n");
+
+  MlirType subChannelParsed =
+      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(
+                                "!quant.uniform<i8:f32:{0:1,1:2}, "
+                                "{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"));
+
+  MlirType i8 = mlirIntegerTypeGet(ctx, 8);
+  MlirType f32 = mlirF32TypeGet(ctx);
+
+  // block-size information
+  int32_t quantizedDimensions[] = {0, 1};
+  int64_t blockSizes[] = {1, 2};
+  int64_t numBlockSizes = 2;
+
+  // quantization parameters
+  int64_t quantParamShape[] = {2, 2};
+  int64_t quantParamRank = 2;
+  int64_t numQuantizationParams = 4;
+  MlirAttribute scales[] = {mlirFloatAttrDoubleGet(ctx, f32, 2.0),
+                            mlirFloatAttrDoubleGet(ctx, f32, 3.0),
+                            mlirFloatAttrDoubleGet(ctx, f32, 4.0),
+                            mlirFloatAttrDoubleGet(ctx, f32, 5.0)};
+  MlirAttribute zeroPoints[] = {
+      mlirIntegerAttrGet(i8, 10), mlirIntegerAttrGet(i8, 20),
+      mlirIntegerAttrGet(i8, 30), mlirIntegerAttrGet(i8, 40)};
+
+  MlirType scalesType =
+      mlirRankedTensorTypeGet(quantParamRank, quantParamShape, f32,
+                              /*encoding=*/mlirAttributeGetNull());
+  MlirType zeroPointsType = mlirRankedTensorTypeGet(
+      quantParamRank, quantParamShape, i8, /*encoding=*/mlirAttributeGetNull());
+  MlirAttribute denseScalesAttr =
+      mlirDenseElementsAttrGet(scalesType, numQuantizationParams, scales);
+  MlirAttribute denseZeroPointsAttr = mlirDenseElementsAttrGet(
+      zeroPointsType, numQuantizationParams, zeroPoints);
+
+  MlirType subChannel = mlirUniformQuantizedSubChannelTypeGet(
+      mlirQuantizedTypeGetSignedFlag(), i8, f32, denseScalesAttr,
+      denseZeroPointsAttr, numBlockSizes, quantizedDimensions, blockSizes,
+      mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
+                                                   /*integralWidth=*/8),
+      mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
+                                                   /*integralWidth=*/8));
+
+  MlirAttribute arrayScalesAttr =
+      mlirArrayAttrGet(ctx, numQuantizationParams, scales);
+  MlirAttribute arrayZeroPointsAttr =
+      mlirArrayAttrGet(ctx, numQuantizationParams, zeroPoints);
+  MlirType illegalSubChannel = mlirUniformQuantizedSubChannelTypeGet(
+      mlirQuantizedTypeGetSignedFlag(), i8, f32, arrayScalesAttr,
+      arrayZeroPointsAttr, numBlockSizes, quantizedDimensions, blockSizes,
+      mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
+                                                   /*integralWidth=*/8),
+      mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
+                                                   /*integralWidth=*/8));
+
+  // CHECK: is null sub-channel type: 1
+  fprintf(stderr, "is null sub-channel type: %d\n",
+          mlirTypeIsNull(illegalSubChannel));
+
+  // CHECK: num dims: 2
+  fprintf(stderr, "num dims: %" PRId64 "\n",
+          mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(subChannel));
+
+  // CHECK: axis-block-size-pair[0]: 0:1
+  fprintf(
+      stderr, "axis-block-size-pair[0]: %" PRId32 ":%" PRId64 "\n",
+      mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(subChannel, 0),
+      mlirUniformQuantizedSubChannelTypeGetBlockSize(subChannel, 0));
+
+  // CHECK: axis-block-size-pair[1]: 1:2
+  fprintf(
+      stderr, "axis-block-size-pair[1]: %" PRId32 ":%" PRId64 "\n",
+      mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(subChannel, 1),
+      mlirUniformQuantizedSubChannelTypeGetBlockSize(subChannel, 1));
+
+  denseScalesAttr = mlirUniformQuantizedSubChannelTypeGetScales(subChannel);
+  denseZeroPointsAttr =
+      mlirUniformQuantizedSubChannelTypeGetZeroPoints(subChannel);
+  scalesType = mlirAttributeGetType(denseScalesAttr);
+  zeroPointsType = mlirAttributeGetType(denseZeroPointsAttr);
+
+  // CHECK: tensor<2x2xf32>
+  mlirTypeDump(scalesType);
+  // CHECK: tensor<2x2xi8>
+  mlirTypeDump(zeroPointsType);
+
+  // CHECK: number of quantization parameters: 4
+  fprintf(stderr, "number of quantization parameters: %" PRId64 "\n",
+          mlirElementsAttrGetNumElements(denseScalesAttr));
+
+  // CHECK: quantization-parameter[0]: 2.000000:10
+  fprintf(stderr, "quantization-parameter[0]: %lf:%" PRId8 "\n",
+          mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 0),
+          mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 0));
+
+  // CHECK: quantization-parameter[1]: 3.000000:20
+  fprintf(stderr, "quantization-parameter[1]: %lf:%" PRId8 "\n",
+          mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 1),
+          mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 1));
+
+  // CHECK: quantization-parameter[2]: 4.000000:30
+  fprintf(stderr, "quantization-parameter[2]: %lf:%" PRId8 "\n",
+          mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 2),
+          mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 2));
+
+  // CHECK: quantization-parameter[3]: 5.000000:40
+  fprintf(stderr, "quantization-parameter[3]: %lf:%" PRId8 "\n",
+          mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 3),
+          mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 3));
+
+  // CHECK: equal: 1
+  fprintf(stderr, "equal: %d\n", mlirTypeEqual(subChannel, subChannelParsed));
+
+  // CHECK: !quant.uniform<i8:f32:{0:1,1:2},
+  // {{.*}}2.000000e+00:10, 3.000000e+00:20},
+  // {4.000000e+00:30, 5.000000e+00:40{{.*}}}}>
+  mlirTypeDump(subChannel);
+  fprintf(stderr, "\n\n");
+}
+
 // CHECK-LABEL: testCalibratedType
 void testCalibratedType(MlirContext ctx) {
   fprintf(stderr, "testCalibratedType\n");
diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir
index 359a58557087e1..29815ec0de41ee 100644
--- a/mlir/test/Dialect/Quant/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir
@@ -64,3 +64,12 @@ module @parseUniformPerAxisMixed attributes {
   bytecode.test = !quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>
 } {}
 
+//===----------------------------------------------------------------------===//
+// UniformQuantizedSubChannel
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: parseUniformSubChannel
+module @parseUniformSubChannel attributes {
+  // CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
+  bytecode.test = !quant.uniform<i8:f32:{0:1, 1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
+} {}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir
index ba3a8e312d96e9..7bb50f352f9389 100644
--- a/mlir/test/Dialect/Quant/invalid.mlir
+++ b/mlir/test/Dialect/Quant/invalid.mlir
@@ -256,3 +256,71 @@ func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) {
   return
 }
 
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_scalar(%arg0: f32) {
+  // expected-error at +1 {{scalar types may not use sub-channel quantization}}
+  %0 = quant.qcast %arg0 : f32 to !qalias
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_unranked(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{tensor containing the sub-channel quantized type must be ranked}}
+  %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,3:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_invalid_quantized_dimension(%arg0: tensor<2x4xf32>) {
+  // expected-error at +1 {{quantized dimension 3 must be less than tensor rank 2}}
+  %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:3},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_invalid_tensor_dim_size(%arg0: tensor<2x4xf32>) {
+  // expected-error at +1 {{tensor dimension size 4 at axis 1 must be divisible by the corresponding block size 3}}
+  %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @qcast_sub_channel_invalid_zero_tensor_dim_size(%arg0: tensor<0x4xf32>) {
+  // expected-error at +1 {{tensor dimension size of zero is not allowed with sub-channel quantization}}
+  %0 = quant.qcast %arg0 : tensor<0x4xf32> to tensor<0x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120}, {2.000000e+02}}>
+func.func @qcast_sub_channel_invalid_scale_dim_size(%arg0: tensor<2x4xf32>) {
+  // expected-error at +1 {{dimension size 2 of scales tensor at axis 1 should match (tensor dimension at axis / block sizes at axis) = 2}}
+  %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias>
+  return
+}
+
+// -----
+
+!qalias = !quant.uniform<u8:f32:{},{{{2.000000e+02:120}}}>
+func.func @qcast_sub_channel_invalid_scale_dim_size(%arg0: tensor<?x?xf32>) {
+  // expected-error at +1 {{Rank of scales 3 must match the rank of the tensor 2}}
+  %0 = quant.qcast %arg0 : tensor<?x?xf32> to tensor<?x?x!qalias>
+  return
+}
diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir
index 4abc5830d081e1..33ff93ecbc1d7b 100644
--- a/mlir/test/Dialect/Quant/ops.mlir
+++ b/mlir/test/Dialect/Quant/ops.mlir
@@ -148,4 +148,23 @@ func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) {
   return
 }
 
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @sub_channel_quantization(%arg0: tensor<2x4xi8>) -> tensor<2x4xi8> {
+  %0 = quant.scast %arg0 : tensor<2x4xi8> to tensor<2x4x!qalias>
+  %1 = quant.dcast %0 : tensor<2x4x!qalias> to tensor<2x4xf32>
+  %2 = quant.qcast %1 : tensor<2x4xf32> to tensor<2x4x!qalias>
+  %3 = quant.scast %2 : tensor<2x4x!qalias> to tensor<2x4xi8>
+  return %3 : tensor<2x4xi8>
+}
 
+// -----
+
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+func.func @sub_channel_quantization_with_unknown_dims(%arg0: tensor<2x?xf32>) {
+  %0 = quant.qcast %arg0 : tensor<2x?xf32> to tensor<2x?x!qalias>
+  return
+}
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index 4528d2826a850c..3b358443e43f2d 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -107,7 +107,7 @@
 
 // -----
 // Illegal scale: negative
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale -1.000000e+00 out of expressed type range}}
 !qalias = !quant.uniform<i8<-4:3>:f32, -1.0:127>
 
 // -----
@@ -128,20 +128,110 @@
 
 // -----
 // Scale f16 underflow
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 5.800000e-08 out of expressed type range}}
 !qalias = !quant.uniform<i8:f16, 5.8e-8>
 
 // -----
 // Scale f16 overflow
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 6.600000e+04 out of expressed type range}}
 !qalias = !quant.uniform<i8:f16, 6.6e4>
 
 // -----
 // Scale f16 underflow in per-axis quantization
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 5.800000e-08 out of expressed type range}}
 !qalias = !quant.uniform<i8:f16:1, {2.0,5.8e-8}>
 
 // -----
 // Scale f16 overflow in per-axis quantization
-// expected-error at +1 {{scale out of expressed type range}}
+// expected-error at +1 {{scale 6.600000e+04 out of expressed type range}}
 !qalias = !quant.uniform<i8:f16:1, {2.0,6.6e4}>
+
+// -----
+// Illegal negative axis in sub-channel quantization
+// expected-error at +1 {{illegal quantized dimension: -1}}
+!qalias = !quant.uniform<u8:f32:{0:1,-1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Illegal zero block-size in sub-channel quantization
+// expected-error at +1 {{illegal block size: 0}}
+!qalias = !quant.uniform<u8:f32:{0:0,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Illegal negative block-size in sub-channel quantization
+// expected-error at +1 {{illegal block size: -1}}
+!qalias = !quant.uniform<u8:f32:{0:-1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing block size in sub-channel quantization
+// expected-error at +1 {{expected ':'}}
+!qalias = !quant.uniform<u8:f32:{0,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing quantization dimension in sub-channel quantization
+// expected-error at +1 {{expected integer value}}
+!qalias = !quant.uniform<u8:f32:{:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Invalid tensor literal structure in sub-channel quantization
+// expected-error at +2 {{expected '>'}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}>
+
+// -----
+// Ragged tensor literal in sub-channel quantization
+// expected-error at +2 {{ranks are not consistent between elements}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02}}>
+
+// -----
+// Missing braces around block-size information in sub-channel quantization
+// expected-error at +1 {{expected ','}}
+!qalias = !quant.uniform<u8:f32:0:1,1:2,
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing right-brace around block-size information in sub-channel quantization
+// expected-error at +1 {{unbalanced '{' character}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2,
+      {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing left-brace around block-size information in sub-channel quantization
+// expected-error at +1 {{unbalanced '<' character}}
+!qalias = !quant.uniform<u8:f32:0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing Axis:BlockSize pair
+// expected-error at +1 {{expected integer value}}
+!qalias = !quant.uniform<u8:f32:{0:1,},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01}}>
+
+// -----
+// Missing Scale:ZeroPoint pair
+// expected-error at +2 {{expected floating point literal}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,}}>
+
+// -----
+// Missing ZeroPoint in Scale:ZeroPoint pair
+// expected-error at +2 {{expected integer value}}
+!qalias = !quant.uniform<u8:f32:{0:1,1:2},
+    {{2.000000e+02:120,9.987200e-01:127}, {2.000000e+02,9.987200e-01:}}>
+
+// -----
+// Empty quantization paramaters in sub-channel quantization
+// expected-error at +1 {{expected floating point literal}}
+!qalias = !quant.uniform<u8:f32:{0:1, 1:2}, {}>
+
+// -----
+// Scale out of expressed type range in sub-channel quantization
+// expected-error at +2 {{scale 6.600000e+04 out of expressed type range}}
+!qalias = !quant.uniform<i8:f16:{0:1,1:2},
+    {{6.6e4:120,9.987200e-01:127}, {2.000000e+02:256,9.987200e-01}}>
+
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 4fbe86d935ea39..35530833f2d3a7 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -154,3 +154,21 @@ func.func @parse() -> !qalias {
   %0 = "foo"() : () -> !qalias
   return %0 : !qalias
 }
+
+// -----
+// Sub-channel scales and zero points (mixed affine and fixedpoint)
+// CHECK: !quant.uniform<u8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:120, 3.000000e+00:127}, {4.000000e+00, 5.000000e+00}}>
+!qalias = !quant.uniform<u8:f32:{0:1,1:2}, {{2.0:120,3.0:127}, {4.0,5.0}}>
+func.func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Empty block-size information in sub-channel quantization
+// CHECK: !quant.uniform<u8:f32:{}, {{\{}}{2.000000e+00:120, 3.000000e+00:127}, {4.000000e+00, 5.000000e+00}}>
+!qalias = !quant.uniform<u8:f32:{}, {{2.0:120,3.0:127}, {4.0,5.0}}>
+func.func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}

>From 5de1147ca98ba7c226809543478775de3d3d4e98 Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Sun, 15 Dec 2024 10:37:17 +0000
Subject: [PATCH 2/4] Lowing to linalg ops

---
 .../Quant/Transforms/LowerQuantOps.cpp        | 214 +++++++++++-------
 mlir/test/Dialect/Quant/lower-quant-ops.mlir  |  64 ++++++
 2 files changed, 194 insertions(+), 84 deletions(-)

diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 45d0a0c3e697a2..c2dbcde1aeba6b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
   return inputType;
 }
 
-// Return the shape of an input value as a list of attributes (static dimensions)
-// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
-// returned. If 'input' is a tensor, its shape is returned.
-SmallVector<OpFoldResult>
-getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
+// Return the shape of an input value as a list of attributes (static
+// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
+// list is returned. If 'input' is a tensor, its shape is returned.
+SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
+                                                 Location loc, Value input) {
   if (isa<TensorType>(input.getType()))
     return tensor::getMixedSizes(builder, loc, input);
   return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 
   // Turn input size into 1D tensor
   auto flatShapeType = shape::getExtentTensorType(context, 1);
-  auto flatInputShape = builder.create<tensor::FromElementsOp>(
-      loc, flatShapeType, inputSize);
+  auto flatInputShape =
+      builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
 
   // Reshape input tensor into 1D
   auto inputType = cast<UnrankedTensorType>(input.getType());
   auto elementType = inputType.getElementType();
   auto flatInputType =
       RankedTensorType::get({ShapedType::kDynamic}, elementType);
-  auto flatInput = builder.create<tensor::ReshapeOp>(
-      loc, flatInputType, input, flatInputShape);
+  auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+                                                     flatInputShape);
   return std::make_pair(flatInput, inputShape);
 }
 
@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 // - inputShape
 //   1D extent tensor containing the shape of the original unranked input.
 //
-std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
-                                                        Location loc,
-                                                        Value input,
-                                                        int64_t axis,
-                                                        int64_t axisSize) {
+std::pair<Value, Value>
+flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
+                                int64_t axis, int64_t axisSize) {
   // Get full tensor shape
   auto *context = builder.getContext();
   auto indexType = builder.getIndexType();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
   // Get shape and sizes on left and right of axis
   auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
   auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
-  auto shapeLeft = builder.create<shape::SplitAtOp>(
-      loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
-      .getResult(0);
-  auto sizeLeft = builder.create<shape::NumElementsOp>(
-      loc, indexType, shapeLeft);
-  auto shapeRight = builder.create<shape::SplitAtOp>(
-      loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
-      .getResult(1);
-  auto sizeRight = builder.create<shape::NumElementsOp>(
-      loc, indexType, shapeRight);
+  auto shapeLeft =
+      builder
+          .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+                                    inputShape, axisValue)
+          .getResult(0);
+  auto sizeLeft =
+      builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
+  auto shapeRight =
+      builder
+          .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
+                                    inputShape, axisNextValue)
+          .getResult(1);
+  auto sizeRight =
+      builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
 
   // Compute flat input shape as a 3-element 1D tensor
   auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
   auto elementType = inputType.getElementType();
   auto flatInputType = RankedTensorType::get(
       {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
-  auto flatInput = builder.create<tensor::ReshapeOp>(
-      loc, flatInputType, input, flatInputShape);
+  auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
+                                                     flatInputShape);
 
   return std::make_pair(flatInput, inputShape);
 }
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
   auto inputType = cast<RankedTensorType>(input.getType());
   auto elementType = inputType.getElementType();
   auto unrankedType = UnrankedTensorType::get(elementType);
-  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
+  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
+                                           inputShape);
 }
 
 // Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
   auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
     return builder.getFloatAttr(expressedType, scale);
   });
-  auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
+  auto tensorType =
+      RankedTensorType::get({(int64_t)scales.size()}, expressedType);
   auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
   return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
 }
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
     UniformQuantizedPerAxisType quantizedType) {
   auto zeroPoints = quantizedType.getZeroPoints();
   auto storageType = quantizedType.getStorageType();
-  auto zeroPointAttrs = llvm::map_to_vector(
-      zeroPoints,
-      [&](int64_t zeroPoint) -> Attribute {
+  auto zeroPointAttrs =
+      llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
         return builder.getIntegerAttr(storageType, zeroPoint);
       });
   auto tensorType =
@@ -239,6 +242,54 @@ Value materializePerChannelZeroPoints(
   return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
 }
 
+// Create a tensor constant containing all scales in a sub-channel quantized
+// type. Example:
+//
+//   !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
+//
+// produces
+//
+//   %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
+//
+Value materializeSubChannelScales(
+    OpBuilder &builder, Location loc,
+    UniformQuantizedSubChannelType quantizedType) {
+  auto scales = quantizedType.getScales();
+  auto expressedType = quantizedType.getExpressedType();
+  auto scaleAttrs = llvm::map_to_vector(
+      scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
+        return builder.getFloatAttr(expressedType, scale);
+      });
+  auto tensorType =
+      RankedTensorType::get(scales.getType().getShape(), expressedType);
+  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
+  return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
+}
+
+// Create a tensor constant containing all zero points in a sub-channel
+// quantized type. Example:
+//
+//   !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
+//
+// produces
+//
+//   %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
+//
+Value materializeSubChannelZeroPoints(
+    OpBuilder &builder, Location loc,
+    UniformQuantizedSubChannelType quantizedType) {
+  auto zeroPoints = quantizedType.getZeroPoints();
+  auto storageType = quantizedType.getStorageType();
+  auto zeroPointAttrs = llvm::map_to_vector(
+      zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
+        return builder.getIntegerAttr(storageType, zeroPoint);
+      });
+  auto tensorType =
+      RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
+  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
+  return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
+}
+
 // Clamp the given scalar or tensor input using the storage bounds encoded in
 // the given quantized type, if present.
 //
@@ -299,7 +350,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
   return builder.create<arith::UIToFPOp>(loc, resultType, input);
 }
 
-// Quantize a scalar or ranked tensor value. The stored value is clamped using 
+// Quantize a scalar or ranked tensor value. The stored value is clamped using
 // the storage bounds encoded in the given quantized type.
 //
 // See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +359,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
                     Value zeroPoint, QuantizedType quantizedType) {
   // Convert scale to tensor if necessary
   auto inputType = input.getType();
-  scale = getScalarOrTensorConstant(
-      builder, loc, scale, inputType, inputShape);
+  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
 
   // Scale input
   auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
@@ -322,8 +372,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
                                           inputShape);
 
     // Convert zero point from storage to expressed type
-    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
-                                      scale.getType(),
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
                                       quantizedType.isSigned());
 
     // Add zero point to stored value
@@ -334,9 +383,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
   // Convert stored value to storage type
   auto storageScalarOrTensorType =
       getScalarOrTensorType(quantizedType.getStorageType(), inputType);
-  auto storedValueInt = convertFloatToInteger(
-      builder, loc, storedValueFloat, storageScalarOrTensorType,
-      quantizedType.isSigned());
+  auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
+                                              storageScalarOrTensorType,
+                                              quantizedType.isSigned());
 
   // Clamp stored value it if the storage type is bound
   auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
@@ -352,12 +401,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
                       Value zeroPoint, QuantizedType quantizedType) {
   // Convert scale to tensor if necessary
   auto inputType = input.getType();
-  scale = getScalarOrTensorConstant(
-      builder, loc, scale, inputType, inputShape);
+  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
 
   // Convert stored value to float
-  auto result = convertIntegerToFloat(
-      builder, loc, input, scale.getType(), quantizedType.isSigned());
+  auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
+                                      quantizedType.isSigned());
 
   // Skip unnecessary computations if no zero point is given
   if (!matchPattern(zeroPoint, m_Zero())) {
@@ -366,8 +414,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
                                           inputShape);
 
     // Convert zero point from storage to expressed type
-    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
-                                      scale.getType(),
+    zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
                                       quantizedType.isSigned());
 
     // Subtract zero point to stored value
@@ -501,35 +548,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
   auto initShape = tensor::getMixedSizes(builder, loc, input);
   Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
 
-  SmallVector<utils::IteratorType> iteratorTypes(
-      inputRank, utils::IteratorType::parallel);
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
   auto channelAxisAffineMap = AffineMap::get(
       inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
   SmallVector<AffineMap> indexingMaps{
-    builder.getMultiDimIdentityMap(inputRank),
-    channelAxisAffineMap,
-    channelAxisAffineMap,
-    builder.getMultiDimIdentityMap(inputRank)
-  };
-  auto result = builder.create<linalg::GenericOp>(
-      loc,
-      init.getType(),  // resultType
-      ValueRange{input, scales, zeroPoints},  // inputs
-      ValueRange{init},  // outputs
-      indexingMaps,
-      iteratorTypes,
-      [&](OpBuilder& builder, Location loc, ValueRange args) {
-        assert(args.size() == 4);
-        auto input = args[0];
-        auto scale = args[1];
-        auto zeroPoint = args[2];
-
-        auto result = convertRanked(builder, loc, op, input, {}, scale,
-                                    zeroPoint, quantizedType);
-
-        builder.create<linalg::YieldOp>(loc, result);
-      })
-      .getResult(0);
+      builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
+      channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
+  auto result = builder
+                    .create<linalg::GenericOp>(
+                        loc,
+                        init.getType(),                        // resultType
+                        ValueRange{input, scales, zeroPoints}, // inputs
+                        ValueRange{init},                      // outputs
+                        indexingMaps, iteratorTypes,
+                        [&](OpBuilder &builder, Location loc, ValueRange args) {
+                          assert(args.size() == 4);
+                          auto input = args[0];
+                          auto scale = args[1];
+                          auto zeroPoint = args[2];
+
+                          auto result =
+                              convertRanked(builder, loc, op, input, {}, scale,
+                                            zeroPoint, quantizedType);
+
+                          builder.create<linalg::YieldOp>(loc, result);
+                        })
+                    .getResult(0);
 
   return result;
 }
@@ -551,7 +596,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
   // Flatten unranked tensor into a 3D ranked tensor if necessary
   bool isUnranked = isa<UnrankedTensorType>(input.getType());
   int64_t channelAxis = quantizedType.getQuantizedDimension();
-  int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
+  int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
   Value inputShape;
   if (isUnranked) {
     std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
@@ -660,11 +705,17 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
     return convertPerChannel(builder, loc, op, input,
                              uniformQuantizedPerAxisType);
 
+  if (auto uniformQuantizedSubChannelType =
+          dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
+    return convertSubChannel(builder, loc, op, input,
+                             uniformQuantizedSubChannelType);
+
   llvm_unreachable("unexpected quantized type");
 }
 
 // Lowering pattern for 'quant.dcast'
-struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
+struct DequantizeCastOpConversion
+    : public OpConversionPattern<quant::DequantizeCastOp> {
   using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
 
   LogicalResult
@@ -689,7 +740,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
 };
 
 // Lowering pattern for 'quant.qcast'
-struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
+struct QuantizeCastOpConversion
+    : public OpConversionPattern<quant::QuantizeCastOp> {
   using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
 
   LogicalResult
@@ -717,12 +769,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
     ConversionTarget target(getContext());
     target.addLegalOp<quant::StorageCastOp>();
     target.addIllegalDialect<quant::QuantDialect>();
-    target.addLegalDialect<
-      arith::ArithDialect,
-      linalg::LinalgDialect,
-      shape::ShapeDialect,
-      tensor::TensorDialect
-    >();
+    target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+                           shape::ShapeDialect, tensor::TensorDialect>();
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -733,10 +781,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
 } // namespace
 
 void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
-  patterns.add<
-    DequantizeCastOpConversion,
-    QuantizeCastOpConversion
-  >(patterns.getContext());
+  patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
+      patterns.getContext());
 }
 
 } // namespace quant
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
index 6bba9f5c037727..2d22b1af89f2b4 100644
--- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir
+++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir
@@ -509,3 +509,67 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias>
   return %0 : tensor<*x!qalias>
 }
 
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3 floordiv 2)>
+
+// CHECK-LABEL: @qcast_sub_channel_ranked
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x?x?x4xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<{{.*}}2.000000e+00, 3.000000e+00{{.*}}, {{.*}}4.000000e+00, 5.000000e+00{{.*}}> : tensor<2x1x1x2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<{{.*}}10, 20{{.*}}, {{.*}}30, 40{{.*}}> : tensor<2x1x1x2xi8>
+
+// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<2x?x?x4xf32>
+// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<2x?x?x4xf32>
+// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<2x?x?x4xi8>
+
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<2x?x?x4xf32>, tensor<2x1x1x2xf32>, tensor<2x1x1x2xi8>) outs(%[[INIT]] : tensor<2x?x?x4xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<2x?x?x4xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x?x?x4xi8> to tensor<2x?x?x4x!quant.uniform<i8:f32:{0:1,3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
+// CHECK: return %[[STORED_QUANT]]
+
+!qalias = !quant.uniform<i8:f32:{0:1, 3:2}, {{{{2.0:10, 3.0:20}}}, {{{4.0:30, 5.0:40}}}}>
+func.func @qcast_sub_channel_ranked(%arg0: tensor<2x?x?x4xf32>) -> tensor<2x?x?x4x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<2x?x?x4xf32> to tensor<2x?x?x4x!qalias>
+  return %0 : tensor<2x?x?x4x!qalias>
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3 floordiv 2)>
+
+// CHECK-LABEL: @qcast_sub_channel_ranked_bounds
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x3x5x4xf32>
+
+// CHECK: %[[SCALES:.*]] = arith.constant dense<{{.*}}2.000000e+00, 3.000000e+00{{.*}}, {{.*}}4.000000e+00, 5.000000e+00{{.*}}> : tensor<2x1x1x2xf32>
+// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<{{.*}}10, 20{{.*}}, {{.*}}30, 40{{.*}}> : tensor<2x1x1x2xi8>
+
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x3x5x4xi8>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<2x3x5x4xf32>, tensor<2x1x1x2xf32>, tensor<2x1x1x2xi8>) outs(%[[INIT]] : tensor<2x3x5x4xi8>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8):
+// CHECK:   %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32
+// CHECK:   %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32
+// CHECK:   %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32
+// CHECK:   %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8
+// CHECK:   linalg.yield %[[STORED_INT]] : i8
+// CHECK: } -> tensor<2x3x5x4xi8>
+
+// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x3x5x4xi8> to tensor<2x3x5x4x!quant.uniform<i8:f32:{0:1,3:2}, {{.*}}2.000000e+00:10, 3.000000e+00:20{{.*}}, {{.*}}4.000000e+00:30, 5.000000e+00:40{{.*}}>>
+// CHECK: return %[[STORED_QUANT]]
+
+!qalias = !quant.uniform<i8:f32:{0:1, 3:2}, {{{{2.0:10, 3.0:20}}}, {{{4.0:30, 5.0:40}}}}>
+func.func @qcast_sub_channel_ranked_bounds(%arg0: tensor<2x3x5x4xf32>) -> tensor<2x3x5x4x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<2x3x5x4xf32> to tensor<2x3x5x4x!qalias>
+  return %0 : tensor<2x3x5x4x!qalias>
+}
\ No newline at end of file

>From 6d33d7c5c553217965ca9c32945ea5679185b2a4 Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Sun, 15 Dec 2024 12:13:59 +0000
Subject: [PATCH 3/4] Addings c-api and py-apis

---
 mlir/lib/Bindings/Python/DialectQuant.cpp     |  1 +
 mlir/lib/CAPI/Dialect/Quant.cpp               |  1 +
 .../mlir/_mlir_libs/_mlir/dialects/quant.pyi  | 22 ++++++++-
 mlir/test/CAPI/quant.c                        |  2 +
 mlir/test/python/dialects/quant.py            | 46 +++++++++++++++++++
 5 files changed, 71 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 44a596caa24a65..8f7da9ef3c1bca 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Dialect/Quant.h"
 #include "mlir-c/IR.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 88648497895ab7..01a6a948f1dc07 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/BuiltinAttributes.h"
 #include "mlir/CAPI/Registration.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
index 47168d49c5568b..3f5304584edeff 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 
-from mlir.ir import Type
+from mlir.ir import DenseElementsAttr, Type
 
 __all__ = [
   "QuantizedType",
@@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType):
   @property
   def is_fixed_point(self) -> bool: ...
 
+class UniformQuantizedSubChannelType(QuantizedType):
+
+  @classmethod
+  def get(cls, flags: int, storage_type: Type, expressed_type: Type,
+          scales: DenseElementsAttr, zero_points: DenseElementsAttr,
+          quantized_dimensions: list[int], block_sizes: list[int],
+          storage_type_min: int, storage_type_max: int):
+    ...
+
+  @property
+  def quantized_dimensions(self) -> list[int]: ...
+
+  @property
+  def block_sizes(self) -> list[int]: ...
+
+  @property
+  def scales(self) -> DenseElementsAttr: ...
+
+  @property
+  def zero_points(self) -> DenseElementsAttr: ...
 
 def CalibratedQuantizedType(QuantizedType):
 
diff --git a/mlir/test/CAPI/quant.c b/mlir/test/CAPI/quant.c
index bc7cd1436efe18..a5be7550b7d63f 100644
--- a/mlir/test/CAPI/quant.c
+++ b/mlir/test/CAPI/quant.c
@@ -10,6 +10,7 @@
 // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
 
 #include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/IR.h"
 
@@ -357,6 +358,7 @@ int main(void) {
   testAnyQuantizedType(ctx);
   testUniformType(ctx);
   testUniformPerAxisType(ctx);
+  testUniformSubChannelType(ctx);
   testCalibratedType(ctx);
   mlirContextDestroy(ctx);
   return EXIT_SUCCESS;
diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py
index b1d6e85f519b5d..b30a7322dca4a1 100644
--- a/mlir/test/python/dialects/quant.py
+++ b/mlir/test/python/dialects/quant.py
@@ -1,5 +1,6 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+import numpy as np
 from mlir.ir import *
 from mlir.dialects import quant
 
@@ -18,21 +19,27 @@ def test_type_hierarchy():
         any = Type.parse("!quant.any<i8<-8:7>:f32>")
         uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
         per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+        sub_channel = Type.parse(
+            "!quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>")
         calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
 
         assert not quant.QuantizedType.isinstance(i8)
         assert quant.QuantizedType.isinstance(any)
         assert quant.QuantizedType.isinstance(uniform)
         assert quant.QuantizedType.isinstance(per_axis)
+        assert quant.QuantizedType.isinstance(sub_channel)
         assert quant.QuantizedType.isinstance(calibrated)
 
         assert quant.AnyQuantizedType.isinstance(any)
         assert quant.UniformQuantizedType.isinstance(uniform)
         assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
+        assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
         assert quant.CalibratedQuantizedType.isinstance(calibrated)
 
         assert not quant.AnyQuantizedType.isinstance(uniform)
         assert not quant.UniformQuantizedType.isinstance(per_axis)
+        assert not quant.UniformQuantizedType.isinstance(sub_channel)
+        assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
 
 
 # CHECK-LABEL: TEST: test_any_quantized_type
@@ -121,6 +128,45 @@ def test_uniform_per_axis_type():
         assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
 
 
+# CHECK-LABEL: TEST: test_uniform_sub_channel_type
+ at run
+def test_uniform_sub_channel_type():
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        sub_channel = quant.UniformQuantizedSubChannelType.get(
+            quant.QuantizedType.FLAG_SIGNED,
+            i8,
+            f32,
+            DenseElementsAttr.get(np.asarray(
+                [2.0, 3.0, 4.0, 5.0], np.float32).reshape(2, 2)),
+            DenseElementsAttr.get(np.asarray(
+                [10, 20, 30, 40], np.int8).reshape(2, 2)),
+            [0, 1], [1, 2],
+            storage_type_min=quant.QuantizedType.default_minimum_for_integer(
+                is_signed=True, integral_width=8
+            ),
+            storage_type_max=quant.QuantizedType.default_maximum_for_integer(
+                is_signed=True, integral_width=8
+            ),
+        )
+
+        # CHECK: quantized dimensions: [0, 1]
+        print(f"quantized dimensions: {sub_channel.quantized_dimensions}")
+        # CHECK: block sizes: [1, 2]
+        print(f"block sizes: {sub_channel.block_sizes}")
+        # CHECK: scales: {{\[}}[2. 3.]
+        # CHECK:               [4. 5.]]
+        print(f"scales: {np.asarray(sub_channel.scales)}")
+        # CHECK: zero-points: {{\[}}[10 20]
+        # CHECK:                    [30 40]]
+        print(f"zero-points: {np.asarray(sub_channel.zero_points)}")
+        # CHECK: !quant.uniform<i8:f32:{0:1,1:2}, {{\{}}{2.000000e+00:10, 3.000000e+00:20}, {4.000000e+00:30, 5.000000e+00:40}}>
+        print(sub_channel)
+        assert sub_channel == Type.parse(
+            "!quant.uniform<i8:f32:{0:1,1:2},{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>")
+
+
 # CHECK-LABEL: TEST: test_calibrated_type
 @run
 def test_calibrated_type():

>From 3a6a0a71f55a92ad597cbfb4ac596e6f50ee2405 Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Thu, 12 Dec 2024 20:27:04 +0000
Subject: [PATCH 4/4] Add pass to normalize generic quantized types to specific
 quantized types

---
 .../mlir/Dialect/Quant/Transforms/Passes.td   |  33 ++++
 .../Dialect/Quant/Transforms/CMakeLists.txt   |   1 +
 .../Quant/Transforms/NormalizeQuantTypes.cpp  | 179 ++++++++++++++++++
 .../Dialect/Quant/normalize-quant-types.mlir  |  51 +++++
 4 files changed, 264 insertions(+)
 create mode 100644 mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
 create mode 100644 mlir/test/Dialect/Quant/normalize-quant-types.mlir

diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
index b25296d4db5a99..4b438706c9af69 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td
@@ -31,6 +31,39 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> {
   ];
 }
 
+def NormalizeQuantTypes : Pass<"normalize-quant-types"> {
+  let summary = "Normalize generic quantized types to specific quantized types";
+  let description = [{
+    This pass converts generic quantized types in the `quant` dialect to more
+    specific types when possible.
+
+    The following conversions are performed:
+
+    1. Sub-channel to per-axis: If the shape of the scales tensor of sub-channel
+       quantized type has all but one non-one value, it is converted to a
+       per-axis quantized type.
+       
+       For example:
+       
+       * `!quant.uniform<i8:f32:{0:1}, {{2.0}, {3.0}}>` 
+          -> `!quant.uniform<i8:f32:0, {2.0, 3.0}>`
+       * `tensor<?x?x!quant.uniform<i8:f32:{0:1,1:4}, {{2.0}, {3.0}}>>` 
+          -> `tensor<?x?x!quant.uniform<i8:f32:0, {2.0, 3.0}>>`
+
+    2. Sub-channel to per-tensor: If a sub-channel quantized type has only 
+       one scale or zero-point, it is converted to a per-tensor 
+       quantized type.
+       
+       For example:
+       
+       * `!quant.uniform<i8:f32:{}, {{2.0}}>`
+          -> `!quant.uniform<i8:f32, 2.0>`
+       * `tensor<?x?x!quant.uniform<i8:f32:{0:1, 0:4}, {{2.0}}>>`
+          -> `tensor<?x?x!quant.uniform<i8:f32, 2.0>>`
+  }];
+  let dependentDialects = ["func::FuncDialect", "quant::QuantDialect"];
+}
+
 def StripFuncQuantTypes : Pass<"strip-func-quant-types"> {
   let summary = "Strip quantized types from function headers";
   let description = [{
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 2fd4a41999d456..825d11992d309c 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRQuantTransforms
   LowerQuantOps.cpp
+  NormalizeQuantTypes.cpp
   StripFuncQuantTypes.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
new file mode 100644
index 00000000000000..ecd6679651f462
--- /dev/null
+++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
@@ -0,0 +1,179 @@
+//===- NormalizeQuantTypes.cpp - Normalize quantized types
+//----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Normalize generic quantized types to specific quantized types
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+/// Returns true if the given sub-channel quantized type is convertible to a
+/// per-tensor quantized type. This is true if the sub-channel type has only
+/// one scale and one zero point.
+///
+/// Assumes that `tensorType` is a tensor with element type
+/// `quant::UniformQuantizedSubChannelType`.
+static bool isConvertibleToPerTensor(TensorType tensorType) {
+  return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
+             .getScales()
+             .getType()
+             .getNumElements() == 1;
+}
+
+/// Returns true if the given sub-channel quantized type is convertible to a
+/// per-axis quantized type. This is true if the shape of the scales tensor has
+/// all but one non-one value.
+///
+/// Assumes that `tensorType` is a tensor with element type
+/// `quant::UniformQuantizedSubChannelType`.
+static bool isConvertibleToPerAxis(TensorType tensorType) {
+  auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
+                   .getScales()
+                   .getType()
+                   .getShape();
+  return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
+}
+
+/// This class defines a type converter that converts sub-channel quantized
+/// types to per-tensor or per-axis quantized types whenever possible.
+class NormalizedQuantTypesConverter : public TypeConverter {
+
+  static Type convertType(Type type) {
+    auto tensorType = dyn_cast<TensorType>(type);
+    if (!tensorType) {
+      return type;
+    }
+
+    auto subChannelType =
+        dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
+    if (!subChannelType) {
+      return type;
+    }
+
+    if (isConvertibleToPerTensor(tensorType)) {
+      double scale =
+          subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
+      int64_t zeroPoint =
+          subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
+      auto perTensorType = UniformQuantizedType::get(
+          subChannelType.getFlags(), subChannelType.getStorageType(),
+          subChannelType.getExpressedType(), scale, zeroPoint,
+          subChannelType.getStorageTypeMin(),
+          subChannelType.getStorageTypeMax());
+      return tensorType.clone(perTensorType);
+    }
+
+    if (isConvertibleToPerAxis(tensorType)) {
+      auto shape = subChannelType.getScales().getType().getShape();
+      auto quantizedDimItr =
+          llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
+      auto scales = llvm::to_vector(llvm::map_range(
+          subChannelType.getScales().getValues<APFloat>(),
+          [](APFloat scale) { return scale.convertToDouble(); }));
+      auto zeroPoints = llvm::to_vector(llvm::map_range(
+          subChannelType.getZeroPoints().getValues<APInt>(),
+          [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
+      auto perAxisType = UniformQuantizedPerAxisType::get(
+          subChannelType.getFlags(), subChannelType.getStorageType(),
+          subChannelType.getExpressedType(), scales, zeroPoints,
+          quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
+          subChannelType.getStorageTypeMax());
+      return tensorType.clone(perAxisType);
+    }
+    return type;
+  }
+
+public:
+  explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
+};
+
+/// This class implements a conversion pattern that converts any generic
+/// operation with sub-channel quantized types to an equivalent operation with
+/// per-tensor or per-axis quantized types.
+class ConvertGenericOpwithSubChannelType : public ConversionPattern {
+public:
+  ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
+                                     MLIRContext *context)
+      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    SmallVector<Type> resultTypes;
+    if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
+      return failure();
+
+    auto *newOp = Operation::create(
+        op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
+        op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+    for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
+      Region &before = std::get<0>(regions);
+      Region &parent = std::get<1>(regions);
+      rewriter.inlineRegionBefore(before, parent, parent.end());
+      if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
+        return failure();
+    }
+    rewriter.insert(newOp);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
+// Conversion pass
+class NormalizeQuantTypes
+    : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
+public:
+  void runOnOperation() override {
+
+    auto moduleOp = cast<ModuleOp>(getOperation());
+    auto *context = &getContext();
+
+    NormalizedQuantTypesConverter typeConverter;
+    ConversionTarget target(*context);
+
+    // Determine legal operations.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return typeConverter.isLegal(op->getOperandTypes()) &&
+             typeConverter.isLegal(op->getResultTypes());
+    });
+
+    // Register conversion patterns
+    RewritePatternSet patterns(context);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
+
+    // Apply conversion
+    if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+} // namespace quant
+} // namespace mlir
diff --git a/mlir/test/Dialect/Quant/normalize-quant-types.mlir b/mlir/test/Dialect/Quant/normalize-quant-types.mlir
new file mode 100644
index 00000000000000..573781c9ecc04a
--- /dev/null
+++ b/mlir/test/Dialect/Quant/normalize-quant-types.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --normalize-quant-types --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @callee(
+// CHECK-SAME: [[PER_TENSOR:tensor<\?x\?x!quant.uniform<i8:f32, 2.000000e\+00:127>>]],
+// CHECK-SAME: [[PER_TENSOR]]
+// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]])
+// CHECK-LABEL: @normalize_quant_types_to_per_tensor
+// CHECK-SAME: %[[ARG_0:.*]]: [[PER_TENSOR:tensor<\?x\?x!quant.uniform<i8:f32, 2.000000e\+00:127>>]],
+// CHECK-SAME: %[[ARG_1:.*]]: [[PER_TENSOR]]
+// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]])
+// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]]
+// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]]
+// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]])
+// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_TENSOR]], [[PER_TENSOR]]
+
+!qalias1 = !quant.uniform<i8:f32:{}, {{2.0:127}}>
+!qalias2 = !quant.uniform<i8:f32:{0:1,1:4}, {{2.0:127}}>
+
+func.func private @callee(tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+
+func.func @normalize_quant_types_to_per_tensor(%arg0: tensor<?x?x!qalias1>,
+    %arg1: tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) {
+  %0 = "test.custom_op"(%arg0) : (tensor<?x?x!qalias1>) -> tensor<?x?x!qalias1>
+  %1 = "test.custom_op"(%arg1) : (tensor<?x?x!qalias2>) -> tensor<?x?x!qalias2>
+  %3:2 = func.call @callee(%0, %1) : (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+  return %3#0, %3#1 : tensor<?x?x!qalias1>, tensor<?x?x!qalias2>
+}
+
+// -----
+
+// CHECK-LABEL: @normalize_quant_types_to_per_axis
+// CHECK-SAME: %[[ARG_0:.*]]: [[PER_AXIS:tensor<\?x\?x!quant.uniform<i8:f32:0, \{2.000000e\+00:127,3.000000e\+00:127\}>>]],
+// CHECK-SAME: %[[ARG_1:.*]]: [[PER_AXIS]]
+// CHECK-SAME: ([[PER_AXIS]], [[PER_AXIS]])
+// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_AXIS]]) -> [[PER_AXIS]]
+// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_AXIS]]) -> [[PER_AXIS]]
+// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]])
+// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_AXIS]], [[PER_AXIS]]
+
+!qalias1 = !quant.uniform<i8:f32:{0:1}, {{2.0:127}, {3.0:127}}>
+!qalias2 = !quant.uniform<i8:f32:{0:1,1:4}, {{2.0:127}, {3.0:127}}>
+
+func.func private @callee(tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+
+func.func @normalize_quant_types_to_per_axis(%arg0: tensor<?x?x!qalias1>,
+    %arg1: tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) {
+  %0 = "test.custom_op"(%arg0) : (tensor<?x?x!qalias1>) -> tensor<?x?x!qalias1>
+  %1 = "test.custom_op"(%arg1) : (tensor<?x?x!qalias2>) -> tensor<?x?x!qalias2>
+  %3:2 = func.call @callee(%0, %1) : (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>) -> (tensor<?x?x!qalias1>, tensor<?x?x!qalias2>)
+  return %3#0, %3#1 : tensor<?x?x!qalias1>, tensor<?x?x!qalias2>
+}



More information about the Mlir-commits mailing list