[Mlir-commits] [mlir] Sub-channel quantized type implementation (PR #120172)
Kevin Gleason
llvmlistbot at llvm.org
Mon Mar 17 07:51:54 PDT 2025
================
@@ -0,0 +1,179 @@
+//===- NormalizeQuantTypes.cpp - Normalize quantized types
+//----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Normalize generic quantized types to specific quantized types
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace quant {
+
+#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
+#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
+
+namespace {
+
+/// Returns true if the given sub-channel quantized type is convertible to a
+/// per-tensor quantized type. This is true if the sub-channel type has only
+/// one scale and one zero point.
+///
+/// Assumes that `tensorType` is a tensor with element type
+/// `quant::UniformQuantizedSubChannelType`.
+static bool isConvertibleToPerTensor(TensorType tensorType) {
+ return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
+ .getScales()
+ .getType()
+ .getNumElements() == 1;
+}
+
+/// Returns true if the given sub-channel quantized type is convertible to a
+/// per-axis quantized type. This is true if the shape of the scales tensor has
+/// all but one non-one value.
+///
+/// Assumes that `tensorType` is a tensor with element type
+/// `quant::UniformQuantizedSubChannelType`.
+static bool isConvertibleToPerAxis(TensorType tensorType) {
+ auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
+ .getScales()
+ .getType()
+ .getShape();
+ return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
+}
+
+/// This class defines a type converter that converts sub-channel quantized
+/// types to per-tensor or per-axis quantized types whenever possible.
+class NormalizedQuantTypesConverter : public TypeConverter {
+
+ static Type convertType(Type type) {
+ auto tensorType = dyn_cast<TensorType>(type);
+ if (!tensorType) {
+ return type;
+ }
+
+ auto subChannelType =
+ dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
+ if (!subChannelType) {
+ return type;
+ }
+
+ if (isConvertibleToPerTensor(tensorType)) {
+ double scale =
+ subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
+ int64_t zeroPoint =
+ subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
+ auto perTensorType = UniformQuantizedType::get(
+ subChannelType.getFlags(), subChannelType.getStorageType(),
+ subChannelType.getExpressedType(), scale, zeroPoint,
+ subChannelType.getStorageTypeMin(),
+ subChannelType.getStorageTypeMax());
+ return tensorType.clone(perTensorType);
+ }
+
+ if (isConvertibleToPerAxis(tensorType)) {
+ auto shape = subChannelType.getScales().getType().getShape();
+ auto quantizedDimItr =
+ llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
+ auto scales = llvm::to_vector(llvm::map_range(
+ subChannelType.getScales().getValues<APFloat>(),
+ [](APFloat scale) { return scale.convertToDouble(); }));
+ auto zeroPoints = llvm::to_vector(llvm::map_range(
+ subChannelType.getZeroPoints().getValues<APInt>(),
+ [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
+ auto perAxisType = UniformQuantizedPerAxisType::get(
+ subChannelType.getFlags(), subChannelType.getStorageType(),
+ subChannelType.getExpressedType(), scales, zeroPoints,
+ quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
+ subChannelType.getStorageTypeMax());
+ return tensorType.clone(perAxisType);
+ }
+ return type;
+ }
+
+public:
+ explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
+};
+
+/// This class implements a conversion pattern that converts any generic
+/// operation with sub-channel quantized types to an equivalent operation with
+/// per-tensor or per-axis quantized types.
+class ConvertGenericOpwithSubChannelType : public ConversionPattern {
+public:
+ ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
+ MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ SmallVector<Type> resultTypes;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
+ return failure();
+
+ auto *newOp = Operation::create(
+ op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+ for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
+ Region &before = std::get<0>(regions);
+ Region &parent = std::get<1>(regions);
+ rewriter.inlineRegionBefore(before, parent, parent.end());
+ if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
+ return failure();
+ }
+ rewriter.insert(newOp);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+// Conversion pass
+class NormalizeQuantTypes
+ : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
+public:
+ void runOnOperation() override {
+
+ auto moduleOp = cast<ModuleOp>(getOperation());
----------------
GleasonK wrote:
Looks like this is done? Feel free to resolve.
https://github.com/llvm/llvm-project/pull/120172
More information about the Mlir-commits
mailing list