[Mlir-commits] [mlir] 94e4ec6 - Add CalibratedQuantizedType to quant dialect

Feng Liu llvmlistbot at llvm.org
Tue Nov 17 22:17:34 PST 2020


Author: Tei Jeong
Date: 2020-11-17T22:14:54-08:00
New Revision: 94e4ec6499a237aeec4f1fe8f2cc1e9bcb33f971

URL: https://github.com/llvm/llvm-project/commit/94e4ec6499a237aeec4f1fe8f2cc1e9bcb33f971
DIFF: https://github.com/llvm/llvm-project/commit/94e4ec6499a237aeec4f1fe8f2cc1e9bcb33f971.diff

LOG: Add CalibratedQuantizedType to quant dialect

This type supports a calibrated type with min, max provided.

This will be used for importing calibration values of intermediate tensors (e.g. LSTM) which can't be imported with QuantStats op.

This type was initially suggested in the following RFC: https://llvm.discourse.group/t/rfc-a-proposal-for-implementing-quantization-transformations-in-mlir/655

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D91584

Added: 
    mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir
    mlir/test/Dialect/Quant/parse-calibrated.mlir

Modified: 
    mlir/include/mlir/Dialect/Quant/QuantTypes.h
    mlir/lib/Dialect/Quant/IR/QuantOps.cpp
    mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
    mlir/lib/Dialect/Quant/IR/TypeDetail.h
    mlir/lib/Dialect/Quant/IR/TypeParser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 567b63936dd3..7add13636695 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -28,6 +28,7 @@ struct QuantizedTypeStorage;
 struct AnyQuantizedTypeStorage;
 struct UniformQuantizedTypeStorage;
 struct UniformQuantizedPerAxisTypeStorage;
+struct CalibratedQuantizedTypeStorage;
 
 } // namespace detail
 
@@ -371,6 +372,34 @@ class UniformQuantizedPerAxisType
   }
 };
 
+/// A quantized type that infers its range from given min/max values.
+///
+/// Typical syntax:
+///   quant.calibrated<f32<-0.922,0.981>>
+class CalibratedQuantizedType
+    : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
+                            detail::CalibratedQuantizedTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static CalibratedQuantizedType get(Type expressedType, double min,
+                                     double max);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static CalibratedQuantizedType getChecked(Type expressedType, double min,
+                                            double max, Location location);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    Type expressedType,
+                                                    double min, double max);
+  double getMin() const;
+  double getMax() const;
+};
+
 } // namespace quant
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index e7df59abc945..62a527b13243 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -24,7 +24,7 @@ using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
 void QuantizationDialect::initialize() {
-  addTypes<AnyQuantizedType, UniformQuantizedType,
+  addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
   addOperations<
 #define GET_OP_LIST

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 41e64d1540f3..66d804c8763a 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -354,3 +354,32 @@ ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
   return getImpl()->quantizedDimension;
 }
+
+CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
+                                                     double min, double max) {
+  return Base::get(expressedType.getContext(), expressedType, min, max);
+}
+
+CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType,
+                                                            double min,
+                                                            double max,
+                                                            Location location) {
+  return Base::getChecked(location, expressedType, min, max);
+}
+
+LogicalResult CalibratedQuantizedType::verifyConstructionInvariants(
+    Location loc, Type expressedType, double min, double max) {
+  // Verify that the expressed type is floating point.
+  // If this restriction is ever eliminated, the parser/printer must be
+  // extended.
+  if (!expressedType.isa<FloatType>())
+    return emitError(loc, "expressed type must be floating point");
+  if (max <= min)
+    return emitError(loc, "illegal min and max: (") << min << ":" << max << ")";
+
+  return success();
+}
+
+double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
+
+double CalibratedQuantizedType::getMax() const { return getImpl()->max; }

diff  --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h
index 33f537b031d7..82b28e8e06f6 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h
@@ -253,6 +253,56 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
   int32_t quantizedDimension;
 };
 
+struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(Type expressedType, double min, double max)
+        : expressedType(expressedType), min(min), max(max) {}
+    // Floating point type that the quantized type approximates.
+    Type expressedType;
+
+    double min;
+    double max;
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min &&
+             lhs.max == rhs.max;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      int64_t minBits = llvm::bit_cast<double>(min);
+      int64_t maxBits = llvm::bit_cast<double>(max);
+      return llvm::hash_combine(expressedType, minBits, maxBits);
+    }
+  };
+
+  CalibratedQuantizedTypeStorage(const KeyTy &key)
+      : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0),
+        min(key.min), max(key.max) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static CalibratedQuantizedTypeStorage *
+  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<CalibratedQuantizedTypeStorage>())
+        CalibratedQuantizedTypeStorage(key);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+  double min;
+  double max;
+};
+
 } // namespace detail
 } // namespace quant
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index c3fc3e5775c6..6894463427fd 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -91,9 +91,28 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
   return success();
 }
 
-/// Parses a UniformQuantizedType.
+static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
+                                            double &min, double &max) {
+  auto typeLoc = parser.getCurrentLocation();
+  FloatType type;
+
+  if (failed(parser.parseType(type))) {
+    parser.emitError(typeLoc, "expecting float expressed type");
+    return nullptr;
+  }
+
+  // Calibrated min and max values.
+  if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
+      parser.parseFloat(max) || parser.parseGreater()) {
+    parser.emitError(typeLoc, "calibrated values must be present");
+    return nullptr;
+  }
+  return type;
+}
+
+/// Parses an AnyQuantizedType.
 ///
-///   uniform_per_layer ::= `any<` storage-spec (expressed-type-spec)?`>`
+///   any ::= `any<` storage-spec (expressed-type-spec)?`>`
 ///   storage-spec ::= storage-type (`<` storage-range `>`)?
 ///   storage-range ::= integer-literal `:` integer-literal
 ///   storage-type ::= (`i` | `u`) integer-literal
@@ -269,6 +288,34 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
                                           storageTypeMin, storageTypeMax, loc);
 }
 
+/// Parses an CalibratedQuantizedType.
+///
+///   calibrated ::= `calibrated<` expressed-spec `>`
+///   expressed-spec ::= expressed-type `<` calibrated-range `>`
+///   expressed-type ::= `f` integer-literal
+///   calibrated-range ::= float-literal `:` float-literal
+static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
+  FloatType expressedType;
+  double min;
+  double max;
+
+  // Type specification.
+  if (parser.parseLess())
+    return nullptr;
+
+  // Expressed type.
+  expressedType = parseExpressedTypeAndRange(parser, min, max);
+  if (!expressedType) {
+    return nullptr;
+  }
+
+  if (parser.parseGreater()) {
+    return nullptr;
+  }
+
+  return CalibratedQuantizedType::getChecked(expressedType, min, max, loc);
+}
+
 /// Parse a type registered to this dialect.
 Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
   Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
@@ -282,6 +329,8 @@ Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
     return parseUniformType(parser, loc);
   if (typeNameSpelling == "any")
     return parseAnyType(parser, loc);
+  if (typeNameSpelling == "calibrated")
+    return parseCalibratedType(parser, loc);
 
   parser.emitError(parser.getNameLoc(),
                    "unknown quantized type " + typeNameSpelling);
@@ -318,7 +367,7 @@ static void printQuantParams(double scale, int64_t zeroPoint,
   }
 }
 
-/// Helper that prints a UniformQuantizedType.
+/// Helper that prints a AnyQuantizedType.
 static void printAnyQuantizedType(AnyQuantizedType type,
                                   DialectAsmPrinter &out) {
   out << "any<";
@@ -363,6 +412,14 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
   out << "}>";
 }
 
+/// Helper that prints a CalibratedQuantizedType.
+static void printCalibratedQuantizedType(CalibratedQuantizedType type,
+                                         DialectAsmPrinter &out) {
+  out << "calibrated<" << type.getExpressedType();
+  out << "<" << type.getMin() << ", " << type.getMax() << ">";
+  out << ">";
+}
+
 /// Print a type registered to this dialect.
 void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
   if (auto anyType = type.dyn_cast<AnyQuantizedType>())
@@ -371,6 +428,8 @@ void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
     printUniformQuantizedType(uniformType, os);
   else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
     printUniformQuantizedPerAxisType(perAxisType, os);
+  else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
+    printCalibratedQuantizedType(calibratedType, os);
   else
     llvm_unreachable("Unhandled quantized type");
 }

diff  --git a/mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir b/mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir
new file mode 100644
index 000000000000..eefc0dfd3382
--- /dev/null
+++ b/mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+// Unrecognized token: missing calibrated type maximum
+// expected-error at +2 {{calibrated values must be present}}
+// expected-error at +1 {{expected ':'}}
+!qalias = type !quant.calibrated<f32<-0.998>>
+
+// -----
+// Unrecognized token: missing closing angle bracket
+// expected-error at +1 {{expected '>'}}
+!qalias = type !quant<"calibrated<f32<-0.998:1.232>">
+
+// -----
+// Unrecognized expressed type: integer type
+// expected-error at +2 {{invalid kind of type specified}}
+// expected-error at +1 {{expecting float expressed type}}
+!qalias = type !quant.calibrated<i8<-4:3>>
+
+// -----
+// Illegal storage min/max: max - min < 0
+// expected-error at +1 {{illegal min and max: (1.000000e+00:-1.000000e+00)}}
+!qalias = type !quant.calibrated<f32<1.0:-1.0>>
+
+// -----
+// Illegal storage min/max: max - min == 0
+// expected-error at +1 {{illegal min and max: (1.000000e+00:1.000000e+00)}}
+!qalias = type !quant.calibrated<f32<1.0:1.0>>

diff  --git a/mlir/test/Dialect/Quant/parse-calibrated.mlir b/mlir/test/Dialect/Quant/parse-calibrated.mlir
new file mode 100644
index 000000000000..648715fd4c49
--- /dev/null
+++ b/mlir/test/Dialect/Quant/parse-calibrated.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | FileCheck %s
+
+// -----
+// CHECK-LABEL: parseCalibrated
+// CHECK: !quant.calibrated<f32<-0.998, 1.232100e+00>
+!qalias = type !quant.calibrated<f32<-0.998:1.2321>>
+func @parseCalibrated() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}


        


More information about the Mlir-commits mailing list