[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)
Rafael Ubal
llvmlistbot at llvm.org
Tue Aug 27 11:46:40 PDT 2024
================
@@ -0,0 +1,229 @@
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef QUANT_OPS
+#define QUANT_OPS
+
+include "mlir/Dialect/Quant/IR/QuantBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<Trait> traits> :
+ Op<Quant_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization casts
+//===----------------------------------------------------------------------===//
+
+def quant_DequantizeCastOp : quant_Op<"dcast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
+ let summary = "Dequantize cast operation";
+ let description = [{
+ Convert an input quantized value into its expressed floating-point value.
+ The dequantization process consists of the following steps:
+
+ ```
+ def dequantize(quantizedValue: quantizedType) -> expressedType:
+ storedValue = reinterpretCast(quantizedValue, storageType)
+ storedValueFloat = convertIntToFloat(storedValue, expressedType)
+ zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+ expressedValue = (storedValueFloat - zeroPointFloat) * scale
+ return expressedValue
+ ```
+
+ Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained
+ from the corresponding parameters encoded in `quantizedType`. For
+ per-channel quantization, the appropriate `scale` and `zeroPoint` values
+ are used for each tensor element computation according to the channel the
+ element belongs to.
+
+ The operation must satisfy the following syntactic constraints:
+
+ - Operand `input` must be a scalar or tensor of type `!quant.uniform`.
+
+ - The result type must be a floating-point scalar or tensor.
+
+ - The `expressedType` parameter of the `!quant.uniform` type of the input
+ must match the floating-point type of the result.
+
+ - The operand and result types must be both scalars or both tensors. If
+ tensors, they must be both ranked or both unranked. If ranked, both must
+ have the same shape, including matching static and dynamic dimensions.
+
+ - If the operand uses per-channel quantization, its `!quant.uniform` type
+ must adhere to the [Per-axis quantization
+ integrity](#per-axis-quantization-integrity) guidelines.
+
+ Examples:
+
+ ```
+ // Dequantize a scalar quantized value
+ %result = quant.dcast %input : !quant.uniform<i8:f32, 2.0> to f32
+
+ // Dequantize a dynamically shaped tensor of quantized values
+ %result = quant.dcast %input : tensor<?x!quant.uniform<i8:f32, 2.0>> to tensor<?xf32>
+
+ // Dequantize an unranked tensor using per-axis quantization information
+ %result = quant.dcast %input : tensor<*x!quant.uniform<i8:f32:1, {2.0, 3.0}>> to tensor<*xf32>
+ ```
+ }];
+ let arguments = (ins quant_QuantizedScalarOrTensor:$input);
+ let results = (outs quant_FloatScalarOrTensor:$result);
+ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
+ let hasVerifier = 1;
+ let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// Return the float type of the scalar or tensor result.
+ FloatType getFloatType();
+
+ /// Return the quantized type of the scalar or tensor input.
+ quant::QuantizedType getQuantizedType();
+ }];
+}
+
+def quant_QuantizeCastOp : quant_Op<"qcast", [
+ Pure,
+ quant_SameScalarOrTensorShape]> {
+ let summary = "Quantize cast operation";
+ let description = [{
+ Convert a floating-point value to a quantized type. The quantization
+ process consists of the following steps:
+
+ ```
+ def quantize(expressedValue: expressedType) -> quantizedType:
+ zeroPointFloat = convertIntToFloat(zeroPoint, expressedType)
+ scaledValue = expressedValue / scale
----------------
rafaelubalmw wrote:
This may be a `FloatType` in all of its variants, as defined by the `quant.uniform` data type before this change. The lowerings produce valid ops for any expressed type. The magic lies in the generated `arith.sitofp` or `arith.fptosi` ops, which would need to be appropriately lowered by a back-end interested in supporting expressed types less common than just `f32` or `f64`. I added some examples in `QuantBase.td`:
```
- expressedType: Floating-point type of the value expressed by this quantized type (e.g., f32, f80, bf16, or tf32).
```
https://github.com/llvm/llvm-project/pull/100667
More information about the Mlir-commits
mailing list