[Mlir-commits] [mlir] Quantile Type and Low FP Support (PR #190321)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 26 05:35:15 PDT 2026
https://github.com/vsimion26 updated https://github.com/llvm/llvm-project/pull/190321
>From 36c171064abf042a7d86290489490ad12915bb9d Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Thu, 2 Apr 2026 09:59:52 +0100
Subject: [PATCH 1/6] Added a new BuiltinType, QuantileType and support for low
precision float storage types. Added lit tests
---
mlir/include/mlir/IR/BuiltinTypes.h | 1 +
mlir/include/mlir/IR/BuiltinTypes.td | 78 ++++++++++++++++++++++++
mlir/lib/AsmParser/Parser.h | 3 +
mlir/lib/AsmParser/TokenKinds.def | 1 +
mlir/lib/AsmParser/TypeParser.cpp | 55 +++++++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 15 +++++
mlir/lib/IR/BuiltinTypes.cpp | 74 +++++++++++++++++++++-
mlir/lib/IR/TypeDetail.h | 46 ++++++++++++++
mlir/test/IR/invalid-quantile-types.mlir | 39 ++++++++++++
mlir/test/IR/quantile-types.mlir | 55 +++++++++++++++++
10 files changed, 366 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/IR/invalid-quantile-types.mlir
create mode 100644 mlir/test/IR/quantile-types.mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d30cba29c9814..ef47850929beb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -35,6 +35,7 @@ class TypeRange;
namespace detail {
struct FunctionTypeStorage;
struct IntegerTypeStorage;
+struct QuantileTypeStorage;
struct TupleTypeStorage;
} // namespace detail
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 20c41c5f79729..8acdf5f0a7703 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1115,6 +1115,84 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
let genVerifyDecl = 1;
}
+//===----------------------------------------------------------------------===//
+// QuantileType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
+ [QuantStorageTypeInterface, MemRefElementTypeInterface]> {
+ let summary = "Quantile-based type with a lookup table of quantile values";
+ let description = [{
+ Syntax:
+
+ ```
+ quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
+ ```
+
+ A quantile type represents a quantile-based floating point encoding, where
+ discrete storage values map to floating-point values via a quantile lookup
+ table. The type has a storage type (how raw values are stored, e.g., `ui4`,
+ `si8`), a quantile type (the expressed floating-point precision, e.g.,
+ `f16`, `f32`), and an array of quantile values that define the mapping.
+
+ This type is used for weight compression schemes like NF4 (NormalFloat4)
+ and similar quantile-based formats.
+
+ #### Example:
+
+ ```mlir
+ quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
+ ```
+ }];
+ let parameters = (ins
+ "Type":$storageType,
+ "Type":$quantileType,
+ ArrayRefParameter<"double">:$quantiles
+ );
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "Type":$storageType,
+ "Type":$quantileType,
+ "ArrayRef<double>":$quantiles), [{
+ return $_get(storageType.getContext(), storageType, quantileType,
+ quantiles);
+ }]>
+ ];
+ let skipDefaultBuilders = 1;
+ let genStorageClass = 0;
+ let genVerifyDecl = 1;
+ let extraClassDeclaration = [{
+ /// QuantStorageTypeInterface method implementations.
+
+ /// Returns true if the storage type defaults to signed.
+ bool shouldDefaultToSigned() const;
+
+ /// Get the bit width of the storage type.
+ unsigned getStorageWidth() const;
+
+ /// Get default minimum value for the storage type.
+ int64_t getDefaultMinimum(bool isSigned) const;
+
+ /// Get default maximum value for the storage type.
+ int64_t getDefaultMaximum(bool isSigned) const;
+
+ /// Get the storage type as a string.
+ std::string getStorageTypeName(bool isSigned) const;
+
+ /// Check if the storage type uses packed representation.
+ bool isPacked() const;
+
+ /// Get the logical bit width per value.
+ unsigned getLogicalBitWidth() const;
+
+ /// Returns how many values of this type fit in one byte.
+ unsigned getElementsPerByte() const;
+
+ /// Get the preferred alignment in bytes.
+ std::optional<unsigned> getPreferredAlignmentBytes() const;
+ }];
+}
+
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index ecc128cf767b3..ccd469df3ab13 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -213,6 +213,9 @@ class Parser {
/// Parse a complex type.
Type parseComplexType();
+ /// Parse a quantile type.
+ Type parseQuantileType();
+
/// Parse an extended type.
Type parseExtendedType();
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fe7c53753e156..b2cd2d1e8dd58 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -119,6 +119,7 @@ TOK_KEYWORD(min)
TOK_KEYWORD(mod)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
+TOK_KEYWORD(quantile)
TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(step)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index a461ebed967a8..906c2cf70b605 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -51,6 +51,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f8E8M0FNU:
case Token::kw_bf16:
case Token::kw_f16:
+ case Token::kw_quantile:
case Token::kw_tf32:
case Token::kw_f32:
case Token::kw_f64:
@@ -282,6 +283,8 @@ Type Parser::parseNonFunctionType() {
return parseTensorType();
case Token::kw_complex:
return parseComplexType();
+ case Token::kw_quantile:
+ return parseQuantileType();
case Token::kw_tuple:
return parseTupleType();
case Token::kw_vector:
@@ -383,6 +386,58 @@ Type Parser::parseNonFunctionType() {
}
}
+/// Parse a quantile type.
+///
+/// quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
+///
+Type Parser::parseQuantileType() {
+ consumeToken(Token::kw_quantile);
+
+ if (parseToken(Token::less, "expected '<' in quantile type"))
+ return nullptr;
+
+ // Parse the storage type.
+ Type storageType = parseType();
+ if (!storageType)
+ return nullptr;
+
+ if (parseToken(Token::colon, "expected ':' in quantile type"))
+ return nullptr;
+
+ // Parse the quantile (expressed) type.
+ Type quantileType = parseType();
+ if (!quantileType)
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ',' in quantile type"))
+ return nullptr;
+
+ if (parseToken(Token::l_brace, "expected '{' in quantile type"))
+ return nullptr;
+
+ // Parse the quantile values as floating point literals.
+ SmallVector<double, 16> quantiles;
+ do {
+ bool isNegative = consumeIf(Token::minus);
+ Token curTok = getToken();
+ std::optional<APFloat> apResult;
+ if (failed(parseFloatFromLiteral(apResult, curTok, isNegative,
+ APFloat::IEEEdouble())))
+ return nullptr;
+ consumeToken();
+ quantiles.push_back(apResult->convertToDouble());
+ } while (consumeIf(Token::comma) &&
+ !getToken().is(Token::r_brace));
+
+ if (parseToken(Token::r_brace, "expected '}' in quantile type"))
+ return nullptr;
+
+ if (parseToken(Token::greater, "expected '>' in quantile type"))
+ return nullptr;
+
+ return QuantileType::get(storageType, quantileType, quantiles);
+}
+
/// Parse a tensor type.
///
/// tensor-type ::= `tensor` `<` dimension-list type `>`
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75008d6cc2591..e1145eaafbc64 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2900,6 +2900,21 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(complexTy.getElementType());
os << '>';
})
+ .Case<QuantileType>([&](QuantileType quantileTy) {
+ os << "quantile<";
+ printType(quantileTy.getStorageType());
+ os << ':';
+ printType(quantileTy.getQuantileType());
+ os << ", {";
+ ArrayRef<double> quantiles = quantileTy.getQuantiles();
+ // interleaveComma(llvm::seq<size_t>(0, quantiles.size()),
+ // [&](size_t i) { os << quantiles[i]; });
+ llvm::interleave(
+ llvm::seq<size_t>(0, quantiles.size()), os,
+ [&](size_t index) { os << quantiles[index]; }, ",");
+ os << "}";
+ os << '>';
+ })
.Case([&](TupleType tupleTy) {
os << "tuple<";
interleaveComma(tupleTy.getTypes(),
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 786c30851a071..faf38e6a8354e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -94,6 +94,78 @@ ComplexType::convertFromAttribute(Attribute attr,
return success();
}
+//===----------------------------------------------------------------------===//
+/// QuantileType
+//===----------------------------------------------------------------------===//
+
+Type QuantileType::getStorageType() const { return getImpl()->storageType; }
+
+Type QuantileType::getQuantileType() const { return getImpl()->quantileType; }
+
+ArrayRef<double> QuantileType::getQuantiles() const {
+ return getImpl()->getQuantiles();
+}
+
+LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Type storageType, Type quantileType,
+ ArrayRef<double> quantiles) {
+ if (!storageType.isIntOrFloat())
+ return emitError() << "storage type must be an integer or float type";
+ if (!llvm::isa<FloatType>(quantileType))
+ return emitError() << "quantile type must be a float type";
+ if (quantiles.empty())
+ return emitError() << "quantile values must not be empty";
+ return success();
+}
+
+bool QuantileType::shouldDefaultToSigned() const {
+ if (auto intType = llvm::dyn_cast<IntegerType>(getStorageType()))
+ return !intType.isUnsigned();
+ // Float types default to signed.
+ return true;
+}
+
+unsigned QuantileType::getStorageWidth() const {
+ return getStorageType().getIntOrFloatBitWidth();
+}
+
+int64_t QuantileType::getDefaultMaximum(bool isSigned) const {
+ if (isSigned)
+ return (1LL << (getStorageWidth() - 1)) - 1;
+ return (1LL << getStorageWidth()) - 1;
+}
+
+int64_t QuantileType::getDefaultMinimum(bool isSigned) const {
+ if (isSigned)
+ return -(1LL << (getStorageWidth() - 1));
+ return 0;
+}
+
+std::string QuantileType::getStorageTypeName(bool isSigned) const {
+ std::string result = "quantile<";
+ llvm::raw_string_ostream os(result);
+ os << getStorageType() << ":" << getQuantileType() << ", {";
+ ArrayRef<double> quantiles = getQuantiles();
+ llvm::interleave(
+ llvm::seq<size_t>(0, quantiles.size()), os,
+ [&](size_t index) { os << quantiles[index]; }, ",");
+ os << "}>";
+ return result;
+}
+
+bool QuantileType::isPacked() const { return getStorageWidth() <= 4; }
+
+unsigned QuantileType::getLogicalBitWidth() const { return getStorageWidth(); }
+
+unsigned QuantileType::getElementsPerByte() const {
+ unsigned width = getStorageWidth();
+ return width > 0 ? 8 / width : 0;
+}
+
+std::optional<unsigned> QuantileType::getPreferredAlignmentBytes() const {
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -435,7 +507,7 @@ bool TensorType::isValidElementType(Type type) {
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
- IndexType>(type) ||
+ IndexType, QuantileType>(type) ||
!llvm::isa<BuiltinDialect>(type.getDialect());
}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 0e952d5c14c7e..6afcc6307d4b4 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -148,6 +148,52 @@ Attribute skipDefaultMemorySpace(Attribute memorySpace);
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt(Attribute memorySpace);
+/// Quantile Type Storage and Uniquing.
+struct QuantileTypeStorage : public TypeStorage {
+ QuantileTypeStorage(Type storageType, Type quantileType,
+ ArrayRef<double> quantiles)
+ : storageType(storageType), quantileType(quantileType),
+ quantilesData(quantiles.data()), numQuantiles(quantiles.size()) {}
+
+ /// The hash key used for uniquing.
+ using KeyTy = std::tuple<Type, Type, ArrayRef<double>>;
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ auto quantiles = std::get<2>(key);
+ // Bit-cast doubles to int64_t for hashing since LLVM hashing
+ // does not natively support double.
+ auto *quantilesBits = llvm::bit_cast<const int64_t *>(quantiles.data());
+ ArrayRef<int64_t> quantilesAsInts(quantilesBits, quantiles.size());
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+ llvm::hash_combine_range(quantilesAsInts.begin(),
+ quantilesAsInts.end()));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return storageType == std::get<0>(key) &&
+ quantileType == std::get<1>(key) &&
+ getQuantiles() == std::get<2>(key);
+ }
+
+ static QuantileTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ ArrayRef<double> quantiles = allocator.copyInto(std::get<2>(key));
+ return new (allocator.allocate<QuantileTypeStorage>())
+ QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles);
+ }
+
+ Type getStorageType() const { return storageType; }
+ Type getQuantileType() const { return quantileType; }
+ ArrayRef<double> getQuantiles() const {
+ return ArrayRef<double>(quantilesData, numQuantiles);
+ }
+
+ Type storageType;
+ Type quantileType;
+ const double *quantilesData;
+ unsigned numQuantiles;
+};
+
} // namespace detail
} // namespace mlir
diff --git a/mlir/test/IR/invalid-quantile-types.mlir b/mlir/test/IR/invalid-quantile-types.mlir
new file mode 100644
index 0000000000000..5ee3e674b5cc1
--- /dev/null
+++ b/mlir/test/IR/invalid-quantile-types.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
+
+//===----------------------------------------------------------------------===//
+// Parser error tests
+//===----------------------------------------------------------------------===//
+
+// Test missing '<' after 'quantile' keyword.
+// expected-error @+1 {{expected '<' in quantile type}}
+func.func private @missing_lt() -> quantile ui4:f16, {1.0}>
+
+// -----
+
+// Test missing ':' between storage type and quantile type.
+// expected-error @+1 {{expected ':' in quantile type}}
+func.func private @missing_colon() -> quantile<ui4 f16, {1.0}>
+
+// -----
+
+// Test missing ',' between quantile type and quantile value list.
+// expected-error @+1 {{expected ',' in quantile type}}
+func.func private @missing_comma() -> quantile<ui4:f16 {1.0}>
+
+// -----
+
+// Test missing '{' before quantile value list.
+// expected-error @+1 {{expected '{' in quantile type}}
+func.func private @missing_lbrace() -> quantile<ui4:f16, 1.0}>
+
+// -----
+
+// Test missing '}' after quantile value list.
+// expected-error @+1 {{expected '}' in quantile type}}
+func.func private @missing_rbrace() -> quantile<ui4:f16, {1.0>
+
+// -----
+
+// Test missing '>' closing the quantile type.
+// expected-error @+1 {{expected '>' in quantile type}}
+func.func private @missing_gt() -> quantile<ui4:f16, {1.0}
diff --git a/mlir/test/IR/quantile-types.mlir b/mlir/test/IR/quantile-types.mlir
new file mode 100644
index 0000000000000..3c119ef91750d
--- /dev/null
+++ b/mlir/test/IR/quantile-types.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// Verify round-trip parsing and printing of the builtin quantile type.
+
+// CHECK: func private @quantile_ui4_f16(quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_ui4_f16(quantile<ui4:f16, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_si8_f32(quantile<si8:f32, {-1.000000e+00,-5.000000e-01,0.000000e+00,5.000000e-01,1.000000e+00}>)
+func.func private @quantile_si8_f32(quantile<si8:f32, {-1.0,-0.5,0.0,0.5,1.0}>) -> ()
+
+// CHECK: func private @quantile_i8_f32(quantile<i8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_i8_f32(quantile<i8:f32, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_as_return() -> quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>
+func.func private @quantile_as_return() -> quantile<ui4:f16, {-1.0,0.0,1.0}>
+
+// Verify use as memref element type (requires MemRefElementTypeInterface).
+// CHECK: func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify use as tensor element type.
+// CHECK: func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify use in multidimensional tensors.
+// CHECK: func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify use in unranked tensors.
+// CHECK: func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify NF4-style 16-entry quantile table
+// CHECK-LABEL: @nf4_16_values
+// CHECK-SAME: quantile<ui4:f16, {
+func.func private @nf4_16_values(quantile<ui4:f16, {
+ -1.0,-0.6961928009986877,-0.5250730514526367,-0.39491748809814453,
+ -0.28444138169288635,-0.18477343022823334,-0.09105003625154495,0.0,
+ 0.07958029955625534,0.16093020141124725,0.24611230194568634,
+ 0.33791524171829224,0.44070982933044434,0.5626170039176941,
+ 0.7229568362236023,1.0}>) -> ()
+
+// Verify negative-only quantile list.
+// CHECK: func private @quantile_negatives(quantile<si4:f32, {-1.000000e+00,-5.000000e-01,-2.500000e-01}>)
+func.func private @quantile_negatives(quantile<si4:f32, {-1.0,-0.5,-0.25}>) -> ()
+
+// Verify that a single quantile value is accepted.
+// CHECK: func private @quantile_single_value(quantile<si8:f32, {1.000000e+00}>)
+func.func private @quantile_single_value(quantile<si8:f32, {1.0}>) -> ()
>From c55bc3afd06e83a821bcada05911f464ef2c3e93 Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Fri, 3 Apr 2026 08:42:20 +0100
Subject: [PATCH 2/6] Added support for optional storage min/max, extended
testing for low fp storage types
---
.../mlir/Dialect/Quant/IR/QuantTypes.h | 6 +--
mlir/include/mlir/IR/BuiltinTypes.td | 18 +++++--
mlir/lib/AsmParser/TypeParser.cpp | 54 ++++++++++++++++++-
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 3 +-
mlir/lib/IR/AsmPrinter.cpp | 5 +-
mlir/lib/IR/BuiltinTypes.cpp | 22 +++++++-
mlir/lib/IR/TypeDetail.h | 31 +++++++----
.../Dialect/Quant/parse-uniform-invalid.mlir | 12 +++++
mlir/test/Dialect/Quant/parse-uniform.mlir | 19 +++++++
mlir/test/IR/invalid-quantile-types.mlir | 6 +++
mlir/test/IR/quantile-types.mlir | 8 +++
11 files changed, 163 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 34f47a15395c9..092a91e1c01d7 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -255,7 +255,7 @@ class AnyQuantizedType
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
-/// StorageType: 'i'|'u' NumBits
+/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
@@ -313,7 +313,7 @@ class UniformQuantizedType
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
-/// StorageType: 'i'|'u' NumBits
+/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
@@ -398,7 +398,7 @@ class UniformQuantizedPerAxisType
/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
/// ScaleZero ::= Scale (':' ZeroPoint)?
///
-/// StorageType: 'i'|'u' NumBits
+/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// AxisSpec: An integer value
/// BlockSizeSpec: An integer value
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 8acdf5f0a7703..ee97806cf78d5 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1126,7 +1126,7 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
Syntax:
```
- quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
+ quantile-type ::= `quantile` `<` type ( `<` int `,` int `>` )? `:` type `,` `{` float-list `}` `>`
```
A quantile type represents a quantile-based floating point encoding, where
@@ -1135,6 +1135,10 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
`si8`), a quantile type (the expressed floating-point precision, e.g.,
`f16`, `f32`), and an array of quantile values that define the mapping.
+ Optionally, explicit minimum and maximum storage values can be specified
+ after the storage type as `<min,max>`, overriding the defaults derived from
+ the storage type width.
+
This type is used for weight compression schemes like NF4 (NormalFloat4)
and similar quantile-based formats.
@@ -1142,20 +1146,25 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
```mlir
quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
+ quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}><-8,7>
```
}];
let parameters = (ins
"Type":$storageType,
"Type":$quantileType,
- ArrayRefParameter<"double">:$quantiles
+ ArrayRefParameter<"double", "quantile values">:$quantiles,
+ OptionalParameter<"std::optional<int64_t>", "explicit storage minimum">:$storageMin,
+ OptionalParameter<"std::optional<int64_t>", "explicit storage maximum">:$storageMax
);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$storageType,
"Type":$quantileType,
- "ArrayRef<double>":$quantiles), [{
+ "ArrayRef<double>":$quantiles,
+ CArg<"std::optional<int64_t>", "std::nullopt">:$storageMin,
+ CArg<"std::optional<int64_t>", "std::nullopt">:$storageMax), [{
return $_get(storageType.getContext(), storageType, quantileType,
- quantiles);
+ quantiles, storageMin, storageMax);
}]>
];
let skipDefaultBuilders = 1;
@@ -1193,6 +1202,7 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
}];
}
+
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 906c2cf70b605..ad963a462db79 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include <cstdint>
#include <limits>
@@ -435,7 +436,58 @@ Type Parser::parseQuantileType() {
if (parseToken(Token::greater, "expected '>' in quantile type"))
return nullptr;
- return QuantileType::get(storageType, quantileType, quantiles);
+ // Optionally parse explicit storage range: `<min:max>`.
+ std::optional<int64_t> storageMin, storageMax;
+ if (consumeIf(Token::less)) {
+ int64_t minVal, maxVal;
+
+ // Parse minimum value (with optional sign).
+ bool minNegative = consumeIf(Token::minus);
+ if (!getToken().is(Token::integer))
+ return (emitWrongTokenError(
+ "expected integer minimum in quantile storage range"),
+ nullptr);
+ SMLoc minLoc = getToken().getLoc();
+ if (getToken().getSpelling().getAsInteger(10, minVal))
+ return nullptr;
+ consumeToken(Token::integer);
+ minVal = minNegative ? -minVal : minVal;
+
+ if (parseToken(Token::colon, "expected ':' in quantile storage range"))
+ return nullptr;
+
+ // Parse maximum value (with optional sign).
+ bool maxNegative = consumeIf(Token::minus);
+ if (!getToken().is(Token::integer))
+ return (emitWrongTokenError(
+ "expected integer maximum in quantile storage range"),
+ nullptr);
+ SMLoc maxLoc = getToken().getLoc();
+ if (getToken().getSpelling().getAsInteger(10, maxVal))
+ return nullptr;
+ consumeToken(Token::integer);
+ maxVal = maxNegative ? -maxVal : maxVal;
+
+ if (parseToken(Token::greater, "expected '>' after quantile storage range"))
+ return nullptr;
+
+ // Validate against the underlying storage type's inherent limits.
+ if (auto qsIface = llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
+ bool isSigned = qsIface.shouldDefaultToSigned();
+ if (minVal < qsIface.getDefaultMinimum(isSigned))
+ return (emitError(minLoc, "illegal storage type minimum: ") << minVal,
+ nullptr);
+ if (maxVal > qsIface.getDefaultMaximum(isSigned))
+ return (emitError(maxLoc, "illegal storage type maximum: ") << maxVal,
+ nullptr);
+ }
+
+ storageMin = minVal;
+ storageMax = maxVal;
+ }
+
+ return QuantileType::get(storageType, quantileType, quantiles, storageMin,
+ storageMax);
}
/// Parse a tensor type.
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 1a42b90ac31e2..5bffe43c818f9 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -320,7 +320,8 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// block-size-info `,` scale-zero-tensor `>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
-/// storage-type ::= (`i` | `u`) integer-literal
+/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN`
+// | `f4E2M1FN` | 'quantile'
/// expressed-type-spec ::= `:` `f` integer-literal
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= scale (`:` zero-point)?
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e1145eaafbc64..3f6600173fc5d 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2907,13 +2907,14 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(quantileTy.getQuantileType());
os << ", {";
ArrayRef<double> quantiles = quantileTy.getQuantiles();
- // interleaveComma(llvm::seq<size_t>(0, quantiles.size()),
- // [&](size_t i) { os << quantiles[i]; });
llvm::interleave(
llvm::seq<size_t>(0, quantiles.size()), os,
[&](size_t index) { os << quantiles[index]; }, ",");
os << "}";
os << '>';
+ if (auto minVal = quantileTy.getStorageMin())
+ if (auto maxVal = quantileTy.getStorageMax())
+ os << '<' << *minVal << ':' << *maxVal << '>';
})
.Case([&](TupleType tupleTy) {
os << "tuple<";
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index faf38e6a8354e..6429928b812a0 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -106,15 +106,28 @@ ArrayRef<double> QuantileType::getQuantiles() const {
return getImpl()->getQuantiles();
}
+std::optional<int64_t> QuantileType::getStorageMin() const {
+ return getImpl()->getStorageMin();
+}
+
+std::optional<int64_t> QuantileType::getStorageMax() const {
+ return getImpl()->getStorageMax();
+}
+
LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
Type storageType, Type quantileType,
- ArrayRef<double> quantiles) {
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin,
+ std::optional<int64_t> storageMax) {
if (!storageType.isIntOrFloat())
return emitError() << "storage type must be an integer or float type";
if (!llvm::isa<FloatType>(quantileType))
return emitError() << "quantile type must be a float type";
if (quantiles.empty())
return emitError() << "quantile values must not be empty";
+ if (storageMin.has_value() != storageMax.has_value())
+ return emitError()
+ << "storage min and max must both be specified or both omitted";
return success();
}
@@ -130,12 +143,16 @@ unsigned QuantileType::getStorageWidth() const {
}
int64_t QuantileType::getDefaultMaximum(bool isSigned) const {
+ if (auto explicitMax = getStorageMax())
+ return *explicitMax;
if (isSigned)
return (1LL << (getStorageWidth() - 1)) - 1;
return (1LL << getStorageWidth()) - 1;
}
int64_t QuantileType::getDefaultMinimum(bool isSigned) const {
+ if (auto explicitMin = getStorageMin())
+ return *explicitMin;
if (isSigned)
return -(1LL << (getStorageWidth() - 1));
return 0;
@@ -150,6 +167,9 @@ std::string QuantileType::getStorageTypeName(bool isSigned) const {
llvm::seq<size_t>(0, quantiles.size()), os,
[&](size_t index) { os << quantiles[index]; }, ",");
os << "}>";
+ if (auto minVal = getStorageMin())
+ if (auto maxVal = getStorageMax())
+ os << '<' << *minVal << ':' << *maxVal << '>';
return result;
}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 6afcc6307d4b4..9e805ba25c3f4 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -151,35 +151,44 @@ unsigned getMemorySpaceAsInt(Attribute memorySpace);
/// Quantile Type Storage and Uniquing.
struct QuantileTypeStorage : public TypeStorage {
QuantileTypeStorage(Type storageType, Type quantileType,
- ArrayRef<double> quantiles)
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin,
+ std::optional<int64_t> storageMax)
: storageType(storageType), quantileType(quantileType),
- quantilesData(quantiles.data()), numQuantiles(quantiles.size()) {}
+ quantilesData(quantiles.data()), numQuantiles(quantiles.size()),
+ storageMin(storageMin), storageMax(storageMax) {}
- /// The hash key used for uniquing.
- using KeyTy = std::tuple<Type, Type, ArrayRef<double>>;
+ using KeyTy = std::tuple<Type, Type, ArrayRef<double>, std::optional<int64_t>,
+ std::optional<int64_t>>;
static llvm::hash_code hashKey(const KeyTy &key) {
auto quantiles = std::get<2>(key);
- // Bit-cast doubles to int64_t for hashing since LLVM hashing
- // does not natively support double.
auto *quantilesBits = llvm::bit_cast<const int64_t *>(quantiles.data());
ArrayRef<int64_t> quantilesAsInts(quantilesBits, quantiles.size());
+ auto hashOptInt = [](std::optional<int64_t> opt) -> llvm::hash_code {
+ return opt ? llvm::hash_combine(true, *opt)
+ : llvm::hash_combine(false, int64_t{0});
+ };
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
llvm::hash_combine_range(quantilesAsInts.begin(),
- quantilesAsInts.end()));
+ quantilesAsInts.end()),
+ hashOptInt(std::get<3>(key)),
+ hashOptInt(std::get<4>(key)));
}
bool operator==(const KeyTy &key) const {
return storageType == std::get<0>(key) &&
quantileType == std::get<1>(key) &&
- getQuantiles() == std::get<2>(key);
+ getQuantiles() == std::get<2>(key) &&
+ storageMin == std::get<3>(key) && storageMax == std::get<4>(key);
}
static QuantileTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
ArrayRef<double> quantiles = allocator.copyInto(std::get<2>(key));
return new (allocator.allocate<QuantileTypeStorage>())
- QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles);
+ QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles,
+ std::get<3>(key), std::get<4>(key));
}
Type getStorageType() const { return storageType; }
@@ -187,11 +196,15 @@ struct QuantileTypeStorage : public TypeStorage {
ArrayRef<double> getQuantiles() const {
return ArrayRef<double>(quantilesData, numQuantiles);
}
+ std::optional<int64_t> getStorageMin() const { return storageMin; }
+ std::optional<int64_t> getStorageMax() const { return storageMax; }
Type storageType;
Type quantileType;
const double *quantilesData;
unsigned numQuantiles;
+ std::optional<int64_t> storageMin;
+ std::optional<int64_t> storageMax;
};
} // namespace detail
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index 6dbc86263bd71..329b1a98dc965 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -264,3 +264,15 @@
// Illegal storage min/max: min < defaultMin
// expected-error at +1 {{illegal storage type minimum: -10}}
!qalias = !quant.uniform<f4E2M1FN<-10:6>:f32, 0.99872:127>
+
+// -----
+
+// Illegal storage min/max: max > defaultMax
+// expected-error at +1 {{illegal storage type maximum: 10}}
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,1.0}><-6:100>:f32, 0.99872:127>
+
+// -----
+
+// Illegal storage min/max: min < defaultMin
+// expected-error at +1 {{illegal storage type minimum: -10}}
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,1.0}><-100:6>:f32, 0.99872:127>
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index a8b9e5707b474..009c9001bdfa2 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -245,3 +245,22 @@ func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
+
+// -----
+// Storage type: QuantileType
+// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}><-6:6>:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><-6:6>:f32, 0.99872:127 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: QuantileType
+// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>:f32, 2.000000e+02>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}>:f32, 2.0e+2 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
diff --git a/mlir/test/IR/invalid-quantile-types.mlir b/mlir/test/IR/invalid-quantile-types.mlir
index 5ee3e674b5cc1..500df1a83f275 100644
--- a/mlir/test/IR/invalid-quantile-types.mlir
+++ b/mlir/test/IR/invalid-quantile-types.mlir
@@ -37,3 +37,9 @@ func.func private @missing_rbrace() -> quantile<ui4:f16, {1.0>
// Test missing '>' closing the quantile type.
// expected-error @+1 {{expected '>' in quantile type}}
func.func private @missing_gt() -> quantile<ui4:f16, {1.0}
+
+// -----
+
+// Test missing '>' closing the storage range.
+// expected-error @below {{expected '>' after quantile storage range}}
+func.func private @missing_range_gt() -> quantile<ui4:f16, {1.0}><-8:7
diff --git a/mlir/test/IR/quantile-types.mlir b/mlir/test/IR/quantile-types.mlir
index 3c119ef91750d..bf6af88c5b83a 100644
--- a/mlir/test/IR/quantile-types.mlir
+++ b/mlir/test/IR/quantile-types.mlir
@@ -46,6 +46,14 @@ func.func private @nf4_16_values(quantile<ui4:f16, {
0.33791524171829224,0.44070982933044434,0.5626170039176941,
0.7229568362236023,1.0}>) -> ()
+// Verify explicit storage min/max range (unsigned storage, narrowed range).
+// CHECK: func private @quantile_with_range(quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}><0:7>)
+func.func private @quantile_with_range(quantile<ui4:f16, {-1.0,0.0,1.0}><0:7>) -> ()
+
+// Verify explicit range is preserved through round-trip.
+// CHECK: func private @quantile_signed_range(quantile<si8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-100:100>)
+func.func private @quantile_signed_range(quantile<si8:f32, {-1.0,0.0,1.0}><-100:100>) -> ()
+
// Verify negative-only quantile list.
// CHECK: func private @quantile_negatives(quantile<si4:f32, {-1.000000e+00,-5.000000e-01,-2.500000e-01}>)
func.func private @quantile_negatives(quantile<si4:f32, {-1.0,-0.5,-0.25}>) -> ()
>From a5c22c66b9252591605973e5db2b0b0875b864d5 Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Tue, 21 Apr 2026 11:34:33 +0100
Subject: [PATCH 3/6] Added semantic contract for LUT values, adapted tests and
addressed reviews
---
.../mlir/Dialect/Quant/IR/QuantTypes.h | 2 +-
mlir/include/mlir/IR/BuiltinTypes.td | 2 +-
mlir/lib/AsmParser/TypeParser.cpp | 9 +-
mlir/lib/IR/BuiltinTypes.cpp | 25 ++++++
.../Dialect/Quant/parse-uniform-invalid.mlir | 16 +++-
mlir/test/Dialect/Quant/parse-uniform.mlir | 6 +-
mlir/test/IR/quantile-types.mlir | 82 +++++++++++--------
7 files changed, 101 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 092a91e1c01d7..09c313fadb22e 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -255,7 +255,7 @@ class AnyQuantizedType
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
-/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
+/// StorageType: 'i'|'u' NumBits, 'f4', 'F8E5M2', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index ee97806cf78d5..24d05a148435f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1139,7 +1139,7 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
after the storage type as `<min,max>`, overriding the defaults derived from
the storage type width.
- This type is used for weight compression schemes like NF4 (NormalFloat4)
+ This type is used for weight compression schemes like NF4 (NormalizedFloat4)
and similar quantile-based formats.
#### Example:
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index ad963a462db79..e6c7592c24aa7 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -392,6 +392,7 @@ Type Parser::parseNonFunctionType() {
/// quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
///
Type Parser::parseQuantileType() {
+ SMLoc typeLoc = getToken().getLoc();
consumeToken(Token::kw_quantile);
if (parseToken(Token::less, "expected '<' in quantile type"))
@@ -486,8 +487,12 @@ Type Parser::parseQuantileType() {
storageMax = maxVal;
}
- return QuantileType::get(storageType, quantileType, quantiles, storageMin,
- storageMax);
+ auto type = QuantileType::getChecked([&]() { return emitError(typeLoc); },
+ storageType, quantileType, quantiles,
+ storageMin, storageMax);
+ if (!type)
+ return nullptr;
+ return type;
}
/// Parse a tensor type.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6429928b812a0..075214993cd81 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -128,6 +128,31 @@ LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
if (storageMin.has_value() != storageMax.has_value())
return emitError()
<< "storage min and max must both be specified or both omitted";
+
+ // Validate explicit storage range.
+ if (storageMin && storageMax && *storageMin >= *storageMax)
+ return emitError() << "storage min must be less than storage max";
+
+ unsigned width = storageType.getIntOrFloatBitWidth();
+ bool isSigned = !llvm::isa<IntegerType>(storageType) ||
+ !llvm::cast<IntegerType>(storageType).isUnsigned();
+ auto effectiveMin =
+ storageMin.value_or(isSigned ? -(1LL << (width - 1)) : 0LL);
+ auto effectiveMax = storageMax.value_or(isSigned ? (1LL << (width - 1)) - 1
+ : (1LL << width) - 1);
+ auto expectedSize = effectiveMax - effectiveMin + 1;
+ if (static_cast<decltype(expectedSize)>(quantiles.size()) != expectedSize)
+ return emitError() << "quantile LUT size (" << quantiles.size()
+ << ") must equal the number of representable storage "
+ "values ("
+ << expectedSize << ")";
+
+ // No NaN or infinity allowed in the LUT.
+ for (double v : quantiles)
+ if (std::isnan(v) || std::isinf(v))
+ return emitError() << "quantile values must be finite (no NaN or "
+ "infinity)";
+
return success();
}
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index 329b1a98dc965..b3867f2dc35fa 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -269,10 +269,22 @@
// Illegal storage min/max: max > defaultMax
// expected-error at +1 {{illegal storage type maximum: 10}}
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,1.0}><-6:100>:f32, 0.99872:127>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><6:10>:f32, 0.99872:127>
// -----
// Illegal storage min/max: min < defaultMin
// expected-error at +1 {{illegal storage type minimum: -10}}
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,1.0}><-100:6>:f32, 0.99872:127>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><-10:-6>:f32, 0.99872:127>
+
+// -----
+
+// Quantile storage range: min must be strictly less than max.
+// expected-error at +1 {{storage min must be less than storage max}}
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16,{-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><5:3>:f32, 0.99872:127>
+
+// -----
+
+// Quantile LUT size (3) does not match the 16 representable values of f4E2M1FN's default range.
+// expected-error at +1 {{quantile LUT size (3) must equal the number of representable storage values (16)}}
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,0.0,1.0}>:f32, 0.99872:127>
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 009c9001bdfa2..27bb6dc98470d 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -247,9 +247,9 @@ func.func @parse() -> !qalias {
}
// -----
-// Storage type: QuantileType
-// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}><-6:6>:f32, 9.987200e-01:127>
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><-6:6>:f32, 0.99872:127 >
+// Storage type: QuantileType with narrowed explicit range <-6:6> (13 representable values).
+// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.750000e-01,-7.500000e-01,-6.250000e-01,-5.000000e-01,-2.500000e-01,0.000000e+00,2.500000e-01,5.000000e-01,6.250000e-01,7.500000e-01,8.750000e-01,1.000000e+00}><-6:6>:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,-0.875,-0.75,-0.625,-0.5,-0.25,0.0,0.25,0.5,0.625,0.75,0.875,1.0}><-6:6>:f32, 0.99872:127 >
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
diff --git a/mlir/test/IR/quantile-types.mlir b/mlir/test/IR/quantile-types.mlir
index bf6af88c5b83a..0950e4a96678a 100644
--- a/mlir/test/IR/quantile-types.mlir
+++ b/mlir/test/IR/quantile-types.mlir
@@ -2,39 +2,46 @@
// Verify round-trip parsing and printing of the builtin quantile type.
-// CHECK: func private @quantile_ui4_f16(quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
-func.func private @quantile_ui4_f16(quantile<ui4:f16, {-1.0,0.0,1.0}>) -> ()
+// CHECK-LABEL: func private @quantile_ui4_f16
+// CHECK-SAME: quantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>
+func.func private @quantile_ui4_f16(quantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>) -> ()
-// CHECK: func private @quantile_si8_f32(quantile<si8:f32, {-1.000000e+00,-5.000000e-01,0.000000e+00,5.000000e-01,1.000000e+00}>)
-func.func private @quantile_si8_f32(quantile<si8:f32, {-1.0,-0.5,0.0,0.5,1.0}>) -> ()
+// CHECK: func private @quantile_si8_f32(quantile<si8:f32, {-1.000000e+00,-5.000000e-01,0.000000e+00,5.000000e-01,1.000000e+00}><-2:2>)
+func.func private @quantile_si8_f32(quantile<si8:f32, {-1.0,-0.5,0.0,0.5,1.0}><-2:2>) -> ()
-// CHECK: func private @quantile_i8_f32(quantile<i8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
-func.func private @quantile_i8_f32(quantile<i8:f32, {-1.0,0.0,1.0}>) -> ()
+// CHECK: func private @quantile_i8_f32(quantile<i8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-1:1>)
+func.func private @quantile_i8_f32(quantile<i8:f32, {-1.0,0.0,1.0}><-1:1>) -> ()
-// CHECK: func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
-func.func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.0,0.0,1.0}>) -> ()
+// CHECK: func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-1:1>)
+func.func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.0,0.0,1.0}><-1:1>) -> ()
-// CHECK: func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
-func.func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.0,0.0,1.0}>) -> ()
+// CHECK-LABEL: func private @quantile_ui4_bf16
+// CHECK-SAME: quantile<ui4:bf16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>
+func.func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>) -> ()
-// CHECK: func private @quantile_as_return() -> quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>
-func.func private @quantile_as_return() -> quantile<ui4:f16, {-1.0,0.0,1.0}>
+// CHECK-LABEL: func private @quantile_as_return
+// CHECK-SAME: quantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>
+func.func private @quantile_as_return() -> quantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>
// Verify use as memref element type (requires MemRefElementTypeInterface).
-// CHECK: func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
-func.func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+// CHECK-LABEL: func private @quantile_in_memref
+// CHECK-SAME: memref<8xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
+func.func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
// Verify use as tensor element type.
-// CHECK: func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
-func.func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+// CHECK-LABEL: func private @quantile_in_tensor
+// CHECK-SAME: tensor<16xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
+func.func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
// Verify use in multidimensional tensors.
-// CHECK: func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
-func.func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+// CHECK-LABEL: func private @quantile_in_ranked_tensor
+// CHECK-SAME: tensor<16x16x1x1xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
+func.func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
// Verify use in unranked tensors.
-// CHECK: func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
-func.func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+// CHECK-LABEL: func private @quantile_in_unranked_tensor
+// CHECK-SAME: tensor<*xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
+func.func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
// Verify NF4-style 16-entry quantile table
// CHECK-LABEL: @nf4_16_values
@@ -47,17 +54,28 @@ func.func private @nf4_16_values(quantile<ui4:f16, {
0.7229568362236023,1.0}>) -> ()
// Verify explicit storage min/max range (unsigned storage, narrowed range).
-// CHECK: func private @quantile_with_range(quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}><0:7>)
-func.func private @quantile_with_range(quantile<ui4:f16, {-1.0,0.0,1.0}><0:7>) -> ()
+// CHECK: func private @quantile_with_range(quantile<ui4:f16, {-1.000000e+00,-7.500000e-01,-5.000000e-01,-2.500000e-01,0.000000e+00,2.500000e-01,5.000000e-01,1.000000e+00}><0:7>)
+func.func private @quantile_with_range(quantile<ui4:f16, {-1.0,-0.75,-0.5,-0.25,0.0,0.25,0.5,1.0}><0:7>) -> ()
// Verify explicit range is preserved through round-trip.
-// CHECK: func private @quantile_signed_range(quantile<si8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-100:100>)
-func.func private @quantile_signed_range(quantile<si8:f32, {-1.0,0.0,1.0}><-100:100>) -> ()
-
-// Verify negative-only quantile list.
-// CHECK: func private @quantile_negatives(quantile<si4:f32, {-1.000000e+00,-5.000000e-01,-2.500000e-01}>)
-func.func private @quantile_negatives(quantile<si4:f32, {-1.0,-0.5,-0.25}>) -> ()
-
-// Verify that a single quantile value is accepted.
-// CHECK: func private @quantile_single_value(quantile<si8:f32, {1.000000e+00}>)
-func.func private @quantile_single_value(quantile<si8:f32, {1.0}>) -> ()
+// CHECK: func private @quantile_signed_range(quantile<si8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-1:1>)
+func.func private @quantile_signed_range(quantile<si8:f32, {-1.0,0.0,1.0}><-1:1>) -> ()
+
+// Verify signed 4-bit storage type uses full 16-entry LUT (all-negative values).
+// CHECK-LABEL: func private @quantile_negatives
+// CHECK-SAME: quantile<si4:f32, {-2.000000e+00,-1.875000e+00,-1.750000e+00,-1.625000e+00,-1.500000e+00,-1.375000e+00,-1.250000e+00,-1.125000e+00,-1.000000e+00,-8.750000e-01,-7.500000e-01,-6.250000e-01,-5.000000e-01,-3.750000e-01,-2.500000e-01,-1.250000e-01}>
+func.func private @quantile_negatives(quantile<si4:f32, {-2.0,-1.875,-1.75,-1.625,-1.5,-1.375,-1.25,-1.125,-1.0,-0.875,-0.75,-0.625,-0.5,-0.375,-0.25,-0.125}>) -> ()
+
+// Verify minimal 2-entry LUT for 1-bit unsigned storage type.
+// CHECK: func private @quantile_ui1_f16(quantile<ui1:f16, {-1.000000e+00,1.000000e+00}>)
+func.func private @quantile_ui1_f16(quantile<ui1:f16, {-1.0,1.0}>) -> ()
+
+// Verify LUT values in descending order
+// Storage is ui4 with explicit <0:7> range (8 entries).
+// CHECK: func private @quantile_descending(quantile<ui4:f16, {1.000000e+00,7.500000e-01,5.000000e-01,2.500000e-01,0.000000e+00,-2.500000e-01,-5.000000e-01,-1.000000e+00}><0:7>)
+func.func private @quantile_descending(quantile<ui4:f16, {1.0,0.75,0.5,0.25,0.0,-0.25,-0.5,-1.0}><0:7>) -> ()
+
+// Verify LUT values in an arbitrary order
+// Storage is ui4 with explicit <0:7> range (8 entries).
+// CHECK: func private @quantile_random_order(quantile<ui4:f16, {0.000000e+00,-5.000000e-01,1.000000e+00,-2.500000e-01,7.500000e-01,-1.000000e+00,5.000000e-01,2.500000e-01}><0:7>)
+func.func private @quantile_random_order(quantile<ui4:f16, {0.0,-0.5,1.0,-0.25,0.75,-1.0,0.5,0.25}><0:7>) -> ()
>From 50126daf26e24701145dda7296132f43225832ee Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Tue, 21 Apr 2026 15:15:43 +0100
Subject: [PATCH 4/6] Added support for Signless Integer in verifier
---
mlir/lib/IR/BuiltinTypes.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 075214993cd81..0a642fd2bed4e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -135,7 +135,7 @@ LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned width = storageType.getIntOrFloatBitWidth();
bool isSigned = !llvm::isa<IntegerType>(storageType) ||
- !llvm::cast<IntegerType>(storageType).isUnsigned();
+ llvm::cast<IntegerType>(storageType).isSigned();
auto effectiveMin =
storageMin.value_or(isSigned ? -(1LL << (width - 1)) : 0LL);
auto effectiveMax = storageMax.value_or(isSigned ? (1LL << (width - 1)) - 1
@@ -158,7 +158,7 @@ LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
bool QuantileType::shouldDefaultToSigned() const {
if (auto intType = llvm::dyn_cast<IntegerType>(getStorageType()))
- return !intType.isUnsigned();
+ return intType.isSigned();
// Float types default to signed.
return true;
}
>From 1eb70f72aac9486995e97b0e2bd4bd0ebbc046ee Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Tue, 26 May 2026 11:39:13 +0100
Subject: [PATCH 5/6] moved QuantileType definition to QuantDialect, updated
tests
---
.../mlir/Dialect/Quant/IR/QuantTypes.h | 83 +++++++++++
.../Dialect/Quant/IR/detail}/TypeDetail.h | 80 ++++++++++
mlir/include/mlir/IR/BuiltinTypes.h | 1 -
mlir/include/mlir/IR/BuiltinTypes.td | 86 -----------
mlir/lib/AsmParser/Parser.h | 3 -
mlir/lib/AsmParser/TokenKinds.def | 1 -
mlir/lib/AsmParser/TypeParser.cpp | 116 ---------------
mlir/lib/Dialect/Quant/IR/QuantOps.cpp | 5 +-
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 140 +++++++++++++++++-
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 75 ++++++++++
mlir/lib/IR/AsmPrinter.cpp | 16 --
mlir/lib/IR/BuiltinTypes.cpp | 119 +--------------
mlir/lib/IR/TypeDetail.h | 59 --------
.../Dialect/Quant/invalid-quantile-types.mlir | 42 ++++++
.../Dialect/Quant/parse-uniform-invalid.mlir | 16 +-
mlir/test/Dialect/Quant/parse-uniform.mlir | 10 +-
mlir/test/Dialect/Quant/quantile-types.mlir | 76 ++++++++++
mlir/test/IR/invalid-quantile-types.mlir | 45 ------
mlir/test/IR/quantile-types.mlir | 81 ----------
19 files changed, 509 insertions(+), 545 deletions(-)
rename mlir/{lib/Dialect/Quant/IR => include/mlir/Dialect/Quant/IR/detail}/TypeDetail.h (83%)
create mode 100644 mlir/test/Dialect/Quant/invalid-quantile-types.mlir
create mode 100644 mlir/test/Dialect/Quant/quantile-types.mlir
delete mode 100644 mlir/test/IR/invalid-quantile-types.mlir
delete mode 100644 mlir/test/IR/quantile-types.mlir
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 09c313fadb22e..7ef1af8d6faf4 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
+#include "mlir/Dialect/Quant/IR/detail/TypeDetail.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -18,6 +19,7 @@
#include "llvm/Support/MathExtras.h"
namespace mlir {
+
namespace quant {
namespace detail {
@@ -547,6 +549,87 @@ class CalibratedQuantizedType
double getMax() const;
};
+class QuantileType
+ : public Type::TypeBase<QuantileType, QuantizedType,
+ detail::QuantileTypeStorage,
+ mlir::QuantStorageTypeInterface::Trait> {
+public:
+ using ImplType = detail::QuantileTypeStorage;
+ using Base::Base;
+
+ // Get the underlying type used for to store raw values.
+ Type getStorageType() const;
+
+ // Get primitive expressed type of data in quantiles.
+ // Note that we may convert FP8 data to FP16 for storage,
+ // but we should treat its expressed type as FP8 rather than FP16.
+ Type getQuantileType() const;
+
+ /// Return the quantile table of this float type.
+ ArrayRef<double> getQuantiles() const;
+
+ /// Return the explicit storage minimum, if set.
+ std::optional<int64_t> getStorageMin() const;
+
+ /// Return the explicit storage maximum, if set.
+ std::optional<int64_t> getStorageMax() const;
+
+ // Get a quantile float type with specified quantile table.
+ static QuantileType get(mlir::MLIRContext *ctx, Type storageType,
+ Type quantileType, ArrayRef<double> quantiles = {},
+ std::optional<int64_t> storageMin = std::nullopt,
+ std::optional<int64_t> storageMax = std::nullopt);
+
+ static QuantileType
+ getChecked(function_ref<InFlightDiagnostic()> emitError,
+ mlir::MLIRContext *ctx, Type storageType, Type quantileType,
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin = std::nullopt,
+ std::optional<int64_t> storageMax = std::nullopt);
+
+ static LogicalResult verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, Type storageType,
+ Type quantileType, ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin, std::optional<int64_t> storageMax);
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(mlir::Type type);
+
+ // Printer
+ void print(mlir::AsmPrinter &printer) const;
+
+ static constexpr llvm::StringLiteral getMnemonic() { return {"quantile"}; }
+
+ static constexpr llvm::StringLiteral name = "quantile";
+
+ // Returns true if the type defaults to signed (e.g., si8, i8 or float types),
+ // false otherwise
+ bool shouldDefaultToSigned() const;
+
+ // Get the bit width of the storage type.
+ unsigned getStorageWidth() const;
+
+ // Get the default minimum and maximum values for the storage type.
+ int64_t getDefaultMinimum([[maybe_unused]] bool isSigned) const;
+ int64_t getDefaultMaximum([[maybe_unused]] bool isSigned) const;
+
+ // Get the string representation of the storage type
+ std::string getStorageTypeName([[maybe_unused]] bool isSigned) const;
+
+ // Get whether the type is a packed quantile float type
+ bool isPacked() const;
+
+ // Get the logical bit width of the quantile float type, which is the bit
+ // width of the represented floating point value.
+ unsigned getLogicalBitWidth() const;
+
+ // Get the number of quantized values stored in one byte for this quantile
+ // float type.
+ unsigned getElementsPerByte() const;
+
+ // Get the preferred alignment in bytes for this quantile float type, if any.
+ std::optional<unsigned> getPreferredAlignmentBytes() const;
+};
} // namespace quant
} // namespace mlir
diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/include/mlir/Dialect/Quant/IR/detail/TypeDetail.h
similarity index 83%
rename from mlir/lib/Dialect/Quant/IR/TypeDetail.h
rename to mlir/include/mlir/Dialect/Quant/IR/detail/TypeDetail.h
index a43bce354c324..8834ee7901c18 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/detail/TypeDetail.h
@@ -422,6 +422,86 @@ struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
double max;
};
+struct QuantileTypeStorage : public mlir::TypeStorage {
+ mlir::Type storageType;
+ mlir::Type quantileType;
+ const double *quantilesElements;
+ size_t quantilesParamsSize;
+ std::optional<int64_t> storageMin;
+ std::optional<int64_t> storageMax;
+
+ struct KeyTy {
+ KeyTy(mlir::Type storageType, mlir::Type quantileType,
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin = std::nullopt,
+ std::optional<int64_t> storageMax = std::nullopt)
+ : storageType(storageType), quantileType(quantileType),
+ quantiles(quantiles), storageMin(storageMin), storageMax(storageMax) {
+ }
+
+ mlir::Type storageType;
+ mlir::Type quantileType;
+ ArrayRef<double> quantiles;
+ std::optional<int64_t> storageMin;
+ std::optional<int64_t> storageMax;
+
+ mlir::Type getQuantileType() const { return quantileType; }
+ ArrayRef<double> getQuantiles() const { return quantiles; }
+
+ bool operator==(const KeyTy &other) const {
+ return storageType == other.storageType &&
+ quantileType == other.quantileType &&
+ quantiles == other.quantiles && storageMin == other.storageMin &&
+ storageMax == other.storageMax;
+ }
+
+ static llvm::hash_code hashOptInt(std::optional<int64_t> opt) {
+ return opt ? llvm::hash_combine(true, *opt)
+ : llvm::hash_combine(false, int64_t{0});
+ }
+
+ unsigned getHashValue() const {
+ const int64_t *quantilesCast =
+ llvm::bit_cast<const int64_t *>(quantiles.data());
+ ArrayRef<int64_t> quantilesBits(quantilesCast, quantiles.size());
+ return static_cast<unsigned>(llvm::hash_combine(
+ llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()),
+ storageType, quantileType, hashOptInt(storageMin),
+ hashOptInt(storageMax)));
+ }
+ };
+
+ bool operator==(const KeyTy &key) const {
+ return storageType == key.storageType && quantileType == key.quantileType &&
+ getQuantiles() == key.quantiles && storageMin == key.storageMin &&
+ storageMax == key.storageMax;
+ }
+
+ QuantileTypeStorage(const KeyTy &key, ArrayRef<double> quantiles)
+ : storageType(key.storageType), quantileType(key.quantileType),
+ quantilesElements(quantiles.data()),
+ quantilesParamsSize(quantiles.size()), storageMin(key.storageMin),
+ storageMax(key.storageMax) {}
+
+ static QuantileTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+ KeyTy key) {
+ ArrayRef<double> quantiles = allocator.copyInto(key.quantiles);
+ return new (allocator.allocate<QuantileTypeStorage>())
+ QuantileTypeStorage(key, quantiles);
+ }
+
+ static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+ ArrayRef<double> getQuantiles() const {
+ return ArrayRef<double>(quantilesElements, quantilesParamsSize);
+ }
+
+ mlir::Type getStorageType() const { return storageType; }
+ mlir::Type getQuantileType() const { return quantileType; }
+ std::optional<int64_t> getStorageMin() const { return storageMin; }
+ std::optional<int64_t> getStorageMax() const { return storageMax; }
+};
+
} // namespace detail
} // namespace quant
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index ef47850929beb..d30cba29c9814 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -35,7 +35,6 @@ class TypeRange;
namespace detail {
struct FunctionTypeStorage;
struct IntegerTypeStorage;
-struct QuantileTypeStorage;
struct TupleTypeStorage;
} // namespace detail
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 24d05a148435f..612cb793e3ecb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1115,92 +1115,6 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
let genVerifyDecl = 1;
}
-//===----------------------------------------------------------------------===//
-// QuantileType
-//===----------------------------------------------------------------------===//
-
-def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
- [QuantStorageTypeInterface, MemRefElementTypeInterface]> {
- let summary = "Quantile-based type with a lookup table of quantile values";
- let description = [{
- Syntax:
-
- ```
- quantile-type ::= `quantile` `<` type ( `<` int `,` int `>` )? `:` type `,` `{` float-list `}` `>`
- ```
-
- A quantile type represents a quantile-based floating point encoding, where
- discrete storage values map to floating-point values via a quantile lookup
- table. The type has a storage type (how raw values are stored, e.g., `ui4`,
- `si8`), a quantile type (the expressed floating-point precision, e.g.,
- `f16`, `f32`), and an array of quantile values that define the mapping.
-
- Optionally, explicit minimum and maximum storage values can be specified
- after the storage type as `<min,max>`, overriding the defaults derived from
- the storage type width.
-
- This type is used for weight compression schemes like NF4 (NormalizedFloat4)
- and similar quantile-based formats.
-
- #### Example:
-
- ```mlir
- quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
- quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}><-8,7>
- ```
- }];
- let parameters = (ins
- "Type":$storageType,
- "Type":$quantileType,
- ArrayRefParameter<"double", "quantile values">:$quantiles,
- OptionalParameter<"std::optional<int64_t>", "explicit storage minimum">:$storageMin,
- OptionalParameter<"std::optional<int64_t>", "explicit storage maximum">:$storageMax
- );
- let builders = [
- TypeBuilderWithInferredContext<(ins
- "Type":$storageType,
- "Type":$quantileType,
- "ArrayRef<double>":$quantiles,
- CArg<"std::optional<int64_t>", "std::nullopt">:$storageMin,
- CArg<"std::optional<int64_t>", "std::nullopt">:$storageMax), [{
- return $_get(storageType.getContext(), storageType, quantileType,
- quantiles, storageMin, storageMax);
- }]>
- ];
- let skipDefaultBuilders = 1;
- let genStorageClass = 0;
- let genVerifyDecl = 1;
- let extraClassDeclaration = [{
- /// QuantStorageTypeInterface method implementations.
-
- /// Returns true if the storage type defaults to signed.
- bool shouldDefaultToSigned() const;
-
- /// Get the bit width of the storage type.
- unsigned getStorageWidth() const;
-
- /// Get default minimum value for the storage type.
- int64_t getDefaultMinimum(bool isSigned) const;
-
- /// Get default maximum value for the storage type.
- int64_t getDefaultMaximum(bool isSigned) const;
-
- /// Get the storage type as a string.
- std::string getStorageTypeName(bool isSigned) const;
-
- /// Check if the storage type uses packed representation.
- bool isPacked() const;
-
- /// Get the logical bit width per value.
- unsigned getLogicalBitWidth() const;
-
- /// Returns how many values of this type fit in one byte.
- unsigned getElementsPerByte() const;
-
- /// Get the preferred alignment in bytes.
- std::optional<unsigned> getPreferredAlignmentBytes() const;
- }];
-}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index ccd469df3ab13..ecc128cf767b3 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -213,9 +213,6 @@ class Parser {
/// Parse a complex type.
Type parseComplexType();
- /// Parse a quantile type.
- Type parseQuantileType();
-
/// Parse an extended type.
Type parseExtendedType();
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index b2cd2d1e8dd58..fe7c53753e156 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -119,7 +119,6 @@ TOK_KEYWORD(min)
TOK_KEYWORD(mod)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
-TOK_KEYWORD(quantile)
TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(step)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index e6c7592c24aa7..427a39567c0cf 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -52,7 +52,6 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f8E8M0FNU:
case Token::kw_bf16:
case Token::kw_f16:
- case Token::kw_quantile:
case Token::kw_tf32:
case Token::kw_f32:
case Token::kw_f64:
@@ -284,8 +283,6 @@ Type Parser::parseNonFunctionType() {
return parseTensorType();
case Token::kw_complex:
return parseComplexType();
- case Token::kw_quantile:
- return parseQuantileType();
case Token::kw_tuple:
return parseTupleType();
case Token::kw_vector:
@@ -387,119 +384,6 @@ Type Parser::parseNonFunctionType() {
}
}
-/// Parse a quantile type.
-///
-/// quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
-///
-Type Parser::parseQuantileType() {
- SMLoc typeLoc = getToken().getLoc();
- consumeToken(Token::kw_quantile);
-
- if (parseToken(Token::less, "expected '<' in quantile type"))
- return nullptr;
-
- // Parse the storage type.
- Type storageType = parseType();
- if (!storageType)
- return nullptr;
-
- if (parseToken(Token::colon, "expected ':' in quantile type"))
- return nullptr;
-
- // Parse the quantile (expressed) type.
- Type quantileType = parseType();
- if (!quantileType)
- return nullptr;
-
- if (parseToken(Token::comma, "expected ',' in quantile type"))
- return nullptr;
-
- if (parseToken(Token::l_brace, "expected '{' in quantile type"))
- return nullptr;
-
- // Parse the quantile values as floating point literals.
- SmallVector<double, 16> quantiles;
- do {
- bool isNegative = consumeIf(Token::minus);
- Token curTok = getToken();
- std::optional<APFloat> apResult;
- if (failed(parseFloatFromLiteral(apResult, curTok, isNegative,
- APFloat::IEEEdouble())))
- return nullptr;
- consumeToken();
- quantiles.push_back(apResult->convertToDouble());
- } while (consumeIf(Token::comma) &&
- !getToken().is(Token::r_brace));
-
- if (parseToken(Token::r_brace, "expected '}' in quantile type"))
- return nullptr;
-
- if (parseToken(Token::greater, "expected '>' in quantile type"))
- return nullptr;
-
- // Optionally parse explicit storage range: `<min:max>`.
- std::optional<int64_t> storageMin, storageMax;
- if (consumeIf(Token::less)) {
- int64_t minVal, maxVal;
-
- // Parse minimum value (with optional sign).
- bool minNegative = consumeIf(Token::minus);
- if (!getToken().is(Token::integer))
- return (emitWrongTokenError(
- "expected integer minimum in quantile storage range"),
- nullptr);
- SMLoc minLoc = getToken().getLoc();
- if (getToken().getSpelling().getAsInteger(10, minVal))
- return nullptr;
- consumeToken(Token::integer);
- minVal = minNegative ? -minVal : minVal;
-
- if (parseToken(Token::colon, "expected ':' in quantile storage range"))
- return nullptr;
-
- // Parse maximum value (with optional sign).
- bool maxNegative = consumeIf(Token::minus);
- if (!getToken().is(Token::integer))
- return (emitWrongTokenError(
- "expected integer maximum in quantile storage range"),
- nullptr);
- SMLoc maxLoc = getToken().getLoc();
- if (getToken().getSpelling().getAsInteger(10, maxVal))
- return nullptr;
- consumeToken(Token::integer);
- maxVal = maxNegative ? -maxVal : maxVal;
-
- if (parseToken(Token::greater, "expected '>' after quantile storage range"))
- return nullptr;
-
- // Validate against the underlying storage type's inherent limits.
- if (auto qsIface = llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
- bool isSigned = qsIface.shouldDefaultToSigned();
- if (minVal < qsIface.getDefaultMinimum(isSigned))
- return (emitError(minLoc, "illegal storage type minimum: ") << minVal,
- nullptr);
- if (maxVal > qsIface.getDefaultMaximum(isSigned))
- return (emitError(maxLoc, "illegal storage type maximum: ") << maxVal,
- nullptr);
- }
-
- storageMin = minVal;
- storageMax = maxVal;
- }
-
- auto type = QuantileType::getChecked([&]() { return emitError(typeLoc); },
- storageType, quantileType, quantiles,
- storageMin, storageMax);
- if (!type)
- return nullptr;
- return type;
-}
-
-/// Parse a tensor type.
-///
-/// tensor-type ::= `tensor` `<` dimension-list type `>`
-/// dimension-list ::= dimension-list-ranked | `*x`
-///
Type Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 060707437334e..4274edef54748 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "QuantDialectBytecode.h"
-#include "TypeDetail.h"
+#include "mlir/Dialect/Quant/IR/detail/TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
@@ -200,7 +200,8 @@ struct QuantInlinerInterface : public DialectInlinerInterface {
void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
- UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
+ UniformQuantizedPerAxisType, UniformQuantizedSubChannelType,
+ QuantileType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index c5a36f7106ad3..d2c3d48cd705d 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -7,8 +7,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
-#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Quant/IR/detail/TypeDetail.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/QuantStorageTypeInterface.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -552,3 +554,139 @@ LogicalResult CalibratedQuantizedType::verifyInvariants(
double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
+
+QuantileType QuantileType::get(mlir::MLIRContext *ctx, mlir::Type storageType,
+ mlir::Type quantileType,
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin,
+ std::optional<int64_t> storageMax) {
+ return Base::get(ctx, storageType, quantileType, quantiles, storageMin,
+ storageMax);
+}
+
+QuantileType QuantileType::getChecked(
+ function_ref<InFlightDiagnostic()> emitError, mlir::MLIRContext *ctx,
+ mlir::Type storageType, mlir::Type quantileType, ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin, std::optional<int64_t> storageMax) {
+ return Base::getChecked(emitError, ctx, storageType, quantileType, quantiles,
+ storageMin, storageMax);
+}
+
+LogicalResult QuantileType::verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, Type storageType,
+ Type quantileType, ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin, std::optional<int64_t> storageMax) {
+ if (!storageType.isIntOrFloat())
+ return emitError() << "storage type must be an integer or float type";
+ if (!llvm::isa<mlir::FloatType>(quantileType))
+ return emitError() << "quantile type must be a float type";
+ if (quantiles.empty())
+ return emitError() << "quantile values must not be empty";
+ if (storageMin.has_value() != storageMax.has_value())
+ return emitError()
+ << "storage min and max must both be specified or both omitted";
+ if (storageMin && storageMax && *storageMin >= *storageMax)
+ return emitError() << "storage min must be less than storage max";
+
+ unsigned width = storageType.getIntOrFloatBitWidth();
+ bool isSigned = !llvm::isa<mlir::IntegerType>(storageType) ||
+ llvm::cast<mlir::IntegerType>(storageType).isSigned();
+ auto effectiveMin =
+ storageMin.value_or(isSigned ? -(1LL << (width - 1)) : 0LL);
+ auto effectiveMax = storageMax.value_or(isSigned ? (1LL << (width - 1)) - 1
+ : (1LL << width) - 1);
+ auto expectedSize = effectiveMax - effectiveMin + 1;
+ if (static_cast<decltype(expectedSize)>(quantiles.size()) != expectedSize)
+ return emitError() << "quantile LUT size (" << quantiles.size()
+ << ") must equal the number of representable storage "
+ "values ("
+ << expectedSize << ")";
+
+ for (double v : quantiles)
+ if (std::isnan(v) || std::isinf(v))
+ return emitError()
+ << "quantile values must be finite (no NaN or infinity)";
+
+ return success();
+}
+
+bool QuantileType::classof(mlir::Type type) {
+ return type.getTypeID() == mlir::TypeID::get<QuantileType>();
+}
+
+mlir::Type QuantileType::getStorageType() const {
+ return static_cast<ImplType *>(impl)->getStorageType();
+}
+
+mlir::Type QuantileType::getQuantileType() const {
+ return static_cast<ImplType *>(impl)->getQuantileType();
+}
+
+ArrayRef<double> QuantileType::getQuantiles() const {
+ return static_cast<ImplType *>(impl)->getQuantiles();
+}
+
+std::optional<int64_t> QuantileType::getStorageMin() const {
+ return static_cast<ImplType *>(impl)->getStorageMin();
+}
+
+std::optional<int64_t> QuantileType::getStorageMax() const {
+ return static_cast<ImplType *>(impl)->getStorageMax();
+}
+
+bool QuantileType::shouldDefaultToSigned() const {
+ if (auto intType = mlir::dyn_cast<mlir::IntegerType>(getStorageType()))
+ return intType.isSigned();
+ // Float types default to signed.
+ return true;
+}
+
+unsigned QuantileType::getStorageWidth() const {
+ return getStorageType().getIntOrFloatBitWidth();
+}
+
+int64_t QuantileType::getDefaultMaximum(bool isSigned) const {
+ if (auto explicitMax = getStorageMax())
+ return *explicitMax;
+ if (isSigned)
+ return (1LL << (getStorageWidth() - 1)) - 1;
+ return (1LL << getStorageWidth()) - 1;
+}
+
+int64_t QuantileType::getDefaultMinimum(bool isSigned) const {
+ if (auto explicitMin = getStorageMin())
+ return *explicitMin;
+ if (isSigned)
+ return -(1LL << (getStorageWidth() - 1));
+ return 0;
+}
+
+std::string QuantileType::getStorageTypeName(bool isSigned) const {
+ std::string result = "!quant.quantile<";
+ llvm::raw_string_ostream os(result);
+ os << getStorageType() << ":" << getQuantileType() << ", {";
+ ArrayRef<double> quantiles = this->getQuantiles();
+ llvm::interleave(
+ llvm::seq<size_t>(0, quantiles.size()), os,
+ [&](size_t index) { os << quantiles[index]; }, ",");
+ os << "}";
+ if (auto minVal = getStorageMin())
+ if (auto maxVal = getStorageMax())
+ os << ", <" << *minVal << ":" << *maxVal << ">";
+ os << ">";
+ os.flush();
+ return result;
+}
+
+bool QuantileType::isPacked() const { return getStorageWidth() <= 4; }
+
+unsigned QuantileType::getLogicalBitWidth() const { return getStorageWidth(); }
+
+unsigned QuantileType::getElementsPerByte() const {
+ unsigned width = getStorageWidth();
+ return width > 0 ? 8 / width : 0;
+}
+
+std::optional<unsigned> QuantileType::getPreferredAlignmentBytes() const {
+ return std::nullopt;
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 5bffe43c818f9..2a845bed0a535 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -480,6 +480,60 @@ static Type parseCalibratedType(DialectAsmParser &parser) {
return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
}
+static Type parseQuantileType(DialectAsmParser &parser) {
+ Type storageType;
+ Type quantileType;
+ SmallVector<double, 1> quantiles;
+
+ if (parser.parseLess())
+ return nullptr;
+ if (parser.parseType(storageType))
+ return nullptr;
+ if (parser.parseColon())
+ return nullptr;
+ if (parser.parseType(quantileType))
+ return nullptr;
+ if (parser.parseComma())
+ return nullptr;
+ if (parser.parseLBrace())
+ return nullptr;
+
+ // Allow empty braces `{}` — verify() will catch the empty quantile error.
+ if (failed(parser.parseOptionalRBrace())) {
+ do {
+ quantiles.emplace_back();
+ if (parser.parseFloat(quantiles.back()))
+ return nullptr;
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (parser.parseRBrace())
+ return nullptr;
+ }
+
+ // Optionally parse explicit storage range: `, min:max` (inside the outer
+ // `<>`).
+ std::optional<int64_t> storageMin, storageMax;
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseLess())
+ return nullptr;
+ int64_t minVal, maxVal;
+ if (parser.parseInteger(minVal) || parser.parseColon() ||
+ parser.parseInteger(maxVal))
+ return nullptr;
+ storageMin = minVal;
+ storageMax = maxVal;
+ if (parser.parseGreater())
+ return nullptr;
+ }
+
+ if (parser.parseGreater())
+ return nullptr;
+
+ mlir::MLIRContext *ctx = parser.getContext();
+ return parser.getChecked<QuantileType>(ctx, storageType, quantileType,
+ quantiles, storageMin, storageMax);
+}
+
/// Parse a type registered to this dialect.
Type QuantDialect::parseType(DialectAsmParser &parser) const {
// All types start with an identifier that we switch on.
@@ -493,6 +547,8 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
return parseAnyType(parser);
if (typeNameSpelling == "calibrated")
return parseCalibratedType(parser);
+ if (typeNameSpelling == "quantile")
+ return parseQuantileType(parser);
parser.emitError(parser.getNameLoc(),
"unknown quantized type " + typeNameSpelling);
@@ -653,6 +709,23 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type,
out << ">";
}
+static void printQuantileType(QuantileType type, DialectAsmPrinter &out) {
+ out << "quantile<";
+ out << type.getStorageType();
+ out << ":";
+ out << type.getQuantileType();
+ out << ", {";
+ ArrayRef<double> quantiles = type.getQuantiles();
+ llvm::interleave(
+ llvm::seq<size_t>(0, quantiles.size()), out,
+ [&](size_t index) { out << quantiles[index]; }, ",");
+ out << "}";
+ if (auto minVal = type.getStorageMin())
+ if (auto maxVal = type.getStorageMax())
+ out << ", <" << *minVal << ":" << *maxVal << ">";
+ out << ">";
+}
+
/// Print a type registered to this dialect.
void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
@@ -666,6 +739,8 @@ void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
printUniformQuantizedSubChannelType(perAxisType, os);
else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
printCalibratedQuantizedType(calibratedType, os);
+ else if (auto quantileType = llvm::dyn_cast<QuantileType>(type))
+ printQuantileType(quantileType, os);
else
llvm_unreachable("Unhandled quantized type");
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3f6600173fc5d..75008d6cc2591 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2900,22 +2900,6 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(complexTy.getElementType());
os << '>';
})
- .Case<QuantileType>([&](QuantileType quantileTy) {
- os << "quantile<";
- printType(quantileTy.getStorageType());
- os << ':';
- printType(quantileTy.getQuantileType());
- os << ", {";
- ArrayRef<double> quantiles = quantileTy.getQuantiles();
- llvm::interleave(
- llvm::seq<size_t>(0, quantiles.size()), os,
- [&](size_t index) { os << quantiles[index]; }, ",");
- os << "}";
- os << '>';
- if (auto minVal = quantileTy.getStorageMin())
- if (auto maxVal = quantileTy.getStorageMax())
- os << '<' << *minVal << ':' << *maxVal << '>';
- })
.Case([&](TupleType tupleTy) {
os << "tuple<";
interleaveComma(tupleTy.getTypes(),
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 0a642fd2bed4e..786c30851a071 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -94,123 +94,6 @@ ComplexType::convertFromAttribute(Attribute attr,
return success();
}
-//===----------------------------------------------------------------------===//
-/// QuantileType
-//===----------------------------------------------------------------------===//
-
-Type QuantileType::getStorageType() const { return getImpl()->storageType; }
-
-Type QuantileType::getQuantileType() const { return getImpl()->quantileType; }
-
-ArrayRef<double> QuantileType::getQuantiles() const {
- return getImpl()->getQuantiles();
-}
-
-std::optional<int64_t> QuantileType::getStorageMin() const {
- return getImpl()->getStorageMin();
-}
-
-std::optional<int64_t> QuantileType::getStorageMax() const {
- return getImpl()->getStorageMax();
-}
-
-LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type storageType, Type quantileType,
- ArrayRef<double> quantiles,
- std::optional<int64_t> storageMin,
- std::optional<int64_t> storageMax) {
- if (!storageType.isIntOrFloat())
- return emitError() << "storage type must be an integer or float type";
- if (!llvm::isa<FloatType>(quantileType))
- return emitError() << "quantile type must be a float type";
- if (quantiles.empty())
- return emitError() << "quantile values must not be empty";
- if (storageMin.has_value() != storageMax.has_value())
- return emitError()
- << "storage min and max must both be specified or both omitted";
-
- // Validate explicit storage range.
- if (storageMin && storageMax && *storageMin >= *storageMax)
- return emitError() << "storage min must be less than storage max";
-
- unsigned width = storageType.getIntOrFloatBitWidth();
- bool isSigned = !llvm::isa<IntegerType>(storageType) ||
- llvm::cast<IntegerType>(storageType).isSigned();
- auto effectiveMin =
- storageMin.value_or(isSigned ? -(1LL << (width - 1)) : 0LL);
- auto effectiveMax = storageMax.value_or(isSigned ? (1LL << (width - 1)) - 1
- : (1LL << width) - 1);
- auto expectedSize = effectiveMax - effectiveMin + 1;
- if (static_cast<decltype(expectedSize)>(quantiles.size()) != expectedSize)
- return emitError() << "quantile LUT size (" << quantiles.size()
- << ") must equal the number of representable storage "
- "values ("
- << expectedSize << ")";
-
- // No NaN or infinity allowed in the LUT.
- for (double v : quantiles)
- if (std::isnan(v) || std::isinf(v))
- return emitError() << "quantile values must be finite (no NaN or "
- "infinity)";
-
- return success();
-}
-
-bool QuantileType::shouldDefaultToSigned() const {
- if (auto intType = llvm::dyn_cast<IntegerType>(getStorageType()))
- return intType.isSigned();
- // Float types default to signed.
- return true;
-}
-
-unsigned QuantileType::getStorageWidth() const {
- return getStorageType().getIntOrFloatBitWidth();
-}
-
-int64_t QuantileType::getDefaultMaximum(bool isSigned) const {
- if (auto explicitMax = getStorageMax())
- return *explicitMax;
- if (isSigned)
- return (1LL << (getStorageWidth() - 1)) - 1;
- return (1LL << getStorageWidth()) - 1;
-}
-
-int64_t QuantileType::getDefaultMinimum(bool isSigned) const {
- if (auto explicitMin = getStorageMin())
- return *explicitMin;
- if (isSigned)
- return -(1LL << (getStorageWidth() - 1));
- return 0;
-}
-
-std::string QuantileType::getStorageTypeName(bool isSigned) const {
- std::string result = "quantile<";
- llvm::raw_string_ostream os(result);
- os << getStorageType() << ":" << getQuantileType() << ", {";
- ArrayRef<double> quantiles = getQuantiles();
- llvm::interleave(
- llvm::seq<size_t>(0, quantiles.size()), os,
- [&](size_t index) { os << quantiles[index]; }, ",");
- os << "}>";
- if (auto minVal = getStorageMin())
- if (auto maxVal = getStorageMax())
- os << '<' << *minVal << ':' << *maxVal << '>';
- return result;
-}
-
-bool QuantileType::isPacked() const { return getStorageWidth() <= 4; }
-
-unsigned QuantileType::getLogicalBitWidth() const { return getStorageWidth(); }
-
-unsigned QuantileType::getElementsPerByte() const {
- unsigned width = getStorageWidth();
- return width > 0 ? 8 / width : 0;
-}
-
-std::optional<unsigned> QuantileType::getPreferredAlignmentBytes() const {
- return std::nullopt;
-}
-
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -552,7 +435,7 @@ bool TensorType::isValidElementType(Type type) {
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
- IndexType, QuantileType>(type) ||
+ IndexType>(type) ||
!llvm::isa<BuiltinDialect>(type.getDialect());
}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 9e805ba25c3f4..0e952d5c14c7e 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -148,65 +148,6 @@ Attribute skipDefaultMemorySpace(Attribute memorySpace);
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt(Attribute memorySpace);
-/// Quantile Type Storage and Uniquing.
-struct QuantileTypeStorage : public TypeStorage {
- QuantileTypeStorage(Type storageType, Type quantileType,
- ArrayRef<double> quantiles,
- std::optional<int64_t> storageMin,
- std::optional<int64_t> storageMax)
- : storageType(storageType), quantileType(quantileType),
- quantilesData(quantiles.data()), numQuantiles(quantiles.size()),
- storageMin(storageMin), storageMax(storageMax) {}
-
- using KeyTy = std::tuple<Type, Type, ArrayRef<double>, std::optional<int64_t>,
- std::optional<int64_t>>;
-
- static llvm::hash_code hashKey(const KeyTy &key) {
- auto quantiles = std::get<2>(key);
- auto *quantilesBits = llvm::bit_cast<const int64_t *>(quantiles.data());
- ArrayRef<int64_t> quantilesAsInts(quantilesBits, quantiles.size());
- auto hashOptInt = [](std::optional<int64_t> opt) -> llvm::hash_code {
- return opt ? llvm::hash_combine(true, *opt)
- : llvm::hash_combine(false, int64_t{0});
- };
- return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
- llvm::hash_combine_range(quantilesAsInts.begin(),
- quantilesAsInts.end()),
- hashOptInt(std::get<3>(key)),
- hashOptInt(std::get<4>(key)));
- }
-
- bool operator==(const KeyTy &key) const {
- return storageType == std::get<0>(key) &&
- quantileType == std::get<1>(key) &&
- getQuantiles() == std::get<2>(key) &&
- storageMin == std::get<3>(key) && storageMax == std::get<4>(key);
- }
-
- static QuantileTypeStorage *construct(TypeStorageAllocator &allocator,
- const KeyTy &key) {
- ArrayRef<double> quantiles = allocator.copyInto(std::get<2>(key));
- return new (allocator.allocate<QuantileTypeStorage>())
- QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles,
- std::get<3>(key), std::get<4>(key));
- }
-
- Type getStorageType() const { return storageType; }
- Type getQuantileType() const { return quantileType; }
- ArrayRef<double> getQuantiles() const {
- return ArrayRef<double>(quantilesData, numQuantiles);
- }
- std::optional<int64_t> getStorageMin() const { return storageMin; }
- std::optional<int64_t> getStorageMax() const { return storageMax; }
-
- Type storageType;
- Type quantileType;
- const double *quantilesData;
- unsigned numQuantiles;
- std::optional<int64_t> storageMin;
- std::optional<int64_t> storageMax;
-};
-
} // namespace detail
} // namespace mlir
diff --git a/mlir/test/Dialect/Quant/invalid-quantile-types.mlir b/mlir/test/Dialect/Quant/invalid-quantile-types.mlir
new file mode 100644
index 0000000000000..faf16d01a9cd5
--- /dev/null
+++ b/mlir/test/Dialect/Quant/invalid-quantile-types.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
+
+//===----------------------------------------------------------------------===//
+// Verify errors (caught by verify(), reached through getChecked())
+//===----------------------------------------------------------------------===//
+
+// Storage type must be an integer or float.
+// expected-error @+1 {{storage type must be an integer or float type}}
+func.func private @invalid_storage_type() -> !quant.quantile<tensor<1xf32>:f32, {1.0}>
+
+// -----
+
+// Quantile (expressed) type must be a float.
+// expected-error @+1 {{quantile type must be a float type}}
+func.func private @invalid_quantile_type() -> !quant.quantile<ui4:i8, {1.0, 0.0, -1.0}>
+
+// -----
+
+// Quantile LUT must not be empty.
+// expected-error @+1 {{quantile values must not be empty}}
+func.func private @empty_quantiles() -> !quant.quantile<ui4:f16, {}>
+
+// -----
+
+// LUT size must match the number of representable storage values.
+// ui4 has 16 representable values [0,15], but only 3 are provided.
+// expected-error @+1 {{quantile LUT size (3) must equal the number of representable storage values (16)}}
+func.func private @wrong_lut_size() -> !quant.quantile<ui4:f16, {-1.0,0.0,1.0}>
+
+// -----
+
+// Explicit storage range: min must be strictly less than max.
+// si4 default range is [-8,7]; explicit 5:3 has min > max.
+// expected-error @+1 {{storage min must be less than storage max}}
+func.func private @invalid_range_order() -> !quant.quantile<si4:f32, {-2.0,-1.875,-1.75,-1.625,-1.5,-1.375,-1.25,-1.125,-1.0,-0.875,-0.75,-0.625,-0.5,-0.375,-0.25,-0.125}, <5:3>>
+
+// -----
+
+// LUT size must match the total representable values of the storage type.
+// f4E2M1FN has 16 representable values regardless of explicit range, but only 3 are provided.
+// expected-error @+1 {{quantile LUT size (3) must equal the number of representable storage values (13)}}
+func.func private @wrong_lut_size_with_range() -> !quant.quantile<f4E2M1FN:f16, {-1.0,0.0,1.0}, <-6:6>>
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index b3867f2dc35fa..5ff54a8804844 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -267,24 +267,18 @@
// -----
-// Illegal storage min/max: max > defaultMax
-// expected-error at +1 {{illegal storage type maximum: 10}}
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><6:10>:f32, 0.99872:127>
-
-// -----
-
-// Illegal storage min/max: min < defaultMin
-// expected-error at +1 {{illegal storage type minimum: -10}}
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><-10:-6>:f32, 0.99872:127>
+// Invalid LUT size: 16 values but explicit range 6:10 has only 5 representable values.
+// expected-error at +1 {{quantile LUT size (16) must equal the number of representable storage values (5)}}
+!qalias = !quant.uniform<!quant.quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}, <6:10>>:f32, 0.99872:127>
// -----
// Quantile storage range: min must be strictly less than max.
// expected-error at +1 {{storage min must be less than storage max}}
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16,{-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><5:3>:f32, 0.99872:127>
+!qalias = !quant.uniform<!quant.quantile<f4E2M1FN:f16,{-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}, <5:3>>:f32, 0.99872:127>
// -----
// Quantile LUT size (3) does not match the 16 representable values of f4E2M1FN's default range.
// expected-error at +1 {{quantile LUT size (3) must equal the number of representable storage values (16)}}
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,0.0,1.0}>:f32, 0.99872:127>
+!qalias = !quant.uniform<!quant.quantile<f4E2M1FN:f16, {-1.0,0.0,1.0}>:f32, 0.99872:127>
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 27bb6dc98470d..56219d2a3f437 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -247,9 +247,9 @@ func.func @parse() -> !qalias {
}
// -----
-// Storage type: QuantileType with narrowed explicit range <-6:6> (13 representable values).
-// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.750000e-01,-7.500000e-01,-6.250000e-01,-5.000000e-01,-2.500000e-01,0.000000e+00,2.500000e-01,5.000000e-01,6.250000e-01,7.500000e-01,8.750000e-01,1.000000e+00}><-6:6>:f32, 9.987200e-01:127>
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,-0.875,-0.75,-0.625,-0.5,-0.25,0.0,0.25,0.5,0.625,0.75,0.875,1.0}><-6:6>:f32, 0.99872:127 >
+// Storage type: QuantileType with narrowed explicit range `-6:6` (13 representable values).
+// CHECK: !quant.uniform<!quant.quantile<f4E2M1FN:f16, {-1.000000e+00,-8.750000e-01,-7.500000e-01,-6.250000e-01,-5.000000e-01,-2.500000e-01,0.000000e+00,2.500000e-01,5.000000e-01,6.250000e-01,7.500000e-01,8.750000e-01,1.000000e+00}, <-6:6>>:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<!quant.quantile<f4E2M1FN:f16, {-1.0,-0.875,-0.75,-0.625,-0.5,-0.25,0.0,0.25,0.5,0.625,0.75,0.875,1.0}, <-6:6>>:f32, 0.99872:127>
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
@@ -257,8 +257,8 @@ func.func @parse() -> !qalias {
// -----
// Storage type: QuantileType
-// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>:f32, 2.000000e+02>
-!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}>:f32, 2.0e+2 >
+// CHECK: !quant.uniform<!quant.quantile<f4E2M1FN:f16, {
+!qalias = !quant.uniform<!quant.quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}>:f32, 2.0e+2 >
func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
diff --git a/mlir/test/Dialect/Quant/quantile-types.mlir b/mlir/test/Dialect/Quant/quantile-types.mlir
new file mode 100644
index 0000000000000..90b6d3be828cd
--- /dev/null
+++ b/mlir/test/Dialect/Quant/quantile-types.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | FileCheck %s
+
+// -----
+// Quantile type: ui4 storage with f16 expressed, 16 entries (default range 0..15).
+// CHECK-LABEL: func private @quantile_ui4_f16
+// CHECK-SAME: !quant.quantile<ui4:f16, {
+func.func private @quantile_ui4_f16(!quant.quantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>) -> ()
+
+// -----
+// Quantile type: si8 storage with f32 expressed, explicit range -2:2 (5 entries).
+// CHECK: func private @quantile_si8_f32(!quant.quantile<si8:f32, {-1.000000e+00,-5.000000e-01,0.000000e+00,5.000000e-01,1.000000e+00}, <-2:2>>)
+func.func private @quantile_si8_f32(!quant.quantile<si8:f32, {-1.0,-0.5,0.0,0.5,1.0}, <-2:2>>) -> ()
+
+// -----
+// Quantile type: i8 (signless) storage with f32 expressed, explicit range -1:1 (3 entries).
+// CHECK: func private @quantile_i8_f32(!quant.quantile<i8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}, <-1:1>>)
+func.func private @quantile_i8_f32(!quant.quantile<i8:f32, {-1.0,0.0,1.0}, <-1:1>>) -> ()
+
+// -----
+// Quantile type: f8E4M3FN float storage with f32 expressed, explicit range -1:1 (3 entries).
+// CHECK: func private @quantile_f8_f32(!quant.quantile<f8E4M3FN:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}, <-1:1>>)
+func.func private @quantile_f8_f32(!quant.quantile<f8E4M3FN:f32, {-1.0,0.0,1.0}, <-1:1>>) -> ()
+
+// -----
+// Quantile type: ui4 storage with bf16 expressed, 16 entries.
+// CHECK-LABEL: func private @quantile_ui4_bf16
+// CHECK-SAME: !quant.quantile<ui4:bf16, {
+func.func private @quantile_ui4_bf16(!quant.quantile<ui4:bf16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>) -> ()
+
+// -----
+// Quantile type used as a return type.
+// CHECK-LABEL: func private @quantile_as_return
+// CHECK-SAME: !quant.quantile<ui4:f16, {
+func.func private @quantile_as_return() -> !quant.quantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>
+
+// -----
+// NF4-style 16-entry quantile table.
+// CHECK-LABEL: @nf4_16_values
+// CHECK-SAME: !quant.quantile<ui4:f16, {
+func.func private @nf4_16_values(!quant.quantile<ui4:f16, {
+ -1.0,-0.6961928009986877,-0.5250730514526367,-0.39491748809814453,
+ -0.28444138169288635,-0.18477343022823334,-0.09105003625154495,0.0,
+ 0.07958029955625534,0.16093020141124725,0.24611230194568634,
+ 0.33791524171829224,0.44070982933044434,0.5626170039176941,
+ 0.7229568362236023,1.0}>) -> ()
+
+// -----
+// Explicit storage min/max range (unsigned storage, narrowed range 0..7, 8 entries).
+// CHECK: func private @quantile_with_range(!quant.quantile<ui4:f16, {-1.000000e+00,-7.500000e-01,-5.000000e-01,-2.500000e-01,0.000000e+00,2.500000e-01,5.000000e-01,1.000000e+00}, <0:7>>)
+func.func private @quantile_with_range(!quant.quantile<ui4:f16, {-1.0,-0.75,-0.5,-0.25,0.0,0.25,0.5,1.0}, <0:7>>) -> ()
+
+// -----
+// Explicit range is preserved through round-trip.
+// CHECK: func private @quantile_signed_range(!quant.quantile<si8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}, <-1:1>>)
+func.func private @quantile_signed_range(!quant.quantile<si8:f32, {-1.0,0.0,1.0}, <-1:1>>) -> ()
+
+// -----
+// Signed 4-bit storage uses full 16-entry LUT (range -8..7).
+// CHECK-LABEL: func private @quantile_negatives
+// CHECK-SAME: !quant.quantile<si4:f32, {-2.000000e+00,-1.875000e+00,-1.750000e+00,-1.625000e+00,-1.500000e+00,-1.375000e+00,-1.250000e+00,-1.125000e+00,-1.000000e+00,-8.750000e-01,-7.500000e-01,-6.250000e-01,-5.000000e-01,-3.750000e-01,-2.500000e-01,-1.250000e-01}>
+func.func private @quantile_negatives(!quant.quantile<si4:f32, {-2.0,-1.875,-1.75,-1.625,-1.5,-1.375,-1.25,-1.125,-1.0,-0.875,-0.75,-0.625,-0.5,-0.375,-0.25,-0.125}>) -> ()
+
+// -----
+// 1-bit unsigned storage: minimal 2-entry LUT.
+// CHECK: func private @quantile_ui1_f16(!quant.quantile<ui1:f16, {-1.000000e+00,1.000000e+00}>)
+func.func private @quantile_ui1_f16(!quant.quantile<ui1:f16, {-1.0,1.0}>) -> ()
+
+// -----
+// LUT values in descending order (ui4, explicit range 0:7, 8 entries).
+// CHECK: func private @quantile_descending(!quant.quantile<ui4:f16, {1.000000e+00,7.500000e-01,5.000000e-01,2.500000e-01,0.000000e+00,-2.500000e-01,-5.000000e-01,-1.000000e+00}, <0:7>>)
+func.func private @quantile_descending(!quant.quantile<ui4:f16, {1.0,0.75,0.5,0.25,0.0,-0.25,-0.5,-1.0}, <0:7>>) -> ()
+
+// -----
+// LUT values in arbitrary order (ui4, explicit range 0:7, 8 entries).
+// CHECK: func private @quantile_random_order(!quant.quantile<ui4:f16, {0.000000e+00,-5.000000e-01,1.000000e+00,-2.500000e-01,7.500000e-01,-1.000000e+00,5.000000e-01,2.500000e-01}, <0:7>>)
+func.func private @quantile_random_order(!quant.quantile<ui4:f16, {0.0,-0.5,1.0,-0.25,0.75,-1.0,0.5,0.25}, <0:7>>) -> ()
\ No newline at end of file
diff --git a/mlir/test/IR/invalid-quantile-types.mlir b/mlir/test/IR/invalid-quantile-types.mlir
deleted file mode 100644
index 500df1a83f275..0000000000000
--- a/mlir/test/IR/invalid-quantile-types.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
-
-//===----------------------------------------------------------------------===//
-// Parser error tests
-//===----------------------------------------------------------------------===//
-
-// Test missing '<' after 'quantile' keyword.
-// expected-error @+1 {{expected '<' in quantile type}}
-func.func private @missing_lt() -> quantile ui4:f16, {1.0}>
-
-// -----
-
-// Test missing ':' between storage type and quantile type.
-// expected-error @+1 {{expected ':' in quantile type}}
-func.func private @missing_colon() -> quantile<ui4 f16, {1.0}>
-
-// -----
-
-// Test missing ',' between quantile type and quantile value list.
-// expected-error @+1 {{expected ',' in quantile type}}
-func.func private @missing_comma() -> quantile<ui4:f16 {1.0}>
-
-// -----
-
-// Test missing '{' before quantile value list.
-// expected-error @+1 {{expected '{' in quantile type}}
-func.func private @missing_lbrace() -> quantile<ui4:f16, 1.0}>
-
-// -----
-
-// Test missing '}' after quantile value list.
-// expected-error @+1 {{expected '}' in quantile type}}
-func.func private @missing_rbrace() -> quantile<ui4:f16, {1.0>
-
-// -----
-
-// Test missing '>' closing the quantile type.
-// expected-error @+1 {{expected '>' in quantile type}}
-func.func private @missing_gt() -> quantile<ui4:f16, {1.0}
-
-// -----
-
-// Test missing '>' closing the storage range.
-// expected-error @below {{expected '>' after quantile storage range}}
-func.func private @missing_range_gt() -> quantile<ui4:f16, {1.0}><-8:7
diff --git a/mlir/test/IR/quantile-types.mlir b/mlir/test/IR/quantile-types.mlir
deleted file mode 100644
index 0950e4a96678a..0000000000000
--- a/mlir/test/IR/quantile-types.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt %s | FileCheck %s
-
-// Verify round-trip parsing and printing of the builtin quantile type.
-
-// CHECK-LABEL: func private @quantile_ui4_f16
-// CHECK-SAME: quantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>
-func.func private @quantile_ui4_f16(quantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>) -> ()
-
-// CHECK: func private @quantile_si8_f32(quantile<si8:f32, {-1.000000e+00,-5.000000e-01,0.000000e+00,5.000000e-01,1.000000e+00}><-2:2>)
-func.func private @quantile_si8_f32(quantile<si8:f32, {-1.0,-0.5,0.0,0.5,1.0}><-2:2>) -> ()
-
-// CHECK: func private @quantile_i8_f32(quantile<i8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-1:1>)
-func.func private @quantile_i8_f32(quantile<i8:f32, {-1.0,0.0,1.0}><-1:1>) -> ()
-
-// CHECK: func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-1:1>)
-func.func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.0,0.0,1.0}><-1:1>) -> ()
-
-// CHECK-LABEL: func private @quantile_ui4_bf16
-// CHECK-SAME: quantile<ui4:bf16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>
-func.func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>) -> ()
-
-// CHECK-LABEL: func private @quantile_as_return
-// CHECK-SAME: quantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>
-func.func private @quantile_as_return() -> quantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>
-
-// Verify use as memref element type (requires MemRefElementTypeInterface).
-// CHECK-LABEL: func private @quantile_in_memref
-// CHECK-SAME: memref<8xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
-func.func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
-
-// Verify use as tensor element type.
-// CHECK-LABEL: func private @quantile_in_tensor
-// CHECK-SAME: tensor<16xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
-func.func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
-
-// Verify use in multidimensional tensors.
-// CHECK-LABEL: func private @quantile_in_ranked_tensor
-// CHECK-SAME: tensor<16x16x1x1xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
-func.func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
-
-// Verify use in unranked tensors.
-// CHECK-LABEL: func private @quantile_in_unranked_tensor
-// CHECK-SAME: tensor<*xquantile<ui4:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>>
-func.func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.0,-0.8667,-0.7333,-0.6,-0.4667,-0.3333,-0.2,-0.0667,0.0667,0.2,0.3333,0.4667,0.6,0.7333,0.8667,1.0}>>) -> ()
-
-// Verify NF4-style 16-entry quantile table
-// CHECK-LABEL: @nf4_16_values
-// CHECK-SAME: quantile<ui4:f16, {
-func.func private @nf4_16_values(quantile<ui4:f16, {
- -1.0,-0.6961928009986877,-0.5250730514526367,-0.39491748809814453,
- -0.28444138169288635,-0.18477343022823334,-0.09105003625154495,0.0,
- 0.07958029955625534,0.16093020141124725,0.24611230194568634,
- 0.33791524171829224,0.44070982933044434,0.5626170039176941,
- 0.7229568362236023,1.0}>) -> ()
-
-// Verify explicit storage min/max range (unsigned storage, narrowed range).
-// CHECK: func private @quantile_with_range(quantile<ui4:f16, {-1.000000e+00,-7.500000e-01,-5.000000e-01,-2.500000e-01,0.000000e+00,2.500000e-01,5.000000e-01,1.000000e+00}><0:7>)
-func.func private @quantile_with_range(quantile<ui4:f16, {-1.0,-0.75,-0.5,-0.25,0.0,0.25,0.5,1.0}><0:7>) -> ()
-
-// Verify explicit range is preserved through round-trip.
-// CHECK: func private @quantile_signed_range(quantile<si8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-1:1>)
-func.func private @quantile_signed_range(quantile<si8:f32, {-1.0,0.0,1.0}><-1:1>) -> ()
-
-// Verify signed 4-bit storage type uses full 16-entry LUT (all-negative values).
-// CHECK-LABEL: func private @quantile_negatives
-// CHECK-SAME: quantile<si4:f32, {-2.000000e+00,-1.875000e+00,-1.750000e+00,-1.625000e+00,-1.500000e+00,-1.375000e+00,-1.250000e+00,-1.125000e+00,-1.000000e+00,-8.750000e-01,-7.500000e-01,-6.250000e-01,-5.000000e-01,-3.750000e-01,-2.500000e-01,-1.250000e-01}>
-func.func private @quantile_negatives(quantile<si4:f32, {-2.0,-1.875,-1.75,-1.625,-1.5,-1.375,-1.25,-1.125,-1.0,-0.875,-0.75,-0.625,-0.5,-0.375,-0.25,-0.125}>) -> ()
-
-// Verify minimal 2-entry LUT for 1-bit unsigned storage type.
-// CHECK: func private @quantile_ui1_f16(quantile<ui1:f16, {-1.000000e+00,1.000000e+00}>)
-func.func private @quantile_ui1_f16(quantile<ui1:f16, {-1.0,1.0}>) -> ()
-
-// Verify LUT values in descending order
-// Storage is ui4 with explicit <0:7> range (8 entries).
-// CHECK: func private @quantile_descending(quantile<ui4:f16, {1.000000e+00,7.500000e-01,5.000000e-01,2.500000e-01,0.000000e+00,-2.500000e-01,-5.000000e-01,-1.000000e+00}><0:7>)
-func.func private @quantile_descending(quantile<ui4:f16, {1.0,0.75,0.5,0.25,0.0,-0.25,-0.5,-1.0}><0:7>) -> ()
-
-// Verify LUT values in an arbitrary order
-// Storage is ui4 with explicit <0:7> range (8 entries).
-// CHECK: func private @quantile_random_order(quantile<ui4:f16, {0.000000e+00,-5.000000e-01,1.000000e+00,-2.500000e-01,7.500000e-01,-1.000000e+00,5.000000e-01,2.500000e-01}><0:7>)
-func.func private @quantile_random_order(quantile<ui4:f16, {0.0,-0.5,1.0,-0.25,0.75,-1.0,0.5,0.25}><0:7>) -> ()
>From d88b0b38c454391c04738cc201840785517718ef Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Tue, 26 May 2026 11:55:39 +0100
Subject: [PATCH 6/6] nit fixes
---
.../mlir/Dialect/Quant/IR/QuantTypes.h | 24 +++++++++++++++++++
1 file changed, 24 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 7ef1af8d6faf4..40f31beb2781d 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -549,6 +549,30 @@ class CalibratedQuantizedType
double getMax() const;
};
+/*Syntax:
+
+ ```
+ quantile-type ::= `!quant.quantile` `<` type `:` type `,` `{` float-list `}`
+ `,` `<` int `,` int `>`? `>`
+ ```
+
+ A quantile type represents a quantile-based floating point encoding, where
+ discrete storage values are totally defined by the floating-point values
+ entries in a quantile lookup table of F8/F16/F32.
+
+ Optionally, explicit minimum and maximum storage values can be specified
+ after the LUT as `<min:max>`.
+
+ This type is used for weight compression schemes like NF4 (NormalizedFloat4)
+ and similar quantile-based formats.
+
+ Example:
+
+ MLIR:
+ !quant.quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
+ !quant.quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}, <-8,7>>
+*/
+
class QuantileType
: public Type::TypeBase<QuantileType, QuantizedType,
detail::QuantileTypeStorage,
More information about the Mlir-commits
mailing list