[Mlir-commits] [mlir] [mlir] Improvements to the 'quant' dialect (PR #100667)

Rafael Ubal llvmlistbot at llvm.org
Thu Sep 26 09:49:47 PDT 2024


================
@@ -6,44 +6,215 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Quant/QuantOps.h"
 #include "QuantDialectBytecode.h"
 #include "TypeDetail.h"
 
-#include "mlir/Dialect/Quant/QuantTypes.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/MathExtras.h"
 #include <numeric>
 
-using namespace mlir;
-using namespace mlir::quant;
-using namespace mlir::quant::detail;
+#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
 
-#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
 
-void QuantizationDialect::initialize() {
+namespace mlir {
+namespace quant {
+
+namespace {
+
+// Verify the integrity of per-axis quantization information, if present.
+//
+// - quantizedType
+//   Any quantized type. Any quantized type with no per-axis quantization is
+//   ignored.
+//
+// - containerType
+//   Original input or result type of the operation using the provided quantized
+//   type. Used to ensure that the quantized type appears within a tensor and
+//   that the tensor is compatible with per-axis quantization information.
+//
+LogicalResult verifyPerAxisQuantization(Operation *op,
+                                        QuantizedType quantizedType,
+                                        Type containerType) {
+  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
+  if (!quantizedPerAxisType)
+    return success();
+
+  auto tensorType = dyn_cast<TensorType>(containerType);
+  if (!tensorType)
+    return op->emitError("scalar types may not use per-axis quantization");
+
+  if (!tensorType.hasRank())
+    return success();
+
+  int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
+  if (quantizedDimension >= tensorType.getRank())
----------------
rafaelubalmw wrote:

Since this is a static test, I've added it in `QuantTypes.cpp` at `UniformQuantizedPerAxisType::verifyInvariants()`.

https://github.com/llvm/llvm-project/pull/100667


More information about the Mlir-commits mailing list