[Mlir-commits] [mlir] Extending UniformQuantizedType with interface-based support for new storage types in Quant dialect (PR #152966)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 20 05:45:31 PDT 2025
https://github.com/Roman-Pevnyi updated https://github.com/llvm/llvm-project/pull/152966
>From 784e1d7d23b8712840f25768cee890142b94caaa Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Mon, 11 Aug 2025 09:33:14 +0200
Subject: [PATCH 1/4] Added new type interafce to let UniformQuantizeType
accept other than built in types. Updated parser and printer in Quant dialect
---
mlir/cmake/modules/AddMLIR.cmake | 8 ++
mlir/include/mlir/IR/BuiltinTypes.h | 2 +
mlir/include/mlir/IR/BuiltinTypes.td | 29 +++++++-
mlir/include/mlir/IR/CMakeLists.txt | 2 +
mlir/include/mlir/IR/QuantizationInterface.h | 22 ++++++
mlir/include/mlir/IR/QuantizationInterface.td | 44 +++++++++++
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 65 ++++++++--------
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 74 ++++++++++---------
mlir/lib/IR/CMakeLists.txt | 4 +-
mlir/lib/IR/QuantizationInterface.cpp | 23 ++++++
10 files changed, 207 insertions(+), 66 deletions(-)
create mode 100644 mlir/include/mlir/IR/QuantizationInterface.h
create mode 100644 mlir/include/mlir/IR/QuantizationInterface.td
create mode 100644 mlir/lib/IR/QuantizationInterface.cpp
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..c35308d57eadd 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -203,6 +203,14 @@ function(add_mlir_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()
+# Declare a dialect in the include directory
+function(add_mlir_type_interface interface)
+ set(LLVM_TARGET_DEFINITIONS ${interface}.td)
+ mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
+ mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
+ add_public_tablegen_target(MLIR${interface}IncGen)
+ add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+endfunction()
# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 86ec5c43970b1..204da9553e915 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
+#include "mlir/IR/QuantizationInterface.h"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..762f9262adbf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/QuantizationInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface]> {
+ [VectorElementTypeInterface, QuantizationInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
+
+ /// QuantizationInterface method implementations
+ /// Return true if this is a signed integer type.
+ bool isStorageSigned() const { return !isUnsigned(); }
+ /// Get the bit width of this integer type.
+ unsigned getStorageWidth() const { return getWidth(); }
+
+ /// Get default minimum value for this integer type.
+ int64_t getDefaultMinimum() const {
+ if (isStorageSigned()) {
+ return llvm::minIntN(getStorageWidth());
+ }
+ return 0;
+ }
+ /// Get default maximum value for this integer type.
+ int64_t getDefaultMaximum() const {
+ if (isStorageSigned()) {
+ return llvm::maxIntN(getStorageWidth());
+ }
+ return llvm::maxUIntN(getStorageWidth());
+ }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const {
+ return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
+ }
}];
}
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547ff131e3..153502c6e981b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
+add_mlir_type_interface(QuantizationInterface)
+
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantizationInterface.h
new file mode 100644
index 0000000000000..0d6709ff52065
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.h
@@ -0,0 +1,22 @@
+//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_QuantizationInterface_H
+#define MLIR_IR_QuantizationInterface_H
+
+#include "mlir/IR/Types.h"
+
+// Forward declarations for the types we need in the implementation
+namespace mlir {
+class IntegerType;
+} // namespace mlir
+
+#include "mlir/IR/QuantizationInterface.h.inc"
+
+#endif // MLIR_IR_QuantizationInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantizationInterface.td
new file mode 100644
index 0000000000000..1008ac8e1dcf1
--- /dev/null
+++ b/mlir/include/mlir/IR/QuantizationInterface.td
@@ -0,0 +1,44 @@
+#ifndef MLIR_IR_QUANTIZATIONINTERFACE
+#define MLIR_IR_QUANTIZATIONINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+ let description = [{
+ Interface for types that can be used as storage types in Quant dialect.
+ This interface provides methods to determine storage characteristics for quantization purposes.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Check if the storage type is signed.
+ Returns true if the type represents signed values, false for unsigned.
+ }],
+ "bool", "isStorageSigned", (ins)>,
+
+ InterfaceMethod<[{
+ Get the bit width of this integer type.
+ Returns the number of bits used to store values of this type.
+ }],
+ "unsigned", "getStorageWidth", (ins)>,
+
+ InterfaceMethod<[{
+ Get default minimum value for this integer type.
+ }],
+ "int64_t", "getDefaultMinimum", (ins)>,
+
+ InterfaceMethod<[{
+ Get default maximum value for this integer type.
+ }],
+ "int64_t", "getDefaultMaximum", (ins)>,
+
+ InterfaceMethod<[{
+ Get the storage type as a string.
+ }],
+ "std::string", "getStorageType", (ins)>
+ ];
+
+}
+
+#endif // MLIR_IR_QUANTIZATIONINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index b2227792f32ca..e7f9b1dc8a7e1 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
- unsigned integralWidth = intStorageType.getWidth();
-
- // Verify storage width.
- if (integralWidth == 0 || integralWidth > MaxStorageBits)
- return emitError() << "illegal storage type size: " << integralWidth;
-
- // Verify storageTypeMin and storageTypeMax.
- bool isSigned =
- (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSigned, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSigned, integralWidth);
- if (storageTypeMax - storageTypeMin <= 0 ||
- storageTypeMin < defaultIntegerMin ||
- storageTypeMax > defaultIntegerMax) {
- return emitError() << "illegal storage min and storage max: ("
- << storageTypeMin << ":" << storageTypeMax << ")";
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ unsigned integralWidth = quantizationInterface.getStorageWidth();
+
+ // Verify storage width.
+ if (integralWidth == 0 || integralWidth > MaxStorageBits)
+ return emitError() << "illegal storage type size: " << integralWidth;
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
+ storageTypeMax > defaultMax) {
+ return emitError() << "illegal storage min and storage max: ("
+ << storageTypeMin << ":" << storageTypeMax << ")";
+ }
+
+ return success();
}
- return success();
+
+ return emitError() << "storage type must implement QuantizationInterface";
}
Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
}
bool QuantizedType::hasStorageTypeBounds() const {
- unsigned int integralWidth = getStorageTypeIntegralWidth();
- bool isSignedInteger = isSigned();
- int64_t defaultIntegerMin =
- getDefaultMinimumForInteger(isSignedInteger, integralWidth);
- int64_t defaultIntegerMax =
- getDefaultMaximumForInteger(isSignedInteger, integralWidth);
- return defaultIntegerMin != getStorageTypeMin() ||
- defaultIntegerMax != getStorageTypeMax();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ int64_t defaultMin = quantizationInterface.getDefaultMinimum();
+ int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+
+ return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
- // NOTE: If ever supporting non-integral storage types, some other scheme
- // for determining the width will be needed.
- return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
+ Type storageType = static_cast<ImplType *>(impl)->storageType;
+ auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType);
+
+ return quantizationInterface.getStorageWidth();
}
Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 9a18cff24e62a..758399a2af5e8 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,15 +10,16 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
using namespace mlir;
using namespace quant;
-static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
+static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
auto typeLoc = parser.getCurrentLocation();
- IntegerType type;
+ Type type;
// Parse storage type (alpha_ident, integer_literal).
StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (result.has_value()) {
if (!succeeded(*result))
return nullptr;
- isSigned = !type.isUnsigned();
- storageTypeWidth = type.getWidth();
- } else if (succeeded(parser.parseKeyword(&identifier))) {
- // Otherwise, this must be an unsigned integer (`u` integer-literal).
- if (!identifier.consume_front("u")) {
- parser.emitError(typeLoc, "illegal storage type prefix");
+
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type)) {
+ isSigned = quantizationInterface.isStorageSigned();
+ storageTypeWidth = quantizationInterface.getStorageWidth();
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- if (identifier.getAsInteger(10, storageTypeWidth)) {
- parser.emitError(typeLoc, "expected storage type width");
+ } else if (succeeded(parser.parseKeyword(&identifier))) {
+ // Otherwise, this must be an unsigned integer (`u` integer-literal)
+ if (identifier.consume_front("u")) {
+ if (identifier.getAsInteger(10, storageTypeWidth)) {
+ parser.emitError(typeLoc, "expected storage type width");
+ return nullptr;
+ }
+ isSigned = false;
+ type = parser.getBuilder().getIntegerType(storageTypeWidth);
+ } else {
+ parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
- isSigned = false;
- type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
return nullptr;
}
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
return type;
}
-static ParseResult parseStorageRange(DialectAsmParser &parser,
- IntegerType storageType, bool isSigned,
+static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
- int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
- isSigned, storageType.getWidth());
- int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
- isSigned, storageType.getWidth());
+ int64_t defaultMin, defaultMax;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(storageType)) {
+ defaultMin = quantizationInterface.getDefaultMinimum();
+ defaultMax = quantizationInterface.getDefaultMaximum();
+ }
+
if (failed(parser.parseOptionalLess())) {
- storageTypeMin = defaultIntegerMin;
- storageTypeMax = defaultIntegerMax;
+ storageTypeMin = defaultMin;
+ storageTypeMax = defaultMax;
return success();
}
@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
parser.getCurrentLocation(&maxLoc) ||
parser.parseInteger(storageTypeMax) || parser.parseGreater())
return failure();
- if (storageTypeMin < defaultIntegerMin) {
+ if (storageTypeMin < defaultMin) {
return parser.emitError(minLoc, "illegal storage type minimum: ")
<< storageTypeMin;
}
- if (storageTypeMax > defaultIntegerMax) {
+ if (storageTypeMax > defaultMax) {
return parser.emitError(maxLoc, "illegal storage type maximum: ")
<< storageTypeMax;
}
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// scale-zero-tensor (`,` scale-zero-tensor)*
/// `}`
static Type parseUniformType(DialectAsmParser &parser) {
- IntegerType storageType;
+ Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
}
// Storage type range.
- if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
- storageTypeMax)) {
+ if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}
@@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
- unsigned storageWidth = type.getStorageTypeIntegralWidth();
- bool isSigned = type.isSigned();
- if (isSigned) {
- out << "i" << storageWidth;
- } else {
- out << "u" << storageWidth;
+ if (auto quantizationInterface =
+ llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
+ out << quantizationInterface.getStorageType();
}
// storageTypeMin and storageTypeMax if not default.
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69cea18f0a..f539aca7fff48 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
+ QuantizationInterface.cpp
Region.cpp
RegionKindInterface.cpp
SymbolTable.cpp
@@ -66,7 +67,8 @@ add_mlir_library(MLIRIR
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
-
+ MLIRQuantizationInterfaceIncGen
+
LINK_LIBS PUBLIC
MLIRSupport
)
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantizationInterface.cpp
new file mode 100644
index 0000000000000..a93333278610e
--- /dev/null
+++ b/mlir/lib/IR/QuantizationInterface.cpp
@@ -0,0 +1,23 @@
+//===- QuantizationInterface.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/QuantizationInterface.cpp.inc"
>From 407fae44ee95ac8f53b587e17f0d26cd8bdbaa1b Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 13 Aug 2025 13:13:42 +0200
Subject: [PATCH 2/4] Added QuantizationInterface to Float8E5M2Type and
Float8E4M3FNType
---
mlir/include/mlir/IR/BuiltinTypes.td | 54 ++++++++++++++++++++++------
mlir/lib/IR/CMakeLists.txt | 2 +-
2 files changed, 44 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 762f9262adbf2..12ac05b64b577 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -101,7 +101,8 @@ class Builtin_CachedFloatType<string name, string mnemonic,
// Float8E5M2Type
//===----------------------------------------------------------------------===//
-def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
+def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
+ ["QuantizationInterface"]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -117,6 +118,21 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
Described in: https://arxiv.org/abs/2209.05433
}];
+
+ let extraClassDeclaration = [{
+ /// QuantizationInterface method implementations
+ bool isStorageSigned() const { return true; }
+ /// Get the bit width of this 8-bit floating point type.
+ unsigned getStorageWidth() const { return 8; }
+
+ /// Get default maximum value for this 8-bit floating point type.
+ int64_t getDefaultMaximum() const { return 57344; }
+ /// Get default minimum value for this 8-bit floating point type.
+ int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const { return "f8E5M2"; }
+ }];
}
//===----------------------------------------------------------------------===//
@@ -143,7 +159,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
// Float8E4M3FNType
//===----------------------------------------------------------------------===//
-def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
+def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
+ ["QuantizationInterface"]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -160,6 +177,21 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
Described in: https://arxiv.org/abs/2209.05433
}];
+
+ let extraClassDeclaration = [{
+ /// QuantizationInterface method implementations
+ bool isStorageSigned() const { return true; }
+ /// Get the bit width of this 8-bit floating point type.
+ unsigned getStorageWidth() const { return 8; }
+
+ /// Get default maximum value for this 8-bit floating point type.
+ int64_t getDefaultMaximum() const { return 448; }
+ /// Get default minimum value for this 8-bit floating point type.
+ int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
+
+ /// Get the storage type as a string.
+ std::string getStorageType() const { return "f8E4M3FN"; }
+ }];
}
//===----------------------------------------------------------------------===//
@@ -557,18 +589,11 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
/// QuantizationInterface method implementations
- /// Return true if this is a signed integer type.
+ /// Return true if this is a signed or signless integer type.
bool isStorageSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
unsigned getStorageWidth() const { return getWidth(); }
- /// Get default minimum value for this integer type.
- int64_t getDefaultMinimum() const {
- if (isStorageSigned()) {
- return llvm::minIntN(getStorageWidth());
- }
- return 0;
- }
/// Get default maximum value for this integer type.
int64_t getDefaultMaximum() const {
if (isStorageSigned()) {
@@ -576,7 +601,14 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
}
return llvm::maxUIntN(getStorageWidth());
}
-
+ /// Get default minimum value for this integer type.
+ int64_t getDefaultMinimum() const {
+ if (isStorageSigned()) {
+ return llvm::minIntN(getStorageWidth());
+ }
+ return 0;
+ }
+
/// Get the storage type as a string.
std::string getStorageType() const {
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index f539aca7fff48..8a50b077d9014 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -63,11 +63,11 @@ add_mlir_library(MLIRIR
MLIRCastInterfacesIncGen
MLIRDataLayoutInterfacesIncGen
MLIROpAsmInterfaceIncGen
+ MLIRQuantizationInterfaceIncGen
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
- MLIRQuantizationInterfaceIncGen
LINK_LIBS PUBLIC
MLIRSupport
>From d7447a8b168ad1111e49deaa526c01e89620cd76 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 20 Aug 2025 13:18:44 +0200
Subject: [PATCH 3/4] Renamed to QuantStorageTypeInterface and removed
redundant dyn_cast checks
---
mlir/include/mlir/IR/BuiltinTypes.h | 2 +-
mlir/include/mlir/IR/BuiltinTypes.td | 14 ++++----
mlir/include/mlir/IR/CMakeLists.txt | 2 +-
...nterface.h => QuantStorageTypeInterface.h} | 10 +++---
...erface.td => QuantStorageTypeInterface.td} | 8 ++---
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 35 ++++++++-----------
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 29 ++++++++-------
mlir/lib/IR/CMakeLists.txt | 5 ++-
...face.cpp => QuantStorageTypeInterface.cpp} | 4 +--
9 files changed, 50 insertions(+), 59 deletions(-)
rename mlir/include/mlir/IR/{QuantizationInterface.h => QuantStorageTypeInterface.h} (63%)
rename mlir/include/mlir/IR/{QuantizationInterface.td => QuantStorageTypeInterface.td} (83%)
rename mlir/lib/IR/{QuantizationInterface.cpp => QuantStorageTypeInterface.cpp} (89%)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 204da9553e915..d30cba29c9814 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,7 +167,7 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 12ac05b64b577..1a29b9549cf51 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,7 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
-include "mlir/IR/QuantizationInterface.td"
+include "mlir/IR/QuantStorageTypeInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -102,7 +102,7 @@ class Builtin_CachedFloatType<string name, string mnemonic,
//===----------------------------------------------------------------------===//
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
- ["QuantizationInterface"]> {
+ ["QuantStorageTypeInterface"]> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -120,7 +120,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
}];
let extraClassDeclaration = [{
- /// QuantizationInterface method implementations
+ /// QuantStorageTypeInterface method implementations
bool isStorageSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }
@@ -160,7 +160,7 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
//===----------------------------------------------------------------------===//
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
- ["QuantizationInterface"]> {
+ ["QuantStorageTypeInterface"]> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -179,7 +179,7 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
}];
let extraClassDeclaration = [{
- /// QuantizationInterface method implementations
+ /// QuantStorageTypeInterface method implementations
bool isStorageSigned() const { return true; }
/// Get the bit width of this 8-bit floating point type.
unsigned getStorageWidth() const { return 8; }
@@ -530,7 +530,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface, QuantizationInterface]> {
+ [VectorElementTypeInterface, QuantStorageTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -588,7 +588,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
- /// QuantizationInterface method implementations
+ /// QuantStorageTypeInterface method implementations
/// Return true if this is a signed or signless integer type.
bool isStorageSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 153502c6e981b..1e4ca1b8328c6 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)
-add_mlir_type_interface(QuantizationInterface)
+add_mlir_type_interface(QuantStorageTypeInterface)
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
diff --git a/mlir/include/mlir/IR/QuantizationInterface.h b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
similarity index 63%
rename from mlir/include/mlir/IR/QuantizationInterface.h
rename to mlir/include/mlir/IR/QuantStorageTypeInterface.h
index 0d6709ff52065..dce430efdbc89 100644
--- a/mlir/include/mlir/IR/QuantizationInterface.h
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.h
@@ -1,4 +1,4 @@
-//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
+//===- QuantStorageTypeInterface.h - Quantzation Interfaces --------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_IR_QuantizationInterface_H
-#define MLIR_IR_QuantizationInterface_H
+#ifndef MLIR_IR_QuantStorageTypeInterface_H
+#define MLIR_IR_QuantStorageTypeInterface_H
#include "mlir/IR/Types.h"
@@ -17,6 +17,6 @@ namespace mlir {
class IntegerType;
} // namespace mlir
-#include "mlir/IR/QuantizationInterface.h.inc"
+#include "mlir/IR/QuantStorageTypeInterface.h.inc"
-#endif // MLIR_IR_QuantizationInterface_H
+#endif // MLIR_IR_QuantStorageTypeInterface_H
diff --git a/mlir/include/mlir/IR/QuantizationInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
similarity index 83%
rename from mlir/include/mlir/IR/QuantizationInterface.td
rename to mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 1008ac8e1dcf1..37fe4aaccf610 100644
--- a/mlir/include/mlir/IR/QuantizationInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -1,9 +1,9 @@
-#ifndef MLIR_IR_QUANTIZATIONINTERFACE
-#define MLIR_IR_QUANTIZATIONINTERFACE
+#ifndef MLIR_IR_QUANTSTORAGETYPEINTERFACE
+#define MLIR_IR_QUANTSTORAGETYPEINTERFACE
include "mlir/IR/OpBase.td"
-def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
+def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
This interface provides methods to determine storage characteristics for quantization purposes.
@@ -41,4 +41,4 @@ def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
}
-#endif // MLIR_IR_QUANTIZATIONINTERFACE
+#endif // MLIR_IR_QUANTSTORAGETYPEINTERFACE
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index e7f9b1dc8a7e1..b27cb790b9052 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -9,7 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -47,23 +47,16 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
- // Verify that the storage type is integral.
- // This restriction may be lifted at some point in favor of using bf16
- // or f16 as exact representations on hardware where that is advantageous.
- auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
- if (!intStorageType)
- return emitError() << "storage type must be integral";
-
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType)) {
- unsigned integralWidth = quantizationInterface.getStorageWidth();
+ if (auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
+ unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;
- int64_t defaultMin = quantizationInterface.getDefaultMinimum();
- int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
storageTypeMax > defaultMax) {
@@ -74,7 +67,7 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
- return emitError() << "storage type must implement QuantizationInterface";
+ return emitError() << "storage type must implement QuantStorageTypeInterface";
}
Type QuantizedType::getStorageType() const {
@@ -91,21 +84,21 @@ int64_t QuantizedType::getStorageTypeMax() const {
bool QuantizedType::hasStorageTypeBounds() const {
Type storageType = static_cast<ImplType *>(impl)->storageType;
- auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType);
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
- int64_t defaultMin = quantizationInterface.getDefaultMinimum();
- int64_t defaultMax = quantizationInterface.getDefaultMaximum();
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
Type storageType = static_cast<ImplType *>(impl)->storageType;
- auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType);
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
- return quantizationInterface.getStorageWidth();
+ return quantStorageTypeInterface.getStorageWidth();
}
Type QuantizedType::getExpressedType() const {
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 758399a2af5e8..f86df4eacd21f 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -10,7 +10,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/QuantizationInterface.h"
+#include "mlir/IR/QuantStorageTypeInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
@@ -29,10 +29,10 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (!succeeded(*result))
return nullptr;
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(type)) {
- isSigned = quantizationInterface.isStorageSigned();
- storageTypeWidth = quantizationInterface.getStorageWidth();
+ if (auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
+ isSigned = quantStorageTypeInterface.isStorageSigned();
+ storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
} else {
parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
@@ -67,12 +67,11 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
- int64_t defaultMin, defaultMax;
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(storageType)) {
- defaultMin = quantizationInterface.getDefaultMinimum();
- defaultMax = quantizationInterface.getDefaultMaximum();
- }
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
+
+ int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum();
+ int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum();
if (failed(parser.parseOptionalLess())) {
storageTypeMin = defaultMin;
@@ -495,10 +494,10 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
- if (auto quantizationInterface =
- llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
- out << quantizationInterface.getStorageType();
- }
+ auto quantStorageTypeInterface =
+ llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
+
+ out << quantStorageTypeInterface.getStorageType();
// storageTypeMin and storageTypeMax if not default.
if (type.hasStorageTypeBounds()) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 8a50b077d9014..9e0d283854b38 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -31,7 +31,7 @@ add_mlir_library(MLIRIR
OperationSupport.cpp
PatternLoggingListener.cpp
PatternMatch.cpp
- QuantizationInterface.cpp
+ QuantStorageTypeInterface.cpp
Region.cpp
RegionKindInterface.cpp
SymbolTable.cpp
@@ -63,7 +63,7 @@ add_mlir_library(MLIRIR
MLIRCastInterfacesIncGen
MLIRDataLayoutInterfacesIncGen
MLIROpAsmInterfaceIncGen
- MLIRQuantizationInterfaceIncGen
+ MLIRQuantStorageTypeInterfaceIncGen
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
@@ -72,4 +72,3 @@ add_mlir_library(MLIRIR
LINK_LIBS PUBLIC
MLIRSupport
)
-
diff --git a/mlir/lib/IR/QuantizationInterface.cpp b/mlir/lib/IR/QuantStorageTypeInterface.cpp
similarity index 89%
rename from mlir/lib/IR/QuantizationInterface.cpp
rename to mlir/lib/IR/QuantStorageTypeInterface.cpp
index a93333278610e..bcc2dc1f3337c 100644
--- a/mlir/lib/IR/QuantizationInterface.cpp
+++ b/mlir/lib/IR/QuantStorageTypeInterface.cpp
@@ -1,4 +1,4 @@
-//===- QuantizationInterface.cpp
+//===- QuantStorageTypeInterface.cpp
//------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -20,4 +20,4 @@ using namespace mlir::detail;
/// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
-#include "mlir/IR/QuantizationInterface.cpp.inc"
+#include "mlir/IR/QuantStorageTypeInterface.cpp.inc"
>From f8458e773b1003c669bbed7db46c95253ef8e810 Mon Sep 17 00:00:00 2001
From: Roman Pevnyi <roman.pevnyi at intel.com>
Date: Wed, 20 Aug 2025 14:45:14 +0200
Subject: [PATCH 4/4] Added interface methods which expose a few basic packing
and alignment facts
---
mlir/include/mlir/IR/BuiltinTypes.td | 36 +++++++++++++++++++
.../mlir/IR/QuantStorageTypeInterface.td | 35 +++++++++++++++---
2 files changed, 66 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1a29b9549cf51..c7e9ee2a4c95f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -132,6 +132,18 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
/// Get the storage type as a string.
std::string getStorageType() const { return "f8E5M2"; }
+
+ /// Check if this 8-bit floating point type uses packed representation.
+ bool isPacked() const { return false; }
+
+ /// Get the logical bit width per value for this 8-bit floating point type.
+ unsigned getLogicalBitWidth() const { return 8; }
+
+ /// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
+ unsigned getElementsPerByte() const { return 1; }
+
+ /// Get the preferred alignment in bytes for this 8-bit floating point type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}
@@ -191,6 +203,18 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
/// Get the storage type as a string.
std::string getStorageType() const { return "f8E4M3FN"; }
+
+ /// Check if this 8-bit floating point type uses packed representation.
+ bool isPacked() const { return false; }
+
+ /// Get the logical bit width per value for this 8-bit floating point type.
+ unsigned getLogicalBitWidth() const { return 8; }
+
+ /// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
+ unsigned getElementsPerByte() const { return 1; }
+
+ /// Get the preferred alignment in bytes for this 8-bit floating point type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}
@@ -613,6 +637,18 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
std::string getStorageType() const {
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
}
+
+ /// Check if this integer type uses packed representation.
+ bool isPacked() const { return false; }
+
+ /// Get the logical bit width per value for this integer type.
+ unsigned getLogicalBitWidth() const { return getWidth(); }
+
+ /// Get the number of logical elements that fit in one byte for this integer type.
+ unsigned getElementsPerByte() const { return 1; }
+
+ /// Get the preferred alignment in bytes for this integer type.
+ std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
}];
}
diff --git a/mlir/include/mlir/IR/QuantStorageTypeInterface.td b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
index 37fe4aaccf610..b09c6e79df840 100644
--- a/mlir/include/mlir/IR/QuantStorageTypeInterface.td
+++ b/mlir/include/mlir/IR/QuantStorageTypeInterface.td
@@ -6,7 +6,8 @@ include "mlir/IR/OpBase.td"
def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
- This interface provides methods to determine storage characteristics for quantization purposes.
+ This interface provides methods to determine storage characteristics for quantization purposes,
+ including packing behavior, and alignment requirements.
}];
let cppNamespace = "::mlir";
@@ -18,25 +19,49 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
"bool", "isStorageSigned", (ins)>,
InterfaceMethod<[{
- Get the bit width of this integer type.
+ Get the bit width of this type.
Returns the number of bits used to store values of this type.
}],
"unsigned", "getStorageWidth", (ins)>,
InterfaceMethod<[{
- Get default minimum value for this integer type.
+ Get default minimum value for this type.
}],
"int64_t", "getDefaultMinimum", (ins)>,
InterfaceMethod<[{
- Get default maximum value for this integer type.
+ Get default maximum value for this type.
}],
"int64_t", "getDefaultMaximum", (ins)>,
InterfaceMethod<[{
Get the storage type as a string.
}],
- "std::string", "getStorageType", (ins)>
+ "std::string", "getStorageType", (ins)>,
+
+ InterfaceMethod<[{
+ Check if the storage type uses packed representation.
+ Returns true if multiple values are packed into one byte (e.g., sub-byte types),
+ false if value uses full byte.
+ }],
+ "bool", "isPacked", (ins)>,
+
+ InterfaceMethod<[{
+ Get the logical bit width per value.
+ For packed sub-byte types, this may differ from getStorageWidth().
+ }],
+ "unsigned", "getLogicalBitWidth", (ins)>,
+
+ InterfaceMethod<[{
+ Get the number of logical elements that fit in one byte.
+ For packed sub-byte types, this returns how many values can be stored per byte.
+ }],
+ "unsigned", "getElementsPerByte", (ins)>,
+
+ InterfaceMethod<[{
+ Returns the preferred alignment for this type, in bytes.
+ }],
+ "std::optional<unsigned>", "getPreferredAlignmentBytes", (ins)>
];
}
More information about the Mlir-commits
mailing list