[Mlir-commits] [mlir] Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect (PR #152966)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 9 00:29:27 PST 2026
https://github.com/Roman-Pevnyi updated https://github.com/llvm/llvm-project/pull/152966
>From 93f80f09629b90c5a6b8270881733edf2384e433 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Mon, 11 Aug 2025 09:33:14 +0200
Subject: [PATCH 1/8] Added new type interafce to let UniformQuantizeType
accept other than built in types. Updated parser and printer in Quant dialect
---
mlir/cmake/modules/AddMLIR.cmake | 9 +++
mlir/include/mlir/IR/BuiltinTypes.h | 2 +
mlir/include/mlir/IR/BuiltinTypes.td | 29 +++++++-
mlir/include/mlir/IR/CMakeLists.txt | 2 +
mlir/include/mlir/IR/QuantizationInterface.h | 22 ++++++
mlir/include/mlir/IR/QuantizationInterface.td | 44 +++++++++++
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 65 ++++++++--------
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 74 ++++++++++---------
mlir/lib/IR/CMakeLists.txt | 2 +
mlir/lib/IR/QuantizationInterface.cpp | 23 ++++++
10 files changed, 207 insertions(+), 65 deletions(-)
create mode 100644 mlir/include/mlir/IR/QuantizationInterface.h
create mode 100644 mlir/include/mlir/IR/QuantizationInterface.td
create mode 100644 mlir/lib/IR/QuantizationInterface.cpp
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 6589458ab7894..be28a2ab900c9 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -216,6 +216,15 @@ macro(add_mlir_generic_tablegen_target target)
add_dependencies(mlir-generic-headers ${target})
endmacro()
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+ set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+ mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+ mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+ add_public_tablegen_target(MLIR${interface}IncGen)
+ add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
+
# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
set(LLVM_TARGET_DEFINITIONS ${doc_filename}.td)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
+#include "mlir/IR/QuantizationInterface.h"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 08847dd11c685..650285135e5b4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -501,7 +502,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface]> {
+ [VectorElementTypeInterface, QuantizationInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -558,6 +559,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+ /// QuantizationInterface method implementations
+ /// Return true if this is a signed integer type.
+ bool isStorageSigned() const { return !isUnsigned(); }
+ /// Get the bit width of this integer type.
+ unsigned getStorageWidth() const { return getWidth(); }
+
+ /// Get default minimum value for this integer type.
+ int64_t getDefaultMinimum() const {
+ if (isStorageSigned()) {
+ return llvm::minIntN(getStorageWidth());
+ }
+ return 0;
+ }
+ /// Get default maximum value for this integer type.
+ int64_t getDefaultMaximum() const {
+ if (isStorageSigned()) {
+ return llvm::maxIntN(getStorageWidth());
+ }
+ return llvm::maxUIntN(getStorageWidth());
+ }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const {
+ return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 3d30d92ed6ec4..35279a13cf109 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -4,6 +4,8 @@ mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs)
add_mlir_interface(RegionKindInterface)
+add_mlir_type_interface(QuantizationInterface)
+
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+ let description = [{
+ Interface for types that can be used as storage types in Quant dialect.
+ This interface provides methods to determine storage characteristics for quantization purposes.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Check if the storage type is signed.
+ Returns true if the type represents signed values, false for unsigned.
+ }],
+ "bool", "isStorageSigned", (ins)>,
+
+ InterfaceMethod<[{
+ Get the bit width of this integer type.
+ Returns the number of bits used to store values of this type.
+ }],
+ "unsigned", "getStorageWidth", (ins)>,
+
+ InterfaceMethod<[{
+ Get default minimum value for this integer type.
+ }],
+ "int64_t", "getDefaultMinimum", (ins)>,
+
+ InterfaceMethod<[{
+ Get default maximum value for this integer type.
+ }],
+ "int64_t", "getDefaultMaximum", (ins)>,
+
+ InterfaceMethod<[{
+ Get the storage type as a string.
+ }],
+ "std::string", "getStorageType", (ins)>
+ ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
- unsigned integralWidth = intStorageType.getWidth();
-
- // Verify storage width.
- if (integralWidth == 0 || integralWidth > MaxStorageBits)
- return emitError() << "illegal storage type size: " << integralWidth;
-
- // Verify storageTypeMin and storageTypeMax.
- bool isSigned =
- (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSigned, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSigned, integralWidth);
- if (storageTypeMax - storageTypeMin <= 0 ||
- storageTypeMin < defaultIntegerMin ||
- storageTypeMax > defaultIntegerMax) {
- return emitError() << "illegal storage min and storage max: ("
- << storageTypeMin << ":" << storageTypeMax << ")";
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+ // Verify storage width.
+ if (integralWidth == 0 || integralWidth > MaxStorageBits)
+ return emitError() << "illegal storage type size: " << integralWidth;
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+ storageTypeMax > defaultMax) {
+ return emitError() << "illegal storage min and storage max: ("
+ << storageTypeMin << ":" << storageTypeMax << ")";
+ }
+
+ return success();
}
- return success();
+
+ return emitError() << "storage type must implement QuantizationInterface";
}
Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
}
bool QuantizedType::hasStorageTypeBounds() const {
- unsigned int integralWidth = getStorageTypeIntegralWidth();
- bool isSignedInteger = isSigned();
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSignedInteger, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSignedInteger, integralWidth);
- return defaultIntegerMin != getStorageTypeMin() ||
- defaultIntegerMax != getStorageTypeMax();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
- // NOTE: If ever supporting non-integral storage types, some other scheme
- // for determining the width will be needed.
- return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ return quantizationInterface.getStorageWidth();
}
Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 7c3840abbf91c..bb38897ce5f4c 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVectorExtras.h"
@@ -17,9 +18,9 @@
using namespace mlir;
using namespace quant;
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
auto typeLoc = parser.getCurrentLocation();
- IntegerType type;
+ Type type;
// Parse storage type (alpha_ident, integer_literal).
StringRef identifier;
@@ -28,20 +29,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (result.has_value()) {
if (!succeeded(*result))
return nullptr;
- isSigned = !type.isUnsigned();
- storageTypeWidth = type.getWidth();
- } else if (succeeded(parser.parseKeyword(&identifier))) {
- // Otherwise, this must be an unsigned integer (`u` integer-literal).
- if (!identifier.consume_front("u")) {
- parser.emitError(typeLoc, "illegal storage type prefix");
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type)) {
+ isSigned = quantizationInterface.isStorageSigned();
+ storageTypeWidth = quantizationInterface.getStorageWidth();
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- if (identifier.getAsInteger(10, storageTypeWidth)) {
- parser.emitError(typeLoc, "expected storage type width");
+ } else if (succeeded(parser.parseKeyword(&identifier))) {
+ // Otherwise, this must be an unsigned integer (`u` integer-literal)
+ if (identifier.consume_front("u")) {
+ if (identifier.getAsInteger(10, storageTypeWidth)) {
+ parser.emitError(typeLoc, "expected storage type width");
+ return nullptr;
+ }
+ isSigned = false;
+ type = parser.getBuilder().getIntegerType(storageTypeWidth);
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- isSigned = false;
- type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
return nullptr;
}
@@ -56,17 +65,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
return type;
}
-static ParseResult parseStorageRange(DialectAsmParser &parser,
- IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
- int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
- isSigned, storageType.getWidth());
- int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
- isSigned, storageType.getWidth());
+ int64_t defaultMin, defaultMax;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ defaultMin = quantizationInterface.getDefaultMinimum();
+ defaultMax = quantizationInterface.getDefaultMaximum();
+ }
+
if (failed(parser.parseOptionalLess())) {
- storageTypeMin = defaultIntegerMin;
- storageTypeMax = defaultIntegerMax;
+ storageTypeMin = defaultMin;
+ storageTypeMax = defaultMax;
return success();
}
@@ -76,11 +87,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
parser.getCurrentLocation(&maxLoc) ||
parser.parseInteger(storageTypeMax) || parser.parseGreater())
return failure();
- if (storageTypeMin < defaultIntegerMin) {
+ if (storageTypeMin < defaultMin) {
return parser.emitError(minLoc, "illegal storage type minimum: ")
<< storageTypeMin;
}
- if (storageTypeMax > defaultIntegerMax) {
+ if (storageTypeMax > defaultMax) {
return parser.emitError(maxLoc, "illegal storage type maximum: ")
<< storageTypeMax;
}
@@ -114,7 +125,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -135,8 +146,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -323,7 +333,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// scale-zero-tensor (`,` scale-zero-tensor)*
/// `}`
static Type parseUniformType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -351,8 +361,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -488,12 +497,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
- unsigned storageWidth = type.getStorageTypeIntegralWidth();
- bool isSigned = type.isSigned();
- if (isSigned) {
- out << "i" << storageWidth;
- } else {
- out << "u" << storageWidth;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+ out << quantizationInterface.getStorageType();
}
// storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 563c8c6285ef3..342d8a233eea5 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
+ QuantizationInterface.cpp
Region.cpp
RegionKindInterface.cpp
Remarks.cpp
@@ -68,6 +69,7 @@ add_mlir_library(MLIRIR
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
MLIROpAsmDialectInterfaceIncGen
+ MLIRQuantizationInterfaceIncGen
LINK_LIBS PUBLIC
MLIRSupport
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"
>From 4fcfd23928c0657b40bc81d6be5c8dd8a78a410c Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 13 Aug 2025 13:13:42 +0200
Subject: [PATCH 2/8] Added QuantizationInterface to Float8E5M2Type and
Float8E4M3FNType
---
mlir/include/mlir/IR/BuiltinTypes.td | 54 ++++++++++++++++++++++------
mlir/lib/IR/CMakeLists.txt | 1 +
2 files changed, 44 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 650285135e5b4..eef2719056d02 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -101,7 +101,8 @@ class Builtin_CachedFloatType<string name, string mnemonic,
// Float8E5M2Type
//===----------------------------------------------------------------------===//
-def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
+def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
+ ["QuantizationInterface"]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -117,6 +118,21 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
Described in: https://arxiv.org/abs/2209.05433
}];
+
+ let extraClassDeclaration = [{
+ /// QuantizationInterface method implementations
+ bool isStorageSigned() const { return true; }
+ /// Get the bit width of this 8-bit floating point type.
+ unsigned getStorageWidth() const { return 8; }
+
+ /// Get default maximum value for this 8-bit floating point type.
+ int64_t getDefaultMaximum() const { return 57344; }
+ /// Get default minimum value for this 8-bit floating point type.
+ int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const { return "f8E5M2"; }
+ }];
}
//===----------------------------------------------------------------------===//
@@ -143,7 +159,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
// Float8E4M3FNType
//===----------------------------------------------------------------------===//
-def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
+def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
+ ["QuantizationInterface"]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -160,6 +177,21 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
Described in: https://arxiv.org/abs/2209.05433
}];
+
+ let extraClassDeclaration = [{
+ /// QuantizationInterface method implementations
+ bool isStorageSigned() const { return true; }
+ /// Get the bit width of this 8-bit floating point type.
+ unsigned getStorageWidth() const { return 8; }
+
+ /// Get default maximum value for this 8-bit floating point type.
+ int64_t getDefaultMaximum() const { return 448; }
+ /// Get default minimum value for this 8-bit floating point type.
+ int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const { return "f8E4M3FN"; }
+ }];
}
//===----------------------------------------------------------------------===//
@@ -561,18 +593,11 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
/// QuantizationInterface method implementations
- /// Return true if this is a signed integer type.
+ /// Return true if this is a signed or signless integer type.
bool isStorageSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
unsigned getStorageWidth() const { return getWidth(); }
- /// Get default minimum value for this integer type.
- int64_t getDefaultMinimum() const {
- if (isStorageSigned()) {
- return llvm::minIntN(getStorageWidth());
- }
- return 0;
- }
/// Get default maximum value for this integer type.
int64_t getDefaultMaximum() const {
if (isStorageSigned()) {
@@ -580,7 +605,14 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
}
return llvm::maxUIntN(getStorageWidth());
}
-
+ /// Get default minimum value for this integer type.
+ int64_t getDefaultMinimum() const {
+ if (isStorageSigned()) {
+ return llvm::minIntN(getStorageWidth());
+ }
+ return 0;
+ }
+
/// Get the storage type as a string.
std::string getStorageType() const {
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 342d8a233eea5..364f3c14f081e 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -64,6 +64,7 @@ add_mlir_library(MLIRIR
MLIRCastInterfacesIncGen
MLIRDataLayoutInterfacesIncGen
MLIROpAsmInterfaceIncGen
+ MLIRQuantizationInterfaceIncGen
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
>From b83a7a3867b1ab642a4031957454bc79f60d9024 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 20 Aug 2025 13:18:44 +0200
Subject: [PATCH 3/8] Renamed to QuantStorageTypeInterface and removed
redundant dyn_cast checks
---
mlir/include/mlir/IR/BuiltinTypes.h | 2 +-
mlir/include/mlir/IR/BuiltinTypes.td | 14 ++++----
mlir/include/mlir/IR/CMakeLists.txt | 2 +-
...nterface.h => QuantStorageTypeInterface.h} | 10 +++---
...erface.td => QuantStorageTypeInterface.td} | 8 ++---
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 35 ++++++++-----------
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 29 ++++++++-------
mlir/lib/IR/CMakeLists.txt | 5 ++-
...face.cpp => QuantStorageTypeInterface.cpp} | 4 +--
9 files changed, 50 insertions(+), 59 deletions(-)
rename mlir/include/mlir/IR/{QuantizationInterface.h => QuantStorageTypeInterface.h} (63%)
rename mlir/include/mlir/IR/{QuantizationInterface.td => QuantStorageTypeInterface.td} (83%)
rename mlir/lib/IR/{QuantizationInterface.cpp => QuantStorageTypeInterface.cpp} (89%)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 204da9553e915..d30cba29c9814 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,7 +167,7 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index eef2719056d02..a44c1119afb7d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,7 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
-include "mlir/IR/QuantizationInterface.td"
+include "mlir/IR/QuantStorageTypeInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -102,7 +102,7 @@ class Builtin_CachedFloatType<string name, string mnemonic,
//===----------------------------------------------------------------------===//
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
- ["QuantizationInterface"]> {
+ ["QuantStorageTypeInterface"]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -120,7 +120,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
}];
let extraClassDeclaration = [{
- /// QuantizationInterface method implementations
+ /// QuantStorageTypeInterface method implementations
bool isStorageSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }
@@ -160,7 +160,7 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
//===----------------------------------------------------------------------===//
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
- ["QuantizationInterface"]> {
+ ["QuantStorageTypeInterface"]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -179,7 +179,7 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
}];
let extraClassDeclaration = [{
- /// QuantizationInterface method implementations
+ /// QuantStorageTypeInterface method implementations
bool isStorageSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }
@@ -534,7 +534,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface, QuantizationInterface]> {
+ [VectorElementTypeInterface, QuantStorageTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -592,7 +592,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
- /// QuantizationInterface method implementations
+ /// QuantStorageTypeInterface method implementations
/// Return true if this is a signed or signless integer type.
bool isStorageSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 35279a13cf109..15518901b901a 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -4,7 +4,7 @@ mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs)
add_mlir_interface(RegionKindInterface)
-add_mlir_type_interface(QuantizationInterface)
+add_mlir_type_interface(QuantStorageTypeInterface)
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
similarity index 63%
rename from mlir/include/mlir/IR/QuantizationInterface.h
rename to mlir/include/mlir/IR/QuantStorageTypeInterface.h
index 0d6709ff52065..dce430efdbc89 100644
--- a/mlir/include/mlir/IR/QuantizationInterface.h
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
@@ -1,4 +1,4 @@
-//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//===- QuantStorageTypeInterface.h - Quantzation Interfaces --------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_IR_QuantizationInterface_H
-#define MLIR_IR_QuantizationInterface_H
+#ifndef MLIR_IR_QuantStorageTypeInterface_H
+#define MLIR_IR_QuantStorageTypeInterface_H
#include "mlir/IR/Types.h"
@@ -17,6 +17,6 @@ namespace mlir {
class IntegerType;
} // namespace mlir
-#include "mlir/IR/QuantizationInterface.h.inc"
+#include "mlir/IR/QuantStorageTypeInterface.h.inc"
-#endif // MLIR_IR_QuantizationInterface_H
+#endif // MLIR_IR_QuantStorageTypeInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
similarity index 83%
rename from mlir/include/mlir/IR/QuantizationInterface.td
rename to mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 1008ac8e1dcf1..37fe4aaccf610 100644
--- a/mlir/include/mlir/IR/QuantizationInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -1,9 +1,9 @@
-#ifndef MLIR_IR_QUANTIZATIONINTERFACE
-#define MLIR_IR_QUANTIZATIONINTERFACE
+#ifndef MLIR_IR_QUANTSTORAGETYPEINTERFACE
+#define MLIR_IR_QUANTSTORAGETYPEINTERFACE
include "mlir/IR/OpBase.td"
-def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
This interface provides methods to determine storage characteristics for quantization purposes.
@@ -41,4 +41,4 @@ def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
}
-#endif // MLIR_IR_QUANTIZATIONINTERFACE
+#endif // MLIR_IR_QUANTSTORAGETYPEINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index e7f9b1dc8a7e1..b27cb790b9052 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,7 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -47,23 +47,16 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
- // Verify that the storage type is integral.
- // This restriction may be lifted at some point in favor of using bf16
- // or f16 as exact representations on hardware where that is advantageous.
- auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
- if (!intStorageType)
- return emitError() << "storage type must be integral";
-
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType)) {
- unsigned integralWidth = quantizationInterface.getStorageWidth();
+ if (auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
+ unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;
- int64_t defaultMin = quantizationInterface.getDefaultMinimum();
- int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
storageTypeMax > defaultMax) {
@@ -74,7 +67,7 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
- return emitError() << "storage type must implement QuantizationInterface";
+ return emitError() << "storage type must implement QuantStorageTypeInterface";
}
Type QuantizedType::getStorageType() const {
@@ -91,21 +84,21 @@ int64_t QuantizedType::getStorageTypeMax() const {
bool QuantizedType::hasStorageTypeBounds() const {
Type storageType = static_cast<ImplType *>(impl)->storageType;
- auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType);
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
- int64_t defaultMin = quantizationInterface.getDefaultMinimum();
- int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
Type storageType = static_cast<ImplType *>(impl)->storageType;
- auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType);
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
- return quantizationInterface.getStorageWidth();
+ return quantStorageTypeInterface.getStorageWidth();
}
Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index bb38897ce5f4c..bebc331c6af85 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,7 +10,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVectorExtras.h"
@@ -30,10 +30,10 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (!succeeded(*result))
return nullptr;
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(type)) {
- isSigned = quantizationInterface.isStorageSigned();
- storageTypeWidth = quantizationInterface.getStorageWidth();
+ if (auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
+ isSigned = quantStorageTypeInterface.isStorageSigned();
+ storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
} else {
parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
@@ -68,12 +68,11 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
- int64_t defaultMin, defaultMax;
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType)) {
- defaultMin = quantizationInterface.getDefaultMinimum();
- defaultMax = quantizationInterface.getDefaultMaximum();
- }
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
+
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
if (failed(parser.parseOptionalLess())) {
storageTypeMin = defaultMin;
@@ -497,10 +496,10 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
- out << quantizationInterface.getStorageType();
- }
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
+
+ out << quantStorageTypeInterface.getStorageType();
// storageTypeMin and storageTypeMax if not default.
if (type.hasStorageTypeBounds()) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 364f3c14f081e..579d35b3f82a5 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,7 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
- QuantizationInterface.cpp
+ QuantStorageTypeInterface.cpp
Region.cpp
RegionKindInterface.cpp
Remarks.cpp
@@ -64,7 +64,7 @@ add_mlir_library(MLIRIR
MLIRCastInterfacesIncGen
MLIRDataLayoutInterfacesIncGen
MLIROpAsmInterfaceIncGen
- MLIRQuantizationInterfaceIncGen
+ MLIRQuantStorageTypeInterfaceIncGen
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
@@ -75,4 +75,3 @@ add_mlir_library(MLIRIR
LINK_LIBS PUBLIC
MLIRSupport
)
-
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantStorageTypeInterface.cpp
similarity index 89%
rename from mlir/lib/IR/QuantizationInterface.cpp
rename to mlir/lib/IR/QuantStorageTypeInterface.cpp
index a93333278610e..bcc2dc1f3337c 100644
--- a/mlir/lib/IR/QuantizationInterface.cpp
+++ b/mlir/lib/IR/QuantStorageTypeInterface.cpp
@@ -1,4 +1,4 @@
-//===- QuantizationInterface.cpp
+//===- QuantStorageTypeInterface.cpp
//------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -20,4 +20,4 @@ using namespace mlir::detail;
/// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
-#include "mlir/IR/QuantizationInterface.cpp.inc"
+#include "mlir/IR/QuantStorageTypeInterface.cpp.inc"
>From a6b17ecd010b7319a798ad03ff280b6620d3433a Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 20 Aug 2025 14:45:14 +0200
Subject: [PATCH 4/8] Added interface methods which expose a few basic packing
and alignment facts
---
mlir/include/mlir/IR/BuiltinTypes.td | 36 +++++++++++++++++++
.../mlir/IR/QuantStorageTypeInterface.td | 35 +++++++++++++++---
2 files changed, 66 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a44c1119afb7d..8d37ee263415d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -132,6 +132,18 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
/// Get the storage type as a string.
std::string getStorageType() const { return "f8E5M2"; }
+
+ /// Check if this 8-bit floating point type uses packed representation.
+ bool isPacked() const { return false; }
+
+ /// Get the logical bit width per value for this 8-bit floating point type.
+ unsigned getLogicalBitWidth() const { return 8; }
+
+ /// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
+ unsigned getElementsPerByte() const { return 1; }
+
+ /// Get the preferred alignment in bytes for this 8-bit floating point type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}
@@ -191,6 +203,18 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
/// Get the storage type as a string.
std::string getStorageType() const { return "f8E4M3FN"; }
+
+ /// Check if this 8-bit floating point type uses packed representation.
+ bool isPacked() const { return false; }
+
+ /// Get the logical bit width per value for this 8-bit floating point type.
+ unsigned getLogicalBitWidth() const { return 8; }
+
+ /// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
+ unsigned getElementsPerByte() const { return 1; }
+
+ /// Get the preferred alignment in bytes for this 8-bit floating point type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}
@@ -617,6 +641,18 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
std::string getStorageType() const {
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
}
+
+ /// Check if this integer type uses packed representation.
+ bool isPacked() const { return false; }
+
+ /// Get the logical bit width per value for this integer type.
+ unsigned getLogicalBitWidth() const { return getWidth(); }
+
+ /// Get the number of logical elements that fit in one byte for this integer type.
+ unsigned getElementsPerByte() const { return 1; }
+
+ /// Get the preferred alignment in bytes for this integer type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}
diff --git a/mlir/include/mlir/IR/QuantStorageTypeInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 37fe4aaccf610..b09c6e79df840 100644
--- a/mlir/include/mlir/IR/QuantStorageTypeInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -6,7 +6,8 @@ include "mlir/IR/OpBase.td"
def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
- This interface provides methods to determine storage characteristics for quantization purposes.
+ This interface provides methods to determine storage characteristics for quantization purposes,
+ including packing behavior, and alignment requirements.
}];
let cppNamespace = "::mlir";
@@ -18,25 +19,49 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
"bool", "isStorageSigned", (ins)>,
InterfaceMethod<[{
- Get the bit width of this integer type.
+ Get the bit width of this type.
Returns the number of bits used to store values of this type.
}],
"unsigned", "getStorageWidth", (ins)>,
InterfaceMethod<[{
- Get default minimum value for this integer type.
+ Get default minimum value for this type.
}],
"int64_t", "getDefaultMinimum", (ins)>,
InterfaceMethod<[{
- Get default maximum value for this integer type.
+ Get default maximum value for this type.
}],
"int64_t", "getDefaultMaximum", (ins)>,
InterfaceMethod<[{
Get the storage type as a string.
}],
- "std::string", "getStorageType", (ins)>
+ "std::string", "getStorageType", (ins)>,
+
+ InterfaceMethod<[{
+ Check if the storage type uses packed representation.
+ Returns true if multiple values are packed into one byte (e.g., sub-byte types),
+ false if value uses full byte.
+ }],
+ "bool", "isPacked", (ins)>,
+
+ InterfaceMethod<[{
+ Get the logical bit width per value.
+ For packed sub-byte types, this may differ from getStorageWidth().
+ }],
+ "unsigned", "getLogicalBitWidth", (ins)>,
+
+ InterfaceMethod<[{
+ Get the number of logical elements that fit in one byte.
+ For packed sub-byte types, this returns how many values can be stored per byte.
+ }],
+ "unsigned", "getElementsPerByte", (ins)>,
+
+ InterfaceMethod<[{
+ Returns the preferred alignment for this type, in bytes.
+ }],
+ "std::optional<unsigned>", "getPreferredAlignmentBytes", (ins)>
];
}
>From 6be747c93cedb448ed24b386f6b9a04d3b49d4e8 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Thu, 8 Jan 2026 16:41:18 +0100
Subject: [PATCH 5/8] Updated interface methodsto be used correctly with
Integer storage type
---
mlir/include/mlir/IR/BuiltinTypes.td | 41 ++++++++++---------
.../mlir/IR/QuantStorageTypeInterface.td | 13 +++---
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 9 ++--
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 22 ++++++----
4 files changed, 47 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 8d37ee263415d..42285ef49ea5f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -121,17 +121,18 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
let extraClassDeclaration = [{
/// QuantStorageTypeInterface method implementations
- bool isStorageSigned() const { return true; }
+ /// Whether the storage type should default to signed when used in quantization.
+ bool shouldDefaultToSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }
/// Get default maximum value for this 8-bit floating point type.
- int64_t getDefaultMaximum() const { return 57344; }
+ int64_t getDefaultMaximum(bool isSigned) const { return 57344; }
/// Get default minimum value for this 8-bit floating point type.
- int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+ int64_t getDefaultMinimum(bool isSigned) const { return -getDefaultMaximum(isSigned); }
/// Get the storage type as a string.
- std::string getStorageType() const { return "f8E5M2"; }
+ std::string getStorageType(bool isSigned) const { return "f8E5M2"; }
/// Check if this 8-bit floating point type uses packed representation.
bool isPacked() const { return false; }
@@ -192,17 +193,18 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
let extraClassDeclaration = [{
/// QuantStorageTypeInterface method implementations
- bool isStorageSigned() const { return true; }
+ /// Whether the storage type should default to signed when used in quantization.
+ bool shouldDefaultToSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }
/// Get default maximum value for this 8-bit floating point type.
- int64_t getDefaultMaximum() const { return 448; }
+ int64_t getDefaultMaximum(bool isSigned) const { return 448; }
/// Get default minimum value for this 8-bit floating point type.
- int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+ int64_t getDefaultMinimum(bool isSigned) const { return -getDefaultMaximum(isSigned); }
/// Get the storage type as a string.
- std::string getStorageType() const { return "f8E4M3FN"; }
+ std::string getStorageType(bool isSigned) const { return "f8E4M3FN"; }
/// Check if this 8-bit floating point type uses packed representation.
bool isPacked() const { return false; }
@@ -617,29 +619,30 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
/// QuantStorageTypeInterface method implementations
- /// Return true if this is a signed or signless integer type.
- bool isStorageSigned() const { return !isUnsigned(); }
+ /// Whether the storage type should default to signed when used in quantization.
+ /// Returns true if this is a signed or signless integer type.
+ bool shouldDefaultToSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
unsigned getStorageWidth() const { return getWidth(); }
/// Get default maximum value for this integer type.
- int64_t getDefaultMaximum() const {
- if (isStorageSigned()) {
- return llvm::maxIntN(getStorageWidth());
+ int64_t getDefaultMaximum(bool isSigned) const {
+ if (isSigned) {
+ return llvm::maxIntN(getWidth());
}
- return llvm::maxUIntN(getStorageWidth());
+ return llvm::maxUIntN(getWidth());
}
/// Get default minimum value for this integer type.
- int64_t getDefaultMinimum() const {
- if (isStorageSigned()) {
- return llvm::minIntN(getStorageWidth());
+ int64_t getDefaultMinimum(bool isSigned) const {
+ if (isSigned) {
+ return llvm::minIntN(getWidth());
}
return 0;
}
/// Get the storage type as a string.
- std::string getStorageType() const {
- return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+ std::string getStorageType(bool isSigned) const {
+ return (isSigned ? "i" : "u") + std::to_string(getWidth());
}
/// Check if this integer type uses packed representation.
diff --git a/mlir/include/mlir/IR/QuantStorageTypeInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
index b09c6e79df840..4df7b1cfdbd29 100644
--- a/mlir/include/mlir/IR/QuantStorageTypeInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -13,10 +13,11 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
let methods = [
InterfaceMethod<[{
- Check if the storage type is signed.
- Returns true if the type represents signed values, false for unsigned.
+ Whether the storage type should default to signed when used in quantization.
+ Returns true if the type defaults to signed (e.g., si8, i8 or float types),
+ false if it defaults to unsigned.
}],
- "bool", "isStorageSigned", (ins)>,
+ "bool", "shouldDefaultToSigned", (ins)>,
InterfaceMethod<[{
Get the bit width of this type.
@@ -27,17 +28,17 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
InterfaceMethod<[{
Get default minimum value for this type.
}],
- "int64_t", "getDefaultMinimum", (ins)>,
+ "int64_t", "getDefaultMinimum", (ins "bool":$isSigned)>,
InterfaceMethod<[{
Get default maximum value for this type.
}],
- "int64_t", "getDefaultMaximum", (ins)>,
+ "int64_t", "getDefaultMaximum", (ins "bool":$isSigned)>,
InterfaceMethod<[{
Get the storage type as a string.
}],
- "std::string", "getStorageType", (ins)>,
+ "std::string", "getStorageType", (ins "bool":$isSigned)>,
InterfaceMethod<[{
Check if the storage type uses packed representation.
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b27cb790b9052..c5a36f7106ad3 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -55,8 +55,9 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;
- int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
- int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
+ bool isSigned = flags & QuantizationFlags::Signed;
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
storageTypeMax > defaultMax) {
@@ -87,8 +88,8 @@ bool QuantizedType::hasStorageTypeBounds() const {
auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
- int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
- int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned());
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned());
return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index bebc331c6af85..c1fe1285cefa5 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -32,10 +32,12 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
- isSigned = quantStorageTypeInterface.isStorageSigned();
+ // Returns true if the type defaults to signed (e.g., si8, i8 or float
+ // types), false if it defaults to unsigned.
+ isSigned = quantStorageTypeInterface.shouldDefaultToSigned();
storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
} else {
- parser.emitError(typeLoc, "illegal quantized storage type alias");
+ parser.emitError(typeLoc, "illegal storage type prefix");
return nullptr;
}
} else if (succeeded(parser.parseKeyword(&identifier))) {
@@ -48,7 +50,7 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
isSigned = false;
type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
- parser.emitError(typeLoc, "illegal quantized storage type alias");
+ parser.emitError(typeLoc, "illegal storage type prefix");
return nullptr;
}
} else {
@@ -66,13 +68,13 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
}
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
- int64_t &storageTypeMin,
+ bool isSigned, int64_t &storageTypeMin,
int64_t &storageTypeMax) {
auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
- int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
- int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
if (failed(parser.parseOptionalLess())) {
storageTypeMin = defaultMin;
@@ -145,7 +147,8 @@ static Type parseAnyType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
+ storageTypeMax)) {
return nullptr;
}
@@ -360,7 +363,8 @@ static Type parseUniformType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
+ storageTypeMax)) {
return nullptr;
}
@@ -499,7 +503,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
- out << quantStorageTypeInterface.getStorageType();
+ out << quantStorageTypeInterface.getStorageType(type.isSigned());
// storageTypeMin and storageTypeMax if not default.
if (type.hasStorageTypeBounds()) {
>From c187374553d6909500b294b42448e299c4f0e165 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Fri, 9 Jan 2026 12:36:11 +0100
Subject: [PATCH 6/8] Added LIT tests for f8E5M2, f8E4M3FN and f4E2M1FN types +
cosmetic changes
---
mlir/include/mlir/IR/BuiltinTypes.td | 49 ++++++++++++++---
.../mlir/IR/QuantStorageTypeInterface.h | 10 ++--
mlir/lib/IR/CMakeLists.txt | 1 -
mlir/lib/IR/QuantStorageTypeInterface.cpp | 3 +-
.../Dialect/Quant/parse-uniform-invalid.mlir | 29 ++++++++++
mlir/test/Dialect/Quant/parse-uniform.mlir | 55 +++++++++++++++++++
6 files changed, 129 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 42285ef49ea5f..d842789f69943 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -81,17 +81,19 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
// Base class for Builtin dialect float types.
class Builtin_FloatType<string name, string mnemonic,
+ list<Trait> traits = [],
list<string> declaredInterfaceMethods = []>
- : Builtin_Type<name, mnemonic, /*traits=*/[
- DeclareTypeInterfaceMethods<
+ : Builtin_Type<name, mnemonic, /*traits=*/
+ traits # [DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
}
// Float types that are cached in MLIRContext.
class Builtin_CachedFloatType<string name, string mnemonic,
+ list<Trait> traits = [],
list<string> declaredInterfaceMethods = []>
- : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
+ : Builtin_FloatType<name, mnemonic, traits, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
@@ -102,7 +104,7 @@ class Builtin_CachedFloatType<string name, string mnemonic,
//===----------------------------------------------------------------------===//
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
- ["QuantStorageTypeInterface"]> {
+ [QuantStorageTypeInterface]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -173,7 +175,7 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
//===----------------------------------------------------------------------===//
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
- ["QuantStorageTypeInterface"]> {
+ [QuantStorageTypeInterface]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -313,7 +315,8 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
// Float4E2M1FNType
//===----------------------------------------------------------------------===//
-def Builtin_Float4E2M1FN : Builtin_FloatType<"Float4E2M1FN", "f4E2M1FN"> {
+def Builtin_Float4E2M1FN : Builtin_FloatType<"Float4E2M1FN", "f4E2M1FN",
+ [QuantStorageTypeInterface]> {
let summary = "4-bit floating point with 2-bit exponent and 1-bit mantissa";
let description = [{
An 4-bit floating point type with 1 sign bit, 2 bits exponent and 1 bit
@@ -329,6 +332,34 @@ def Builtin_Float4E2M1FN : Builtin_FloatType<"Float4E2M1FN", "f4E2M1FN"> {
Open Compute Project (OCP) microscaling formats (MX) specification:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
}];
+
+ let extraClassDeclaration = [{
+ /// QuantStorageTypeInterface method implementations
+ /// Whether the storage type should default to signed when used in quantization.
+ bool shouldDefaultToSigned() const { return true; }
+ /// Get the bit width of this 4-bit floating point type.
+ unsigned getStorageWidth() const { return 4; }
+
+ /// Get default maximum value for this 4-bit floating point type.
+ int64_t getDefaultMaximum(bool isSigned) const { return 6; }
+ /// Get default minimum value for this 4-bit floating point type.
+ int64_t getDefaultMinimum(bool isSigned) const { return -getDefaultMaximum(isSigned); }
+
+ /// Get the storage type as a string.
+ std::string getStorageType(bool isSigned) const { return "f4E2M1FN"; }
+
+ /// Check if this 4-bit floating point type uses packed representation.
+ bool isPacked() const { return true; }
+
+ /// Get the logical bit width per value for this 4-bit floating point type.
+ unsigned getLogicalBitWidth() const { return 4; }
+
+ /// Get the number of logical elements that fit in one byte for this 4-bit floating point type.
+ unsigned getElementsPerByte() const { return 2; }
+
+ /// Get the preferred alignment in bytes for this 4-bit floating point type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
+ }];
}
//===----------------------------------------------------------------------===//
@@ -404,7 +435,7 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
- /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
+ /*traits=*/[], /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}
@@ -413,7 +444,7 @@ def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
//===----------------------------------------------------------------------===//
def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
- /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
+ /*traits=*/[], /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}
@@ -430,7 +461,7 @@ def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
//===----------------------------------------------------------------------===//
def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
- /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
+ /*traits=*/[], /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}
diff --git a/mlir/include/mlir/IR/QuantStorageTypeInterface.h b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
index dce430efdbc89..4f5e102feea85 100644
--- a/mlir/include/mlir/IR/QuantStorageTypeInterface.h
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
@@ -1,5 +1,4 @@
-//===- QuantStorageTypeInterface.h - Quantzation Interfaces --------*- C++
-//-*-===//
+//===- QuantStorageTypeInterface.h - Quantzation Interfaces -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +11,9 @@
#include "mlir/IR/Types.h"
-// Forward declarations for the types we need in the implementation
-namespace mlir {
-class IntegerType;
-} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Tablegen Interface Declarations
+//===----------------------------------------------------------------------===//
#include "mlir/IR/QuantStorageTypeInterface.h.inc"
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 579d35b3f82a5..95a3f554d5e78 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -70,7 +70,6 @@ add_mlir_library(MLIRIR
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
MLIROpAsmDialectInterfaceIncGen
- MLIRQuantizationInterfaceIncGen
LINK_LIBS PUBLIC
MLIRSupport
diff --git a/mlir/lib/IR/QuantStorageTypeInterface.cpp b/mlir/lib/IR/QuantStorageTypeInterface.cpp
index bcc2dc1f3337c..9287c76cc8c6a 100644
--- a/mlir/lib/IR/QuantStorageTypeInterface.cpp
+++ b/mlir/lib/IR/QuantStorageTypeInterface.cpp
@@ -1,5 +1,4 @@
-//===- QuantStorageTypeInterface.cpp
-//------------------------------------------===//
+//===- QuantStorageTypeInterface.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index 3b358443e43f2..6dbc86263bd71 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -235,3 +235,32 @@
!qalias = !quant.uniform<i8:f16:{0:1,1:2},
{{6.6e4:120,9.987200e-01:127}, {2.000000e+02:256,9.987200e-01}}>
+// -----
+// Illegal storage min/max: max > defaultMax
+// expected-error at +1 {{illegal storage type maximum: 60000}}
+!qalias = !quant.uniform<f8E5M2<-57344:60000>:f32, 0.99872:127>
+
+// -----
+// Illegal storage min/max: min < defaultMin
+// expected-error at +1 {{illegal storage type minimum: -60000}}
+!qalias = !quant.uniform<f8E5M2<-60000:57344>:f32, 0.99872:127>
+
+// -----
+// Illegal storage min/max: max > defaultMax
+// expected-error at +1 {{illegal storage type maximum: 500}}
+!qalias = !quant.uniform<f8E4M3FN<-448:500>:f32, 0.99872:127>
+
+// -----
+// Illegal storage min/max: min < defaultMin
+// expected-error at +1 {{illegal storage type minimum: -500}}
+!qalias = !quant.uniform<f8E4M3FN<-500:448>:f32, 0.99872:127>
+
+// -----
+// Illegal storage min/max: max > defaultMax
+// expected-error at +1 {{illegal storage type maximum: 10}}
+!qalias = !quant.uniform<f4E2M1FN<-6:10>:f32, 0.99872:127>
+
+// -----
+// Illegal storage min/max: min < defaultMin
+// expected-error at +1 {{illegal storage type minimum: -10}}
+!qalias = !quant.uniform<f4E2M1FN<-10:6>:f32, 0.99872:127>
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 80a6621ed6979..1431403aece41 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -172,3 +172,58 @@ func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
+
+// -----
+// Default min/max value optimization for f8E5M2.
+// CHECK: !quant.uniform<f8E5M2:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<f8E5M2<-57344:57344>:f32, 0.99872:127 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: f8E5M2
+// CHECK: !quant.uniform<f8E5M2:f32, 2.000000e+02>
+!qalias = !quant.uniform<f8E5M2:f32, 2.0e+2>
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Default min/max value optimization for f8E4M3FN.
+// CHECK: !quant.uniform<f8E4M3FN:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<f8E4M3FN<-448:448>:f32, 0.99872:127 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: f8E4M3FN
+// CHECK: !quant.uniform<f8E4M3FN:f32, 2.000000e+02>
+!qalias = !quant.uniform<f8E4M3FN:f32, 2.0e+2>
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+
+// -----
+// Default min/max value optimization for f4E2M1FN.
+// CHECK: !quant.uniform<f4E2M1FN:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<f4E2M1FN<-6:6>:f32, 0.99872:127 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: f4E2M1FN
+// CHECK: !quant.uniform<f4E2M1FN:f32, 2.000000e+02>
+!qalias = !quant.uniform<f4E2M1FN:f32, 2.0e+2>
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
>From 6933ce02846cb832a33a381d0461e7493ebcdaf3 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Fri, 9 Jan 2026 13:06:31 +0100
Subject: [PATCH 7/8] Rename getStorageType interface method to
getStorageTypeName, since it better describes its behaviour
---
mlir/include/mlir/IR/BuiltinTypes.td | 8 ++++----
mlir/include/mlir/IR/QuantStorageTypeInterface.td | 2 +-
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 2 +-
mlir/lib/IR/CMakeLists.txt | 2 +-
4 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index d842789f69943..37904e7b3cebb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -134,7 +134,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
int64_t getDefaultMinimum(bool isSigned) const { return -getDefaultMaximum(isSigned); }
/// Get the storage type as a string.
- std::string getStorageType(bool isSigned) const { return "f8E5M2"; }
+ std::string getStorageTypeName(bool isSigned) const { return "f8E5M2"; }
/// Check if this 8-bit floating point type uses packed representation.
bool isPacked() const { return false; }
@@ -206,7 +206,7 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
int64_t getDefaultMinimum(bool isSigned) const { return -getDefaultMaximum(isSigned); }
/// Get the storage type as a string.
- std::string getStorageType(bool isSigned) const { return "f8E4M3FN"; }
+ std::string getStorageTypeName(bool isSigned) const { return "f8E4M3FN"; }
/// Check if this 8-bit floating point type uses packed representation.
bool isPacked() const { return false; }
@@ -346,7 +346,7 @@ def Builtin_Float4E2M1FN : Builtin_FloatType<"Float4E2M1FN", "f4E2M1FN",
int64_t getDefaultMinimum(bool isSigned) const { return -getDefaultMaximum(isSigned); }
/// Get the storage type as a string.
- std::string getStorageType(bool isSigned) const { return "f4E2M1FN"; }
+ std::string getStorageTypeName(bool isSigned) const { return "f4E2M1FN"; }
/// Check if this 4-bit floating point type uses packed representation.
bool isPacked() const { return true; }
@@ -672,7 +672,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
}
/// Get the storage type as a string.
- std::string getStorageType(bool isSigned) const {
+ std::string getStorageTypeName(bool isSigned) const {
return (isSigned ? "i" : "u") + std::to_string(getWidth());
}
diff --git a/mlir/include/mlir/IR/QuantStorageTypeInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 4df7b1cfdbd29..22dabc58fca54 100644
--- a/mlir/include/mlir/IR/QuantStorageTypeInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -38,7 +38,7 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
InterfaceMethod<[{
Get the storage type as a string.
}],
- "std::string", "getStorageType", (ins "bool":$isSigned)>,
+ "std::string", "getStorageTypeName", (ins "bool":$isSigned)>,
InterfaceMethod<[{
Check if the storage type uses packed representation.
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index c1fe1285cefa5..1a42b90ac31e2 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -503,7 +503,7 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
auto quantStorageTypeInterface =
llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
- out << quantStorageTypeInterface.getStorageType(type.isSigned());
+ out << quantStorageTypeInterface.getStorageTypeName(type.isSigned());
// storageTypeMin and storageTypeMax if not default.
if (type.hasStorageTypeBounds()) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 95a3f554d5e78..632b06e4378f2 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,7 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
- QuantStorageTypeInterface.cpp
+ QuantStorageTypeInterface.cpp
Region.cpp
RegionKindInterface.cpp
Remarks.cpp
>From fa37a1cbef5e07ffd376054ea66df64017b93dcc Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Fri, 9 Jan 2026 14:02:14 +0100
Subject: [PATCH 8/8] Removed unnecessary headers
---
mlir/lib/IR/QuantStorageTypeInterface.cpp | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
diff --git a/mlir/lib/IR/QuantStorageTypeInterface.cpp b/mlir/lib/IR/QuantStorageTypeInterface.cpp
index 9287c76cc8c6a..e9c27187f3f33 100644
--- a/mlir/lib/IR/QuantStorageTypeInterface.cpp
+++ b/mlir/lib/IR/QuantStorageTypeInterface.cpp
@@ -6,14 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "llvm/ADT/Sequence.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
using namespace mlir;
-using namespace mlir::detail;
//===----------------------------------------------------------------------===//
/// Tablegen Interface Definitions
More information about the Mlir-commits
mailing list