[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)

Rafael Ubal llvmlistbot at llvm.org
Tue Aug 27 11:46:40 PDT 2024


================
@@ -0,0 +1,229 @@
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef QUANT_OPS
+#define QUANT_OPS
+
+include "mlir/Dialect/Quant/IR/QuantBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<Trait> traits> :
+    Op<Quant_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization casts
+//===----------------------------------------------------------------------===//
+
+def quant_DequantizeCastOp : quant_Op<"dcast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
+  let summary = "Dequantize cast operation";
+  let description = [{
+    Convert an input quantized value into its expressed floating-point value.
+    The dequantization process consists of the following steps:
+
+    ```
+    def dequantize(quantizedValue: quantizedType) -> expressedType:
+        storedValue = reinterpretCast(quantizedValue, storageType)
+        storedValueFloat = convertIntToFloat(storedValue, expressedType)
+        zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+        expressedValue = (storedValueFloat - zeroPointFloat) * scale
+        return expressedValue
+    ```
+
+    Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained
+    from the corresponding parameters encoded in `quantizedType`. For
+    per-channel quantization, the appropriate `scale` and `zeroPoint` values
+    are used for each tensor element computation according to the channel the
+    element belongs to.
+
+    The operation must satisfy the following syntactic constraints:
+
+    - Operand `input` must be a scalar or tensor of type `!quant.uniform`.
+
+    - The result type must be a floating-point scalar or tensor.
+
+    - The `expressedType` parameter of the `!quant.uniform` type of the input
+      must match the floating-point type of the result.
+
+    - The operand and result types must be both scalars or both tensors. If
+      tensors, they must be both ranked or both unranked. If ranked, both must
+      have the same shape, including matching static and dynamic dimensions.
+
+    - If the operand uses per-channel quantization, its `!quant.uniform` type
+      must adhere to the [Per-axis quantization
+      integrity](#per-axis-quantization-integrity) guidelines.
+
+    Examples:
+
+    ```
+    // Dequantize a scalar quantized value
+    %result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
+
+    // Dequantize a dynamically shaped tensor of quantized values
+    %result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
+
+    // Dequantize an unranked tensor using per-axis quantization information
+    %result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
+    ```
+  }];
+  let arguments = (ins quant_QuantizedScalarOrTensor:$input);
+  let results = (outs quant_FloatScalarOrTensor:$result);
+  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+  let hasVerifier = 1;
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// Return the float type of the scalar or tensor result.
+    FloatType getFloatType();
+    
+    /// Return the quantized type of the scalar or tensor input.
+    quant::QuantizedType getQuantizedType();
+  }];
+}
+
+def quant_QuantizeCastOp : quant_Op<"qcast", [
+    Pure,
+    quant_SameScalarOrTensorShape]> {
+  let summary = "Quantize cast operation";
+  let description = [{
+    Convert a floating-point value to a quantized type. The quantization
+    process consists of the following steps:
+
+    ```
+    def quantize(expressedValue: expressedType) -> quantizedType:
+        zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+        scaledValue = expressedValue / scale
----------------
rafaelubalmw wrote:

This may be a `FloatType` in all of its variants, as defined by the `quant.uniform` data type before this change. The lowerings produce valid ops for any expressed type. The magic lies in the generated `arith.sitofp` or `arith.fptosi` ops, which would need to be appropriately lowered by a back-end interested in supporting expressed types less common than just `f32` or `f64`. I added some examples in `QuantBase.td`:

```
- expressedType: Floating-point type of the value expressed by this quantized type (e.g., f32, f80, bf16, or tf32).
```

https://github.com/llvm/llvm-project/pull/100667


More information about the Mlir-commits mailing list