[Mlir-commits] [mlir] Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect (PR #152966)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 20 05:45:31 PDT 2025


https://github.com/Roman-Pevnyi updated https://github.com/llvm/llvm-project/pull/152966

>From 784e1d7d23b8712840f25768cee890142b94caaa Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Mon, 11 Aug 2025 09:33:14 +0200
Subject: [PATCH 1/4] Added new type interafce to let UniformQuantizeType
 accept other than built in types. Updated parser and printer in Quant dialect

---
 mlir/cmake/modules/AddMLIR.cmake              |  8 ++
 mlir/include/mlir/IR/BuiltinTypes.h           |  2 +
 mlir/include/mlir/IR/BuiltinTypes.td          | 29 +++++++-
 mlir/include/mlir/IR/CMakeLists.txt           |  2 +
 mlir/include/mlir/IR/QuantizationInterface.h  | 22 ++++++
 mlir/include/mlir/IR/QuantizationInterface.td | 44 +++++++++++
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      | 65 ++++++++--------
 mlir/lib/Dialect/Quant/IR/TypeParser.cpp      | 74 ++++++++++---------
 mlir/lib/IR/CMakeLists.txt                    |  4 +-
 mlir/lib/IR/QuantizationInterface.cpp         | 23 ++++++
 10 files changed, 207 insertions(+), 66 deletions(-)
 create mode 100644 mlir/include/mlir/IR/QuantizationInterface.h
 create mode 100644 mlir/include/mlir/IR/QuantizationInterface.td
 create mode 100644 mlir/lib/IR/QuantizationInterface.cpp

diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
   add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
 endfunction()
 
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+  set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+  mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+  mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+  add_public_tablegen_target(MLIR${interface}IncGen)
+  add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
 
 # Generate Documentation
 function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
 // Tablegen Type Declarations
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/QuantizationInterface.h"
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.h.inc"
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
 include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [VectorElementTypeInterface, QuantizationInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+    /// QuantizationInterface method implementations
+    /// Return true if this is a signed integer type.
+    bool isStorageSigned() const { return !isUnsigned(); }
+    /// Get the bit width of this integer type.
+    unsigned getStorageWidth() const { return getWidth(); }
+    
+    /// Get default minimum value for this integer type.
+    int64_t getDefaultMinimum() const {
+      if (isStorageSigned()) {
+        return llvm::minIntN(getStorageWidth());
+      }
+      return 0;
+    }
+    /// Get default maximum value for this integer type.
+    int64_t getDefaultMaximum() const {
+      if (isStorageSigned()) {
+        return llvm::maxIntN(getStorageWidth());
+      }
+      return llvm::maxUIntN(getStorageWidth());
+    }
+    
+    /// Get the storage type as a string.
+    std::string getStorageType() const {
+      return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
+add_mlir_type_interface(QuantizationInterface)
+
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
 mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
 mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+  let description = [{
+    Interface for types that can be used as storage types in Quant dialect.
+    This interface provides methods to determine storage characteristics for quantization purposes.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Check if the storage type is signed.
+      Returns true if the type represents signed values, false for unsigned.
+    }],
+    "bool", "isStorageSigned", (ins)>,
+
+    InterfaceMethod<[{
+      Get the bit width of this integer type.
+      Returns the number of bits used to store values of this type.
+    }],
+    "unsigned", "getStorageWidth", (ins)>,
+
+    InterfaceMethod<[{
+      Get default minimum value for this integer type.
+    }],
+    "int64_t", "getDefaultMinimum", (ins)>,
+
+    InterfaceMethod<[{
+      Get default maximum value for this integer type.
+    }],
+    "int64_t", "getDefaultMaximum", (ins)>,
+
+    InterfaceMethod<[{
+      Get the storage type as a string.
+    }],
+    "std::string", "getStorageType", (ins)>
+  ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
   if (!intStorageType)
     return emitError() << "storage type must be integral";
-  unsigned integralWidth = intStorageType.getWidth();
-
-  // Verify storage width.
-  if (integralWidth == 0 || integralWidth > MaxStorageBits)
-    return emitError() << "illegal storage type size: " << integralWidth;
-
-  // Verify storageTypeMin and storageTypeMax.
-  bool isSigned =
-      (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSigned, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSigned, integralWidth);
-  if (storageTypeMax - storageTypeMin <= 0 ||
-      storageTypeMin < defaultIntegerMin ||
-      storageTypeMax > defaultIntegerMax) {
-    return emitError() << "illegal storage min and storage max: ("
-                       << storageTypeMin << ":" << storageTypeMax << ")";
+
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+    // Verify storage width.
+    if (integralWidth == 0 || integralWidth > MaxStorageBits)
+      return emitError() << "illegal storage type size: " << integralWidth;
+
+    int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+    int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+    if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+        storageTypeMax > defaultMax) {
+      return emitError() << "illegal storage min and storage max: ("
+                         << storageTypeMin << ":" << storageTypeMax << ")";
+    }
+
+    return success();
   }
-  return success();
+
+  return emitError() << "storage type must implement QuantizationInterface";
 }
 
 Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
 }
 
 bool QuantizedType::hasStorageTypeBounds() const {
-  unsigned int integralWidth = getStorageTypeIntegralWidth();
-  bool isSignedInteger = isSigned();
-  int64_t defaultIntegerMin =
-      getDefaultMinimumForInteger(isSignedInteger, integralWidth);
-  int64_t defaultIntegerMax =
-      getDefaultMaximumForInteger(isSignedInteger, integralWidth);
-  return defaultIntegerMin != getStorageTypeMin() ||
-         defaultIntegerMax != getStorageTypeMax();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+  int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+  return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
 }
 
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
-  // NOTE: If ever supporting non-integral storage types, some other scheme
-  // for determining the width will be needed.
-  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+  Type storageType = static_cast<ImplType *>(impl)->storageType;
+  auto quantizationInterface =
+      llvm::dyn_cast<QuantizationInterface>(storageType);
+
+  return quantizationInterface.getStorageWidth();
 }
 
 Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
 
 using namespace mlir;
 using namespace quant;
 
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   auto typeLoc = parser.getCurrentLocation();
-  IntegerType type;
+  Type type;
 
   // Parse storage type (alpha_ident, integer_literal).
   StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   if (result.has_value()) {
     if (!succeeded(*result))
       return nullptr;
-    isSigned = !type.isUnsigned();
-    storageTypeWidth = type.getWidth();
-  } else if (succeeded(parser.parseKeyword(&identifier))) {
-    // Otherwise, this must be an unsigned integer (`u` integer-literal).
-    if (!identifier.consume_front("u")) {
-      parser.emitError(typeLoc, "illegal storage type prefix");
+
+    if (auto quantizationInterface =
+            llvm::dyn_cast<QuantizationInterface>(type)) {
+      isSigned = quantizationInterface.isStorageSigned();
+      storageTypeWidth = quantizationInterface.getStorageWidth();
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    if (identifier.getAsInteger(10, storageTypeWidth)) {
-      parser.emitError(typeLoc, "expected storage type width");
+  } else if (succeeded(parser.parseKeyword(&identifier))) {
+    // Otherwise, this must be an unsigned integer (`u` integer-literal)
+    if (identifier.consume_front("u")) {
+      if (identifier.getAsInteger(10, storageTypeWidth)) {
+        parser.emitError(typeLoc, "expected storage type width");
+        return nullptr;
+      }
+      isSigned = false;
+      type = parser.getBuilder().getIntegerType(storageTypeWidth);
+    } else {
+      parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
     }
-    isSigned = false;
-    type = parser.getBuilder().getIntegerType(storageTypeWidth);
   } else {
     return nullptr;
   }
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
   return type;
 }
 
-static ParseResult parseStorageRange(DialectAsmParser &parser,
-                                     IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
                                      int64_t &storageTypeMin,
                                      int64_t &storageTypeMax) {
-  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
-      isSigned, storageType.getWidth());
-  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
-      isSigned, storageType.getWidth());
+  int64_t defaultMin, defaultMax;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(storageType)) {
+    defaultMin = quantizationInterface.getDefaultMinimum();
+    defaultMax = quantizationInterface.getDefaultMaximum();
+  }
+
   if (failed(parser.parseOptionalLess())) {
-    storageTypeMin = defaultIntegerMin;
-    storageTypeMax = defaultIntegerMax;
+    storageTypeMin = defaultMin;
+    storageTypeMax = defaultMax;
     return success();
   }
 
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
       parser.getCurrentLocation(&maxLoc) ||
       parser.parseInteger(storageTypeMax) || parser.parseGreater())
     return failure();
-  if (storageTypeMin < defaultIntegerMin) {
+  if (storageTypeMin < defaultMin) {
     return parser.emitError(minLoc, "illegal storage type minimum: ")
            << storageTypeMin;
   }
-  if (storageTypeMax > defaultIntegerMax) {
+  if (storageTypeMax > defaultMax) {
     return parser.emitError(maxLoc, "illegal storage type maximum: ")
            << storageTypeMax;
   }
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
 ///   storage-type ::= (`i` | `u`) integer-literal
 ///   expressed-type-spec ::= `:` `f` integer-literal
 static Type parseAnyType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
 ///     scale-zero-tensor (`,` scale-zero-tensor)*
 ///   `}`
 static Type parseUniformType(DialectAsmParser &parser) {
-  IntegerType storageType;
+  Type storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
   int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
   }
 
   // Storage type range.
-  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
-                        storageTypeMax)) {
+  if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
     return nullptr;
   }
 
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
 
 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   // storage type
-  unsigned storageWidth = type.getStorageTypeIntegralWidth();
-  bool isSigned = type.isSigned();
-  if (isSigned) {
-    out << "i" << storageWidth;
-  } else {
-    out << "u" << storageWidth;
+  if (auto quantizationInterface =
+          llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+    out << quantizationInterface.getStorageType();
   }
 
   // storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
   OperationSupport.cpp
   PatternLoggingListener.cpp
   PatternMatch.cpp
+  QuantizationInterface.cpp  
   Region.cpp
   RegionKindInterface.cpp
   SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
-
+  MLIRQuantizationInterfaceIncGen
+  
   LINK_LIBS PUBLIC
   MLIRSupport
   )
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"

>From 407fae44ee95ac8f53b587e17f0d26cd8bdbaa1b Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 13 Aug 2025 13:13:42 +0200
Subject: [PATCH 2/4] Added QuantizationInterface to Float8E5M2Type and
 Float8E4M3FNType

---
 mlir/include/mlir/IR/BuiltinTypes.td | 54 ++++++++++++++++++++++------
 mlir/lib/IR/CMakeLists.txt           |  2 +-
 2 files changed, 44 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 762f9262adbf2..12ac05b64b577 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -101,7 +101,8 @@ class Builtin_CachedFloatType<string name, string mnemonic,
 // Float8E5M2Type
 //===----------------------------------------------------------------------===//
 
-def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
+def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2", 
+                                   ["QuantizationInterface"]> {
   let summary = "8-bit floating point with 2 bit mantissa";
   let description = [{
     An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -117,6 +118,21 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
 
     Described in: https://arxiv.org/abs/2209.05433
   }];
+
+  let extraClassDeclaration = [{
+    /// QuantizationInterface method implementations
+    bool isStorageSigned() const { return true; }
+    /// Get the bit width of this 8-bit floating point type.
+    unsigned getStorageWidth() const { return 8; }
+    
+    /// Get default maximum value for this 8-bit floating point type.
+    int64_t getDefaultMaximum() const { return 57344; }
+    /// Get default minimum value for this 8-bit floating point type.
+    int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+    
+    /// Get the storage type as a string.
+    std::string getStorageType() const { return "f8E5M2"; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -143,7 +159,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
 // Float8E4M3FNType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
+def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN", 
+                                   ["QuantizationInterface"]> {
   let summary = "8-bit floating point with 3 bit mantissa";
   let description = [{
     An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -160,6 +177,21 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
 
     Described in: https://arxiv.org/abs/2209.05433
   }];
+
+  let extraClassDeclaration = [{
+    /// QuantizationInterface method implementations
+    bool isStorageSigned() const { return true; }
+    /// Get the bit width of this 8-bit floating point type.
+    unsigned getStorageWidth() const { return 8; }
+    
+    /// Get default maximum value for this 8-bit floating point type.
+    int64_t getDefaultMaximum() const { return 448; }
+    /// Get default minimum value for this 8-bit floating point type.
+    int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+    
+    /// Get the storage type as a string.
+    std::string getStorageType() const { return "f8E4M3FN"; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -557,18 +589,11 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
 
     /// QuantizationInterface method implementations
-    /// Return true if this is a signed integer type.
+    /// Return true if this is a signed or signless integer type.
     bool isStorageSigned() const { return !isUnsigned(); }
     /// Get the bit width of this integer type.
     unsigned getStorageWidth() const { return getWidth(); }
     
-    /// Get default minimum value for this integer type.
-    int64_t getDefaultMinimum() const {
-      if (isStorageSigned()) {
-        return llvm::minIntN(getStorageWidth());
-      }
-      return 0;
-    }
     /// Get default maximum value for this integer type.
     int64_t getDefaultMaximum() const {
       if (isStorageSigned()) {
@@ -576,7 +601,14 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
       }
       return llvm::maxUIntN(getStorageWidth());
     }
-    
+    /// Get default minimum value for this integer type.
+    int64_t getDefaultMinimum() const {
+      if (isStorageSigned()) {
+        return llvm::minIntN(getStorageWidth());
+      }
+      return 0;
+    }
+
     /// Get the storage type as a string.
     std::string getStorageType() const {
       return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index f539aca7fff48..8a50b077d9014 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -63,11 +63,11 @@ add_mlir_library(MLIRIR
   MLIRCastInterfacesIncGen
   MLIRDataLayoutInterfacesIncGen
   MLIROpAsmInterfaceIncGen
+  MLIRQuantizationInterfaceIncGen
   MLIRRegionKindInterfaceIncGen
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
   MLIRTensorEncodingIncGen
-  MLIRQuantizationInterfaceIncGen
   
   LINK_LIBS PUBLIC
   MLIRSupport

>From d7447a8b168ad1111e49deaa526c01e89620cd76 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 20 Aug 2025 13:18:44 +0200
Subject: [PATCH 3/4] Renamed to QuantStorageTypeInterface and removed
 redundant dyn_cast checks

---
 mlir/include/mlir/IR/BuiltinTypes.h           |  2 +-
 mlir/include/mlir/IR/BuiltinTypes.td          | 14 ++++----
 mlir/include/mlir/IR/CMakeLists.txt           |  2 +-
 ...nterface.h => QuantStorageTypeInterface.h} | 10 +++---
 ...erface.td => QuantStorageTypeInterface.td} |  8 ++---
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      | 35 ++++++++-----------
 mlir/lib/Dialect/Quant/IR/TypeParser.cpp      | 29 ++++++++-------
 mlir/lib/IR/CMakeLists.txt                    |  5 ++-
 ...face.cpp => QuantStorageTypeInterface.cpp} |  4 +--
 9 files changed, 50 insertions(+), 59 deletions(-)
 rename mlir/include/mlir/IR/{QuantizationInterface.h => QuantStorageTypeInterface.h} (63%)
 rename mlir/include/mlir/IR/{QuantizationInterface.td => QuantStorageTypeInterface.td} (83%)
 rename mlir/lib/IR/{QuantizationInterface.cpp => QuantStorageTypeInterface.cpp} (89%)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 204da9553e915..d30cba29c9814 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,7 +167,7 @@ class BaseMemRefType : public Type,
 // Tablegen Type Declarations
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
 
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 12ac05b64b577..1a29b9549cf51 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,7 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
-include "mlir/IR/QuantizationInterface.td"
+include "mlir/IR/QuantStorageTypeInterface.td"
 include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -102,7 +102,7 @@ class Builtin_CachedFloatType<string name, string mnemonic,
 //===----------------------------------------------------------------------===//
 
 def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2", 
-                                   ["QuantizationInterface"]> {
+                                   ["QuantStorageTypeInterface"]> {
   let summary = "8-bit floating point with 2 bit mantissa";
   let description = [{
     An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -120,7 +120,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
   }];
 
   let extraClassDeclaration = [{
-    /// QuantizationInterface method implementations
+    /// QuantStorageTypeInterface method implementations
     bool isStorageSigned() const { return true; }
     /// Get the bit width of this 8-bit floating point type.
     unsigned getStorageWidth() const { return 8; }
@@ -160,7 +160,7 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN", 
-                                   ["QuantizationInterface"]> {
+                                   ["QuantStorageTypeInterface"]> {
   let summary = "8-bit floating point with 3 bit mantissa";
   let description = [{
     An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -179,7 +179,7 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
   }];
 
   let extraClassDeclaration = [{
-    /// QuantizationInterface method implementations
+    /// QuantStorageTypeInterface method implementations
     bool isStorageSigned() const { return true; }
     /// Get the bit width of this 8-bit floating point type.
     unsigned getStorageWidth() const { return 8; }
@@ -530,7 +530,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface, QuantizationInterface]> {
+    [VectorElementTypeInterface, QuantStorageTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -588,7 +588,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
 
-    /// QuantizationInterface method implementations
+    /// QuantStorageTypeInterface method implementations
     /// Return true if this is a signed or signless integer type.
     bool isStorageSigned() const { return !isUnsigned(); }
     /// Get the bit width of this integer type.
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 153502c6e981b..1e4ca1b8328c6 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_interface(SymbolInterfaces)
 add_mlir_interface(RegionKindInterface)
 
-add_mlir_type_interface(QuantizationInterface)
+add_mlir_type_interface(QuantStorageTypeInterface)
 
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
 mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
similarity index 63%
rename from mlir/include/mlir/IR/QuantizationInterface.h
rename to mlir/include/mlir/IR/QuantStorageTypeInterface.h
index 0d6709ff52065..dce430efdbc89 100644
--- a/mlir/include/mlir/IR/QuantizationInterface.h
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
@@ -1,4 +1,4 @@
-//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//===- QuantStorageTypeInterface.h - Quantzation Interfaces --------*- C++
 //-*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -7,8 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_IR_QuantizationInterface_H
-#define MLIR_IR_QuantizationInterface_H
+#ifndef MLIR_IR_QuantStorageTypeInterface_H
+#define MLIR_IR_QuantStorageTypeInterface_H
 
 #include "mlir/IR/Types.h"
 
@@ -17,6 +17,6 @@ namespace mlir {
 class IntegerType;
 } // namespace mlir
 
-#include "mlir/IR/QuantizationInterface.h.inc"
+#include "mlir/IR/QuantStorageTypeInterface.h.inc"
 
-#endif // MLIR_IR_QuantizationInterface_H
+#endif // MLIR_IR_QuantStorageTypeInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
similarity index 83%
rename from mlir/include/mlir/IR/QuantizationInterface.td
rename to mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 1008ac8e1dcf1..37fe4aaccf610 100644
--- a/mlir/include/mlir/IR/QuantizationInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -1,9 +1,9 @@
-#ifndef MLIR_IR_QUANTIZATIONINTERFACE
-#define MLIR_IR_QUANTIZATIONINTERFACE
+#ifndef MLIR_IR_QUANTSTORAGETYPEINTERFACE
+#define MLIR_IR_QUANTSTORAGETYPEINTERFACE
 
 include "mlir/IR/OpBase.td"
 
-def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
   let description = [{
     Interface for types that can be used as storage types in Quant dialect.
     This interface provides methods to determine storage characteristics for quantization purposes.
@@ -41,4 +41,4 @@ def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
 
 }
 
-#endif // MLIR_IR_QUANTIZATIONINTERFACE
+#endif // MLIR_IR_QUANTSTORAGETYPEINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index e7f9b1dc8a7e1..b27cb790b9052 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,7 +9,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -47,23 +47,16 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
                                 unsigned flags, Type storageType,
                                 Type expressedType, int64_t storageTypeMin,
                                 int64_t storageTypeMax) {
-  // Verify that the storage type is integral.
-  // This restriction may be lifted at some point in favor of using bf16
-  // or f16 as exact representations on hardware where that is advantageous.
-  auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
-  if (!intStorageType)
-    return emitError() << "storage type must be integral";
-
-  if (auto quantizationInterface =
-          llvm::dyn_cast<QuantizationInterface>(storageType)) {
-    unsigned integralWidth = quantizationInterface.getStorageWidth();
+  if (auto quantStorageTypeInterface =
+          llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
+    unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
 
     // Verify storage width.
     if (integralWidth == 0 || integralWidth > MaxStorageBits)
       return emitError() << "illegal storage type size: " << integralWidth;
 
-    int64_t defaultMin = quantizationInterface.getDefaultMinimum();
-    int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+    int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+    int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
 
     if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
         storageTypeMax > defaultMax) {
@@ -74,7 +67,7 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
     return success();
   }
 
-  return emitError() << "storage type must implement QuantizationInterface";
+  return emitError() << "storage type must implement QuantStorageTypeInterface";
 }
 
 Type QuantizedType::getStorageType() const {
@@ -91,21 +84,21 @@ int64_t QuantizedType::getStorageTypeMax() const {
 
 bool QuantizedType::hasStorageTypeBounds() const {
   Type storageType = static_cast<ImplType *>(impl)->storageType;
-  auto quantizationInterface =
-      llvm::dyn_cast<QuantizationInterface>(storageType);
+  auto quantStorageTypeInterface =
+      llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
 
-  int64_t defaultMin = quantizationInterface.getDefaultMinimum();
-  int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+  int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+  int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
 
   return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
 }
 
 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
   Type storageType = static_cast<ImplType *>(impl)->storageType;
-  auto quantizationInterface =
-      llvm::dyn_cast<QuantizationInterface>(storageType);
+  auto quantStorageTypeInterface =
+      llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
 
-  return quantizationInterface.getStorageWidth();
+  return quantStorageTypeInterface.getStorageWidth();
 }
 
 Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 758399a2af5e8..f86df4eacd21f 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,7 +10,7 @@
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/APFloat.h"
 
@@ -29,10 +29,10 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
     if (!succeeded(*result))
       return nullptr;
 
-    if (auto quantizationInterface =
-            llvm::dyn_cast<QuantizationInterface>(type)) {
-      isSigned = quantizationInterface.isStorageSigned();
-      storageTypeWidth = quantizationInterface.getStorageWidth();
+    if (auto quantStorageTypeInterface =
+            llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
+      isSigned = quantStorageTypeInterface.isStorageSigned();
+      storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
     } else {
       parser.emitError(typeLoc, "illegal quantized storage type alias");
       return nullptr;
@@ -67,12 +67,11 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
 static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
                                      int64_t &storageTypeMin,
                                      int64_t &storageTypeMax) {
-  int64_t defaultMin, defaultMax;
-  if (auto quantizationInterface =
-          llvm::dyn_cast<QuantizationInterface>(storageType)) {
-    defaultMin = quantizationInterface.getDefaultMinimum();
-    defaultMax = quantizationInterface.getDefaultMaximum();
-  }
+  auto quantStorageTypeInterface =
+      llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
+
+  int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+  int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
 
   if (failed(parser.parseOptionalLess())) {
     storageTypeMin = defaultMin;
@@ -495,10 +494,10 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
 
 static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
   // storage type
-  if (auto quantizationInterface =
-          llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
-    out << quantizationInterface.getStorageType();
-  }
+  auto quantStorageTypeInterface =
+      llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
+
+  out << quantStorageTypeInterface.getStorageType();
 
   // storageTypeMin and storageTypeMax if not default.
   if (type.hasStorageTypeBounds()) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 8a50b077d9014..9e0d283854b38 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,7 +31,7 @@ add_mlir_library(MLIRIR
   OperationSupport.cpp
   PatternLoggingListener.cpp
   PatternMatch.cpp
-  QuantizationInterface.cpp  
+  QuantStorageTypeInterface.cpp  
   Region.cpp
   RegionKindInterface.cpp
   SymbolTable.cpp
@@ -63,7 +63,7 @@ add_mlir_library(MLIRIR
   MLIRCastInterfacesIncGen
   MLIRDataLayoutInterfacesIncGen
   MLIROpAsmInterfaceIncGen
-  MLIRQuantizationInterfaceIncGen
+  MLIRQuantStorageTypeInterfaceIncGen
   MLIRRegionKindInterfaceIncGen
   MLIRSideEffectInterfacesIncGen
   MLIRSymbolInterfacesIncGen
@@ -72,4 +72,3 @@ add_mlir_library(MLIRIR
   LINK_LIBS PUBLIC
   MLIRSupport
   )
-
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantStorageTypeInterface.cpp
similarity index 89%
rename from mlir/lib/IR/QuantizationInterface.cpp
rename to mlir/lib/IR/QuantStorageTypeInterface.cpp
index a93333278610e..bcc2dc1f3337c 100644
--- a/mlir/lib/IR/QuantizationInterface.cpp
+++ b/mlir/lib/IR/QuantStorageTypeInterface.cpp
@@ -1,4 +1,4 @@
-//===- QuantizationInterface.cpp
+//===- QuantStorageTypeInterface.cpp
 //------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -20,4 +20,4 @@ using namespace mlir::detail;
 /// Tablegen Interface Definitions
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/QuantizationInterface.cpp.inc"
+#include "mlir/IR/QuantStorageTypeInterface.cpp.inc"

>From f8458e773b1003c669bbed7db46c95253ef8e810 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 20 Aug 2025 14:45:14 +0200
Subject: [PATCH 4/4] Added interface methods which expose a few basic packing
 and alignment facts

---
 mlir/include/mlir/IR/BuiltinTypes.td          | 36 +++++++++++++++++++
 .../mlir/IR/QuantStorageTypeInterface.td      | 35 +++++++++++++++---
 2 files changed, 66 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1a29b9549cf51..c7e9ee2a4c95f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -132,6 +132,18 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
     
     /// Get the storage type as a string.
     std::string getStorageType() const { return "f8E5M2"; }
+
+    /// Check if this 8-bit floating point type uses packed representation.
+    bool isPacked() const { return false; }
+
+    /// Get the logical bit width per value for this 8-bit floating point type.
+    unsigned getLogicalBitWidth() const { return 8; }
+
+    /// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
+    unsigned getElementsPerByte() const { return 1; }
+
+    /// Get the preferred alignment in bytes for this 8-bit floating point type.
+    std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
   }];
 }
 
@@ -191,6 +203,18 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
     
     /// Get the storage type as a string.
     std::string getStorageType() const { return "f8E4M3FN"; }
+
+    /// Check if this 8-bit floating point type uses packed representation.
+    bool isPacked() const { return false; }
+
+    /// Get the logical bit width per value for this 8-bit floating point type.
+    unsigned getLogicalBitWidth() const { return 8; }
+
+    /// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
+    unsigned getElementsPerByte() const { return 1; }
+
+    /// Get the preferred alignment in bytes for this 8-bit floating point type.
+    std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
   }];
 }
 
@@ -613,6 +637,18 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
     std::string getStorageType() const {
       return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
     }
+
+    /// Check if this integer type uses packed representation.
+    bool isPacked() const { return false; }
+
+    /// Get the logical bit width per value for this integer type.
+    unsigned getLogicalBitWidth() const { return getWidth(); }
+
+    /// Get the number of logical elements that fit in one byte for this integer type.
+    unsigned getElementsPerByte() const { return 1; }
+
+    /// Get the preferred alignment in bytes for this integer type.
+    std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
   }];
 }
 
diff --git a/mlir/include/mlir/IR/QuantStorageTypeInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 37fe4aaccf610..b09c6e79df840 100644
--- a/mlir/include/mlir/IR/QuantStorageTypeInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -6,7 +6,8 @@ include "mlir/IR/OpBase.td"
 def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
   let description = [{
     Interface for types that can be used as storage types in Quant dialect.
-    This interface provides methods to determine storage characteristics for quantization purposes.
+    This interface provides methods to determine storage characteristics for quantization purposes,
+    including packing behavior, and alignment requirements.
   }];
   let cppNamespace = "::mlir";
 
@@ -18,25 +19,49 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
     "bool", "isStorageSigned", (ins)>,
 
     InterfaceMethod<[{
-      Get the bit width of this integer type.
+      Get the bit width of this type.
       Returns the number of bits used to store values of this type.
     }],
     "unsigned", "getStorageWidth", (ins)>,
 
     InterfaceMethod<[{
-      Get default minimum value for this integer type.
+      Get default minimum value for this type.
     }],
     "int64_t", "getDefaultMinimum", (ins)>,
 
     InterfaceMethod<[{
-      Get default maximum value for this integer type.
+      Get default maximum value for this type.
     }],
     "int64_t", "getDefaultMaximum", (ins)>,
 
     InterfaceMethod<[{
       Get the storage type as a string.
     }],
-    "std::string", "getStorageType", (ins)>
+    "std::string", "getStorageType", (ins)>,
+
+    InterfaceMethod<[{
+      Check if the storage type uses packed representation.
+      Returns true if multiple values are packed into one byte (e.g., sub-byte types),
+      false if value uses full byte.
+    }],
+    "bool", "isPacked", (ins)>,
+
+    InterfaceMethod<[{
+      Get the logical bit width per value.
+      For packed sub-byte types, this may differ from getStorageWidth().
+    }],
+    "unsigned", "getLogicalBitWidth", (ins)>,
+
+    InterfaceMethod<[{
+      Get the number of logical elements that fit in one byte.
+      For packed sub-byte types, this returns how many values can be stored per byte.
+    }],
+    "unsigned", "getElementsPerByte", (ins)>,
+
+    InterfaceMethod<[{
+      Returns the preferred alignment for this type, in bytes.
+    }],
+    "std::optional<unsigned>", "getPreferredAlignmentBytes", (ins)>
   ];
 
 }



More information about the Mlir-commits mailing list