[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)
Rafael Ubal
llvmlistbot at llvm.org
Tue Jul 30 08:47:25 PDT 2024
================
@@ -6,44 +6,215 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
#include "QuantDialectBytecode.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
-void QuantizationDialect::initialize() {
+namespace mlir {
+namespace quant {
+
+namespace {
+
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+// Any quantized type. Any quantized type with no per-axis quantization is
+// ignored.
+//
+// - containerType
+// Original input or result type of the operation using the provided quantized
+// type. Used to ensure that the quantized type appears within a tensor and
+// that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+ QuantizedType quantizedType,
+ Type containerType) {
+ auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+ if (!quantizedPerAxisType)
+ return success();
+
+ auto tensorType = dyn_cast<TensorType>(containerType);
+ if (!tensorType)
+ return op->emitError("scalar types may not use per-axis quantization");
+
+ if (!tensorType.hasRank())
+ return success();
+
+ int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+ if (quantizedDimension >= tensorType.getRank())
+ return op->emitError("quantized dimension must be less than tensor rank");
+
+ int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
+ if (quantizedDimensionSize != ShapedType::kDynamic &&
+ quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+ return op->emitError(
+ "quantized dimension size does not match number of scales");
+
+ return success();
+}
+
+// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
+//
+// - quantizedType
+// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
+// whether as a primitive type or in a tensor.
+//
+// - floatType
+// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
+// whether as a primitive type or in a tensor.
+//
+// - containerType
+// Type of original input or result.
+//
+LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
+ FloatType floatType, Type containerType) {
+ if (quantizedType.getExpressedType() != floatType)
+ return op->emitError(
+ "expressed type in quantized type expected to match float type");
+
+ // Veriy integrity of per-axis quantization information, if present.
+ return verifyPerAxisQuantization(op, quantizedType, containerType);
+}
+
+} // namespace
+
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
+void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
>();
- addBytecodeInterface(this);
+ detail::addBytecodeInterface(this);
+}
+
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DequantizeCastOp::verify() {
+ return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+ getInput().getType());
+}
+
+OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
+ // with the value of x. Values x and y are guaranteed to be of the same type
+ // in this pattern.
+ auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
+ if (!srcQcastOp)
+ return {};
+ assert(srcQcastOp.getInput().getType() == getType());
----------------
rafaelubalmw wrote:
Yes. In a pattern `floatType1 -> [qcast] -> quantizedType -> [dcast] -> floatType2`, it is guaranteed that `floatType1 = floatType2`, both in the form of a scalar or a tensor. This is enforced by the verifiers of `qcast` and `dcast`, which check for matching float and expressed types, as well as for identical tensor shapes. I wouldn't want to use conditional execution and pattern match failure since that should never occur for well-formed IR.
https://github.com/llvm/llvm-project/pull/100667
More information about the Mlir-commits
mailing list