[Mlir-commits] [mlir] Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect (PR #152966)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 11 00:38:15 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (Roman-Pevnyi)

<details>
<summary>Changes</summary>

Currently, UniformQuantizedType only supports built-in MLIR storage types such as Integer. LLM quantization research introducing feature of using NF4 as a low precision datatype (see https://arxiv.org/pdf/2305.14314). There is a growing need to make the system extensible and maintainable as more types are added. Ensuring that MLIR can natively support NF4 through a clean, extensible interface is essential for both current and future quantization workflows.

**Current Approach and Its Limitations:**

- The present implementation relies on dynamic checks (e.g., type switches or if-else chains) to determine the storage type and retrieve type-specific information for legality checks.

- This approach works for a small, fixed set of types, but as the number of supported types grows, the code becomes harder to read, maintain, and extend.

**Proposed Interface-Based Approach:**

- Define a StorageTypeInterface that specifies the required methods any storage type must implement to be used in UniformQuantizedType.
- Each storage type (Integer, Float8E5M2, Float8E4M3FN, and new types like NF4) would implement this interface, encapsulating their type-specific logic.
- When UniformQuantizedType needs to check legality or retrieve information, it can use MLIR’s dyn_cast mechanism to check if the type implements the interface and then call the required methods.
- This design decouples UniformQuantizedType from the specifics of each storage type, making it easy to add new types (such as NF4) without modifying the core logic or introducing more type checks.

**Benefits:**

- Extensibility: New storage types can be added by simply implementing the interface, without touching the core UniformQuantizedType logic.
- Readability: The code is cleaner, as it avoids large switch statements or if-else chains.
- Maintainability: Type-specific logic is encapsulated within each type, reducing the risk of errors and making the codebase easier to understand and update.

---
Full diff: https://github.com/llvm/llvm-project/pull/152966.diff


10 Files Affected:

- (modified) mlir/cmake/modules/AddMLIR.cmake (+8) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+2) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+28-1) 
- (modified) mlir/include/mlir/IR/CMakeLists.txt (+2) 
- (added) mlir/include/mlir/IR/QuantizationInterface.h (+22) 
- (added) mlir/include/mlir/IR/QuantizationInterface.td (+44) 
- (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+35-30) 
- (modified) mlir/lib/Dialect/Quant/IR/TypeParser.cpp (+40-34) 
- (modified) mlir/lib/IR/CMakeLists.txt (+3-1) 
- (added) mlir/lib/IR/QuantizationInterface.cpp (+23) 


``````````diff
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
   add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
 endfunction()
 
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+  set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+  mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+  mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+  add_public_tablegen_target(MLIR${interface}IncGen)
+  add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
 
 # Generate Documentation
 function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
 // Tablegen Type Declarations
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/QuantizationInterface.h"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.h.inc"
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
 include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface, QuantizationInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+    /// QuantizationInterface method implementations
+    /// Return true if this is a signed integer type.
+    bool isStorageSigned() const { return !isUnsigned(); }
+    /// Get the bit width of this integer type.
+    unsigned getStorageWidth() const { return getWidth(); }
+    
+    /// Get default minimum value for this integer type.
+    int64_t getDefaultMinimum() const {
+      if (isStorageSigned()) {
+        return llvm::minIntN(getStorageWidth());
+      }
+      return 0;
+    }
+    /// Get default maximum value for this integer type.
+    int64_t getDefaultMaximum() const {
+      if (isStorageSigned()) {
+        return llvm::maxIntN(getStorageWidth());
+      }
+      return llvm::maxUIntN(getStorageWidth());
+    }
+    
+    /// Get the storage type as a string.
+    std::string getStorageType() const {
+      return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+add_mlir_type_interface(QuantizationInterface)
+
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
 mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
 mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+  let description = [{
+    Interface for types that can be used as storage types in Quant dialect.
+    This interface provides methods to determine storage characteristics for quantization purposes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Check if the storage type is signed.
+      Returns true if the type represents signed values, false for unsigned.
+    }],
+    "bool", "isStorageSigned", (ins)>,
+
+    InterfaceMethod<[{
+      Get the bit width of this integer type.
+      Returns the number of bits used to store values of this type.
+    }],
+    "unsigned", "getStorageWidth", (ins)>,
+
+    InterfaceMethod<[{
+      Get default minimum value for this integer type.
+    }],
+    "int64_t", "getDefaultMinimum", (ins)>,
+
+    InterfaceMethod<[{
+      Get default maximum value for this integer type.
+    }],
+    "int64_t", "getDefaultMaximum", (ins)>,
+
+    InterfaceMethod<[{
+      Get the storage type as a string.
+    }],
+    "std::string", "getStorageType", (ins)>
+  ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
   if (!intStorageType)
     return emitError() << "storage type must be integral";
-  unsigned integralWidth = intStorageType.getWidth();
-
-  // Verify storage width.
-  if (integralWidth == 0 || integralWidth > MaxStorageBits)
-    return emitError() << "illegal storage type size: " << integralWidth;
-
-  // Verify storageTypeMin and storageTypeMax.
-  bool isSigned =
-      (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSigned, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSigned, integralWidth);
-  if (storageTypeMax - storageTypeMin <= 0 ||
-      storageTypeMin < defaultIntegerMin ||
-      storageTypeMax > defaultIntegerMax) {
-    return emitError() << "illegal storage min and storage max: ("
-                       << storageTypeMin << ":" << storageTypeMax << ")";
+
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+    // Verify storage width.
+    if (integralWidth == 0 || integralWidth > MaxStorageBits)
+      return emitError() << "illegal storage type size: " << integralWidth;
+
+    int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+    int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+    if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+        storageTypeMax > defaultMax) {
+      return emitError() << "illegal storage min and storage max: ("
+                         << storageTypeMin << ":" << storageTypeMax << ")";
+    }
+
+    return success();
   }
-  return success();
+
+  return emitError() << "storage type must implement QuantizationInterface";
 }
 
 Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
 }
 
 bool QuantizedType::hasStorageTypeBounds() const {
-  unsigned int integralWidth = getStorageTypeIntegralWidth();
-  bool isSignedInteger = isSigned();
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSignedInteger, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSignedInteger, integralWidth);
-  return defaultIntegerMin != getStorageTypeMin() ||
-         defaultIntegerMax != getStorageTypeMax();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+  int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+  return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
 }
 
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
-  // NOTE: If ever supporting non-integral storage types, some other scheme
-  // for determining the width will be needed.
-  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  return quantizationInterface.getStorageWidth();
 }
 
 Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
 
 using namespace mlir;
 using namespace quant;
 
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   auto typeLoc = parser.getCurrentLocation();
-  IntegerType type;
+  Type type;
 
   // Parse storage type (alpha_ident, integer_literal).
   StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   if (result.has_value()) {
     if (!succeeded(*result))
       return nullptr;
-    isSigned = !type.isUnsigned();
-    storageTypeWidth = type.getWidth();
-  } else if (succeeded(parser.parseKeyword(&identifier))) {
-    // Otherwise, this must be an unsigned integer (`u` integer-literal).
-    if (!identifier.consume_front("u")) {
-      parser.emitError(typeLoc, "illegal storage type prefix");
+
+    if (auto quantizationInterface =
+            llvm::dyn_cast<QuantizationInterface>(type)) {
+      isSigned = quantizationInterface.isStorageSigned();
+      storageTypeWidth = quantizationInterface.getStorageWidth();
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    if (identifier.getAsInteger(10, storageTypeWidth)) {
-      parser.emitError(typeLoc, "expected storage type width");
+  } else if (succeeded(parser.parseKeyword(&identifier))) {
+    // Otherwise, this must be an unsigned integer (`u` integer-literal)
+    if (identifier.consume_front("u")) {
+      if (identifier.getAsInteger(10, storageTypeWidth)) {
+        parser.emitError(typeLoc, "expected storage type width");
+        return nullptr;
+      }
+      isSigned = false;
+      type = parser.getBuilder().getIntegerType(storageTypeWidth);
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    isSigned = false;
-    type = parser.getBuilder().getIntegerType(storageTypeWidth);
   } else {
     return nullptr;
   }
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   return type;
 }
 
-static ParseResult parseStorageRange(DialectAsmParser &parser,
-                                     IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
                                      int64_t &storageTypeMin,
                                      int64_t &storageTypeMax) {
-  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
-      isSigned, storageType.getWidth());
-  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
-      isSigned, storageType.getWidth());
+  int64_t defaultMin, defaultMax;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    defaultMin = quantizationInterface.getDefaultMinimum();
+    defaultMax = quantizationInterface.getDefaultMaximum();
+  }
+
   if (failed(parser.parseOptionalLess())) {
-    storageTypeMin = defaultIntegerMin;
-    storageTypeMax = defaultIntegerMax;
+    storageTypeMin = defaultMin;
+    storageTypeMax = defaultMax;
     return success();
   }
 
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
       parser.getCurrentLocation(&maxLoc) ||
       parser.parseInteger(storageTypeMax) || parser.parseGreater())
     return failure();
-  if (storageTypeMin < defaultIntegerMin) {
+  if (storageTypeMin < defaultMin) {
     return parser.emitError(minLoc, "illegal storage type minimum: ")
            << storageTypeMin;
   }
-  if (storageTypeMax > defaultIntegerMax) {
+  if (storageTypeMax > defaultMax) {
     return parser.emitError(maxLoc, "illegal storage type maximum: ")
            << storageTypeMax;
   }
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
 ///   storage-type ::= (`i` | `u`) integer-literal
 ///   expressed-type-spec ::= `:` `f` integer-literal
 static Type parseAnyType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
 ///     scale-zero-tensor (`,` scale-zero-tensor)*
 ///   `}`
 static Type parseUniformType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
 
 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   // storage type
-  unsigned storageWidth = type.getStorageTypeIntegralWidth();
-  bool isSigned = type.isSigned();
-  if (isSigned) {
-    out << "i" << storageWidth;
-  } else {
-    out << "u" << storageWidth;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+    out << quantizationInterface.getStorageType();
   }
 
   // storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
   OperationSupport.cpp
   PatternLoggingListener.cpp
   PatternMatch.cpp
+  QuantizationInterface.cpp  
   Region.cpp
   RegionKindInterface.cpp
   SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
-
+  MLIRQuantizationInterfaceIncGen
+  
   LINK_LIBS PUBLIC
   MLIRSupport
   )
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"

``````````

</details>


https://github.com/llvm/llvm-project/pull/152966


More information about the Mlir-commits mailing list