[Mlir-commits] [mlir] Sub-channel quantized type implementation (PR #120172)

Sandeep Dasgupta llvmlistbot at llvm.org
Wed Mar 12 18:15:34 PDT 2025


================
@@ -17,50 +17,137 @@
 
 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
-
 namespace mlir {
 namespace quant {
 
 namespace {
 
 // Verify the integrity of per-axis quantization information, if present.
 //
-// - quantizedType
-//   Any quantized type. Any quantized type with no per-axis quantization is
-//   ignored.
+// - uniformQuantizedPerAxisType
+//   A quantized type with per-axis quantization.
 //
 // - containerType
 //   Original input or result type of the operation using the provided quantized
 //   type. Used to ensure that the quantized type appears within a tensor and
 //   that the tensor is compatible with per-axis quantization information.
 //
-LogicalResult verifyPerAxisQuantization(Operation *op,
-                                        QuantizedType quantizedType,
-                                        Type containerType) {
-  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
-  if (!quantizedPerAxisType)
-    return success();
-
+LogicalResult verifyPerAxisQuantization(
+    Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
+    Type containerType) {
   auto tensorType = dyn_cast<TensorType>(containerType);
   if (!tensorType)
     return op->emitError("scalar types may not use per-axis quantization");
 
   if (!tensorType.hasRank())
     return success();
 
-  int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
-  if (quantizedDimension >= tensorType.getRank())
+  int32_t quantizedDimension =
+      uniformQuantizedPerAxisType.getQuantizedDimension();
+  if ((int64_t)quantizedDimension >= tensorType.getRank())
     return op->emitError("quantized dimension must be less than tensor rank");
 
   int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
   if (quantizedDimensionSize != ShapedType::kDynamic &&
-      quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+      quantizedDimensionSize !=
+          (int64_t)uniformQuantizedPerAxisType.getScales().size())
     return op->emitError(
         "quantized dimension size does not match number of scales");
 
   return success();
 }
 
+// Verifies that the sub-channel quantization parameters are consistent with
+// the given container type. The function checks the following:
+//
+// - The container type must be a ranked tensor type.
+// - Each quantized dimension must be less than the rank of the tensor.
+// - The size of each dimension at the quantized dimension must be divisible
+//    by the corresponding block size.
+// - The scale dimension size at each axis index should match the tensor
+//    dimension at the index divided by the corresponding block size.
+//
+// The `uniformQuantizedSubChannelType` argument provides the sub-channel
+// quantization parameters, and the `containerType` argument specifies the
+// type of the container holding the quantized data.
----------------
sdasgup3 wrote:

>> Do we have any constraints on the datatype of the scales / zps in relation to the storageType?

Yes, we have those strict checks [ref](https://github.com/llvm/llvm-project/pull/120172#discussion_r1992493402) as part of QuantizedType verification `UniformQuantizedSubChannelType::verifyInvariants`.

IMO, lets defer loosing the constraints once we  start with super-block quantization implementation [ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694#p-332808-looking-ahead-optimizing-quantized-model-memory-footprint-14).


>> . Is there any reason that verification is spread out over several places? 

For reason: [refer](https://github.com/llvm/llvm-project/pull/120172#discussion_r1992493402)

>> Can we consolidate?


`UniformQuantizedSubChannelType::verifyInvariants` are employed independently by dialects that only share the quantized types (e.g., the stablehlo dialect). In these cases, the lack of op-level verification is compensated by these dialects having their own versions of op-level verification checks [eg](https://github.com/openxla/stablehlo/blob/ec66eef1ef763f3d67f2b6dfe7c08056140fe195/stablehlo/dialect/Base.cpp#L765). This contrasts with `verifySubChannelQuantization`, which are op-level verification checks solely relevant in the context of quant dialect ops.




https://github.com/llvm/llvm-project/pull/120172


More information about the Mlir-commits mailing list