[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)
Spenser Bauman
llvmlistbot at llvm.org
Fri Jul 26 17:54:16 PDT 2024
================
@@ -6,44 +6,215 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/QuantOps.h"
#include "QuantDialectBytecode.h"
#include "TypeDetail.h"
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
-void QuantizationDialect::initialize() {
+namespace mlir {
+namespace quant {
+
+namespace {
+
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+// Any quantized type. Any quantized type with no per-axis quantization is
+// ignored.
+//
+// - containerType
+// Original input or result type of the operation using the provided quantized
+// type. Used to ensure that the quantized type appears within a tensor and
+// that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+ QuantizedType quantizedType,
+ Type containerType) {
+ auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+ if (!quantizedPerAxisType)
+ return success();
+
+ auto tensorType = dyn_cast<TensorType>(containerType);
+ if (!tensorType)
+ return op->emitError("scalar types may not use per-axis quantization");
+
+ if (!tensorType.hasRank())
+ return success();
+
+ int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+ if (quantizedDimension >= tensorType.getRank())
+ return op->emitError("quantized dimension must be less than tensor rank");
+
+ int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
+ if (quantizedDimensionSize != ShapedType::kDynamic &&
+ quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
+ return op->emitError(
+ "quantized dimension size does not match number of scales");
+
+ return success();
+}
+
+// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
+//
+// - quantizedType
+// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
+// whether as a primitive type or in a tensor.
+//
+// - floatType
+// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
+// whether as a primitive type or in a tensor.
+//
+// - containerType
+// Type of original input or result.
+//
+LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
+ FloatType floatType, Type containerType) {
+ if (quantizedType.getExpressedType() != floatType)
+ return op->emitError(
+ "expressed type in quantized type expected to match float type");
+
+ // Veriy integrity of per-axis quantization information, if present.
+ return verifyPerAxisQuantization(op, quantizedType, containerType);
+}
+
+} // namespace
+
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
+void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
+#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
>();
- addBytecodeInterface(this);
+ detail::addBytecodeInterface(this);
+}
+
+
+//===----------------------------------------------------------------------===//
+// DequantizeCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DequantizeCastOp::verify() {
+ return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
+ getInput().getType());
+}
+
+OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
+ // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
+ // with the value of x. Values x and y are guaranteed to be of the same type
+ // in this pattern.
+ auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
+ if (!srcQcastOp)
+ return {};
+ assert(srcQcastOp.getInput().getType() == getType());
----------------
sabauma wrote:
I'm not sure the fold method is the right place to do this kind of correctness check on the IR structure. Do the verifiers enforce this invariant as well?
The other fold methods look like they just fail to fold when the source and target types do not mat.
https://github.com/llvm/llvm-project/pull/100667
More information about the Mlir-commits
mailing list