[Mlir-commits] [mlir] Quantile Type and Low FP Support (PR #190321)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 3 03:33:21 PDT 2026
https://github.com/vsimion26 updated https://github.com/llvm/llvm-project/pull/190321
>From 0727e20043b75965c31b49390b22966f60052af0 Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Thu, 2 Apr 2026 09:59:52 +0100
Subject: [PATCH 1/2] Added a new BuiltinType, QuantileType and support for low
precision float storage types. Added lit tests
---
mlir/include/mlir/IR/BuiltinTypes.h | 1 +
mlir/include/mlir/IR/BuiltinTypes.td | 78 ++++++++++++++++++++++++
mlir/lib/AsmParser/Parser.h | 3 +
mlir/lib/AsmParser/TokenKinds.def | 1 +
mlir/lib/AsmParser/TypeParser.cpp | 55 +++++++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 15 +++++
mlir/lib/IR/BuiltinTypes.cpp | 74 +++++++++++++++++++++-
mlir/lib/IR/TypeDetail.h | 46 ++++++++++++++
mlir/test/IR/invalid-quantile-types.mlir | 39 ++++++++++++
mlir/test/IR/quantile-types.mlir | 55 +++++++++++++++++
10 files changed, 366 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/IR/invalid-quantile-types.mlir
create mode 100644 mlir/test/IR/quantile-types.mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d30cba29c9814..ef47850929beb 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -35,6 +35,7 @@ class TypeRange;
namespace detail {
struct FunctionTypeStorage;
struct IntegerTypeStorage;
+struct QuantileTypeStorage;
struct TupleTypeStorage;
} // namespace detail
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e7d0a03a85e7d..88df2620e4733 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1109,6 +1109,84 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
let genVerifyDecl = 1;
}
+//===----------------------------------------------------------------------===//
+// QuantileType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
+ [QuantStorageTypeInterface, MemRefElementTypeInterface]> {
+ let summary = "Quantile-based type with a lookup table of quantile values";
+ let description = [{
+ Syntax:
+
+ ```
+ quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
+ ```
+
+ A quantile type represents a quantile-based floating point encoding, where
+ discrete storage values map to floating-point values via a quantile lookup
+ table. The type has a storage type (how raw values are stored, e.g., `ui4`,
+ `si8`), a quantile type (the expressed floating-point precision, e.g.,
+ `f16`, `f32`), and an array of quantile values that define the mapping.
+
+ This type is used for weight compression schemes like NF4 (NormalFloat4)
+ and similar quantile-based formats.
+
+ #### Example:
+
+ ```mlir
+ quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
+ ```
+ }];
+ let parameters = (ins
+ "Type":$storageType,
+ "Type":$quantileType,
+ ArrayRefParameter<"double">:$quantiles
+ );
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "Type":$storageType,
+ "Type":$quantileType,
+ "ArrayRef<double>":$quantiles), [{
+ return $_get(storageType.getContext(), storageType, quantileType,
+ quantiles);
+ }]>
+ ];
+ let skipDefaultBuilders = 1;
+ let genStorageClass = 0;
+ let genVerifyDecl = 1;
+ let extraClassDeclaration = [{
+ /// QuantStorageTypeInterface method implementations.
+
+ /// Returns true if the storage type defaults to signed.
+ bool shouldDefaultToSigned() const;
+
+ /// Get the bit width of the storage type.
+ unsigned getStorageWidth() const;
+
+ /// Get default minimum value for the storage type.
+ int64_t getDefaultMinimum(bool isSigned) const;
+
+ /// Get default maximum value for the storage type.
+ int64_t getDefaultMaximum(bool isSigned) const;
+
+ /// Get the storage type as a string.
+ std::string getStorageTypeName(bool isSigned) const;
+
+ /// Check if the storage type uses packed representation.
+ bool isPacked() const;
+
+ /// Get the logical bit width per value.
+ unsigned getLogicalBitWidth() const;
+
+ /// Returns how many values of this type fit in one byte.
+ unsigned getElementsPerByte() const;
+
+ /// Get the preferred alignment in bytes.
+ std::optional<unsigned> getPreferredAlignmentBytes() const;
+ }];
+}
+
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index ecc128cf767b3..ccd469df3ab13 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -213,6 +213,9 @@ class Parser {
/// Parse a complex type.
Type parseComplexType();
+ /// Parse a quantile type.
+ Type parseQuantileType();
+
/// Parse an extended type.
Type parseExtendedType();
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fe7c53753e156..b2cd2d1e8dd58 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -119,6 +119,7 @@ TOK_KEYWORD(min)
TOK_KEYWORD(mod)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
+TOK_KEYWORD(quantile)
TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(step)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index a461ebed967a8..906c2cf70b605 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -51,6 +51,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f8E8M0FNU:
case Token::kw_bf16:
case Token::kw_f16:
+ case Token::kw_quantile:
case Token::kw_tf32:
case Token::kw_f32:
case Token::kw_f64:
@@ -282,6 +283,8 @@ Type Parser::parseNonFunctionType() {
return parseTensorType();
case Token::kw_complex:
return parseComplexType();
+ case Token::kw_quantile:
+ return parseQuantileType();
case Token::kw_tuple:
return parseTupleType();
case Token::kw_vector:
@@ -383,6 +386,58 @@ Type Parser::parseNonFunctionType() {
}
}
+/// Parse a quantile type.
+///
+/// quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
+///
+Type Parser::parseQuantileType() {
+ consumeToken(Token::kw_quantile);
+
+ if (parseToken(Token::less, "expected '<' in quantile type"))
+ return nullptr;
+
+ // Parse the storage type.
+ Type storageType = parseType();
+ if (!storageType)
+ return nullptr;
+
+ if (parseToken(Token::colon, "expected ':' in quantile type"))
+ return nullptr;
+
+ // Parse the quantile (expressed) type.
+ Type quantileType = parseType();
+ if (!quantileType)
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ',' in quantile type"))
+ return nullptr;
+
+ if (parseToken(Token::l_brace, "expected '{' in quantile type"))
+ return nullptr;
+
+ // Parse the quantile values as floating point literals.
+ SmallVector<double, 16> quantiles;
+ do {
+ bool isNegative = consumeIf(Token::minus);
+ Token curTok = getToken();
+ std::optional<APFloat> apResult;
+ if (failed(parseFloatFromLiteral(apResult, curTok, isNegative,
+ APFloat::IEEEdouble())))
+ return nullptr;
+ consumeToken();
+ quantiles.push_back(apResult->convertToDouble());
+ } while (consumeIf(Token::comma) &&
+ !getToken().is(Token::r_brace));
+
+ if (parseToken(Token::r_brace, "expected '}' in quantile type"))
+ return nullptr;
+
+ if (parseToken(Token::greater, "expected '>' in quantile type"))
+ return nullptr;
+
+ return QuantileType::get(storageType, quantileType, quantiles);
+}
+
/// Parse a tensor type.
///
/// tensor-type ::= `tensor` `<` dimension-list type `>`
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75008d6cc2591..e1145eaafbc64 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2900,6 +2900,21 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(complexTy.getElementType());
os << '>';
})
+ .Case<QuantileType>([&](QuantileType quantileTy) {
+ os << "quantile<";
+ printType(quantileTy.getStorageType());
+ os << ':';
+ printType(quantileTy.getQuantileType());
+ os << ", {";
+ ArrayRef<double> quantiles = quantileTy.getQuantiles();
+ // interleaveComma(llvm::seq<size_t>(0, quantiles.size()),
+ // [&](size_t i) { os << quantiles[i]; });
+ llvm::interleave(
+ llvm::seq<size_t>(0, quantiles.size()), os,
+ [&](size_t index) { os << quantiles[index]; }, ",");
+ os << "}";
+ os << '>';
+ })
.Case([&](TupleType tupleTy) {
os << "tuple<";
interleaveComma(tupleTy.getTypes(),
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 786c30851a071..faf38e6a8354e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -94,6 +94,78 @@ ComplexType::convertFromAttribute(Attribute attr,
return success();
}
+//===----------------------------------------------------------------------===//
+/// QuantileType
+//===----------------------------------------------------------------------===//
+
+Type QuantileType::getStorageType() const { return getImpl()->storageType; }
+
+Type QuantileType::getQuantileType() const { return getImpl()->quantileType; }
+
+ArrayRef<double> QuantileType::getQuantiles() const {
+ return getImpl()->getQuantiles();
+}
+
+LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Type storageType, Type quantileType,
+ ArrayRef<double> quantiles) {
+ if (!storageType.isIntOrFloat())
+ return emitError() << "storage type must be an integer or float type";
+ if (!llvm::isa<FloatType>(quantileType))
+ return emitError() << "quantile type must be a float type";
+ if (quantiles.empty())
+ return emitError() << "quantile values must not be empty";
+ return success();
+}
+
+bool QuantileType::shouldDefaultToSigned() const {
+ if (auto intType = llvm::dyn_cast<IntegerType>(getStorageType()))
+ return !intType.isUnsigned();
+ // Float types default to signed.
+ return true;
+}
+
+unsigned QuantileType::getStorageWidth() const {
+ return getStorageType().getIntOrFloatBitWidth();
+}
+
+int64_t QuantileType::getDefaultMaximum(bool isSigned) const {
+ if (isSigned)
+ return (1LL << (getStorageWidth() - 1)) - 1;
+ return (1LL << getStorageWidth()) - 1;
+}
+
+int64_t QuantileType::getDefaultMinimum(bool isSigned) const {
+ if (isSigned)
+ return -(1LL << (getStorageWidth() - 1));
+ return 0;
+}
+
+std::string QuantileType::getStorageTypeName(bool isSigned) const {
+ std::string result = "quantile<";
+ llvm::raw_string_ostream os(result);
+ os << getStorageType() << ":" << getQuantileType() << ", {";
+ ArrayRef<double> quantiles = getQuantiles();
+ llvm::interleave(
+ llvm::seq<size_t>(0, quantiles.size()), os,
+ [&](size_t index) { os << quantiles[index]; }, ",");
+ os << "}>";
+ return result;
+}
+
+bool QuantileType::isPacked() const { return getStorageWidth() <= 4; }
+
+unsigned QuantileType::getLogicalBitWidth() const { return getStorageWidth(); }
+
+unsigned QuantileType::getElementsPerByte() const {
+ unsigned width = getStorageWidth();
+ return width > 0 ? 8 / width : 0;
+}
+
+std::optional<unsigned> QuantileType::getPreferredAlignmentBytes() const {
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -435,7 +507,7 @@ bool TensorType::isValidElementType(Type type) {
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
- IndexType>(type) ||
+ IndexType, QuantileType>(type) ||
!llvm::isa<BuiltinDialect>(type.getDialect());
}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 0e952d5c14c7e..6afcc6307d4b4 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -148,6 +148,52 @@ Attribute skipDefaultMemorySpace(Attribute memorySpace);
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt(Attribute memorySpace);
+/// Quantile Type Storage and Uniquing.
+struct QuantileTypeStorage : public TypeStorage {
+ QuantileTypeStorage(Type storageType, Type quantileType,
+ ArrayRef<double> quantiles)
+ : storageType(storageType), quantileType(quantileType),
+ quantilesData(quantiles.data()), numQuantiles(quantiles.size()) {}
+
+ /// The hash key used for uniquing.
+ using KeyTy = std::tuple<Type, Type, ArrayRef<double>>;
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ auto quantiles = std::get<2>(key);
+ // Bit-cast doubles to int64_t for hashing since LLVM hashing
+ // does not natively support double.
+ auto *quantilesBits = llvm::bit_cast<const int64_t *>(quantiles.data());
+ ArrayRef<int64_t> quantilesAsInts(quantilesBits, quantiles.size());
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+ llvm::hash_combine_range(quantilesAsInts.begin(),
+ quantilesAsInts.end()));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return storageType == std::get<0>(key) &&
+ quantileType == std::get<1>(key) &&
+ getQuantiles() == std::get<2>(key);
+ }
+
+ static QuantileTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ ArrayRef<double> quantiles = allocator.copyInto(std::get<2>(key));
+ return new (allocator.allocate<QuantileTypeStorage>())
+ QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles);
+ }
+
+ Type getStorageType() const { return storageType; }
+ Type getQuantileType() const { return quantileType; }
+ ArrayRef<double> getQuantiles() const {
+ return ArrayRef<double>(quantilesData, numQuantiles);
+ }
+
+ Type storageType;
+ Type quantileType;
+ const double *quantilesData;
+ unsigned numQuantiles;
+};
+
} // namespace detail
} // namespace mlir
diff --git a/mlir/test/IR/invalid-quantile-types.mlir b/mlir/test/IR/invalid-quantile-types.mlir
new file mode 100644
index 0000000000000..5ee3e674b5cc1
--- /dev/null
+++ b/mlir/test/IR/invalid-quantile-types.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
+
+//===----------------------------------------------------------------------===//
+// Parser error tests
+//===----------------------------------------------------------------------===//
+
+// Test missing '<' after 'quantile' keyword.
+// expected-error @+1 {{expected '<' in quantile type}}
+func.func private @missing_lt() -> quantile ui4:f16, {1.0}>
+
+// -----
+
+// Test missing ':' between storage type and quantile type.
+// expected-error @+1 {{expected ':' in quantile type}}
+func.func private @missing_colon() -> quantile<ui4 f16, {1.0}>
+
+// -----
+
+// Test missing ',' between quantile type and quantile value list.
+// expected-error @+1 {{expected ',' in quantile type}}
+func.func private @missing_comma() -> quantile<ui4:f16 {1.0}>
+
+// -----
+
+// Test missing '{' before quantile value list.
+// expected-error @+1 {{expected '{' in quantile type}}
+func.func private @missing_lbrace() -> quantile<ui4:f16, 1.0}>
+
+// -----
+
+// Test missing '}' after quantile value list.
+// expected-error @+1 {{expected '}' in quantile type}}
+func.func private @missing_rbrace() -> quantile<ui4:f16, {1.0>
+
+// -----
+
+// Test missing '>' closing the quantile type.
+// expected-error @+1 {{expected '>' in quantile type}}
+func.func private @missing_gt() -> quantile<ui4:f16, {1.0}
diff --git a/mlir/test/IR/quantile-types.mlir b/mlir/test/IR/quantile-types.mlir
new file mode 100644
index 0000000000000..3c119ef91750d
--- /dev/null
+++ b/mlir/test/IR/quantile-types.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// Verify round-trip parsing and printing of the builtin quantile type.
+
+// CHECK: func private @quantile_ui4_f16(quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_ui4_f16(quantile<ui4:f16, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_si8_f32(quantile<si8:f32, {-1.000000e+00,-5.000000e-01,0.000000e+00,5.000000e-01,1.000000e+00}>)
+func.func private @quantile_si8_f32(quantile<si8:f32, {-1.0,-0.5,0.0,0.5,1.0}>) -> ()
+
+// CHECK: func private @quantile_i8_f32(quantile<i8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_i8_f32(quantile<i8:f32, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_f8_f32(quantile<f8E4M3FN:f32, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.000000e+00,0.000000e+00,1.000000e+00}>)
+func.func private @quantile_ui4_bf16(quantile<ui4:bf16, {-1.0,0.0,1.0}>) -> ()
+
+// CHECK: func private @quantile_as_return() -> quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>
+func.func private @quantile_as_return() -> quantile<ui4:f16, {-1.0,0.0,1.0}>
+
+// Verify use as memref element type (requires MemRefElementTypeInterface).
+// CHECK: func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_memref(memref<8xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify use as tensor element type.
+// CHECK: func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_tensor(tensor<16xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify use in multidimensional tensors.
+// CHECK: func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_ranked_tensor(tensor<16x16x1x1xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify use in unranked tensors.
+// CHECK: func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}>>)
+func.func private @quantile_in_unranked_tensor(tensor<*xquantile<ui4:f16, {-1.0,0.0,1.0}>>) -> ()
+
+// Verify NF4-style 16-entry quantile table
+// CHECK-LABEL: @nf4_16_values
+// CHECK-SAME: quantile<ui4:f16, {
+func.func private @nf4_16_values(quantile<ui4:f16, {
+ -1.0,-0.6961928009986877,-0.5250730514526367,-0.39491748809814453,
+ -0.28444138169288635,-0.18477343022823334,-0.09105003625154495,0.0,
+ 0.07958029955625534,0.16093020141124725,0.24611230194568634,
+ 0.33791524171829224,0.44070982933044434,0.5626170039176941,
+ 0.7229568362236023,1.0}>) -> ()
+
+// Verify negative-only quantile list.
+// CHECK: func private @quantile_negatives(quantile<si4:f32, {-1.000000e+00,-5.000000e-01,-2.500000e-01}>)
+func.func private @quantile_negatives(quantile<si4:f32, {-1.0,-0.5,-0.25}>) -> ()
+
+// Verify that a single quantile value is accepted.
+// CHECK: func private @quantile_single_value(quantile<si8:f32, {1.000000e+00}>)
+func.func private @quantile_single_value(quantile<si8:f32, {1.0}>) -> ()
>From 9699f90a6e53c4dbeb557d8ed66f3a6769cb8b9a Mon Sep 17 00:00:00 2001
From: vsimion26 <vlad.simion at intel.com>
Date: Fri, 3 Apr 2026 08:42:20 +0100
Subject: [PATCH 2/2] Added support for optional storage min/max, extended
testing for low fp storage types
---
.../mlir/Dialect/Quant/IR/QuantTypes.h | 6 +--
mlir/include/mlir/IR/BuiltinTypes.td | 18 +++++--
mlir/lib/AsmParser/TypeParser.cpp | 54 ++++++++++++++++++-
mlir/lib/Dialect/Quant/IR/TypeParser.cpp | 3 +-
mlir/lib/IR/AsmPrinter.cpp | 5 +-
mlir/lib/IR/BuiltinTypes.cpp | 22 +++++++-
mlir/lib/IR/TypeDetail.h | 31 +++++++----
.../Dialect/Quant/parse-uniform-invalid.mlir | 12 +++++
mlir/test/Dialect/Quant/parse-uniform.mlir | 19 +++++++
mlir/test/IR/invalid-quantile-types.mlir | 6 +++
mlir/test/IR/quantile-types.mlir | 8 +++
11 files changed, 163 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 34f47a15395c9..092a91e1c01d7 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -255,7 +255,7 @@ class AnyQuantizedType
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
-/// StorageType: 'i'|'u' NumBits
+/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
@@ -313,7 +313,7 @@ class UniformQuantizedType
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
-/// StorageType: 'i'|'u' NumBits
+/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
@@ -398,7 +398,7 @@ class UniformQuantizedPerAxisType
/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
/// ScaleZero ::= Scale (':' ZeroPoint)?
///
-/// StorageType: 'i'|'u' NumBits
+/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// AxisSpec: An integer value
/// BlockSizeSpec: An integer value
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 88df2620e4733..157362b0766bd 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1120,7 +1120,7 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
Syntax:
```
- quantile-type ::= `quantile` `<` type `:` type `,` `{` float-list `}` `>`
+ quantile-type ::= `quantile` `<` type ( `<` int `,` int `>` )? `:` type `,` `{` float-list `}` `>`
```
A quantile type represents a quantile-based floating point encoding, where
@@ -1129,6 +1129,10 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
`si8`), a quantile type (the expressed floating-point precision, e.g.,
`f16`, `f32`), and an array of quantile values that define the mapping.
+ Optionally, explicit minimum and maximum storage values can be specified
+ after the storage type as `<min,max>`, overriding the defaults derived from
+ the storage type width.
+
This type is used for weight compression schemes like NF4 (NormalFloat4)
and similar quantile-based formats.
@@ -1136,20 +1140,25 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
```mlir
quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
+ quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}><-8,7>
```
}];
let parameters = (ins
"Type":$storageType,
"Type":$quantileType,
- ArrayRefParameter<"double">:$quantiles
+ ArrayRefParameter<"double", "quantile values">:$quantiles,
+ OptionalParameter<"std::optional<int64_t>", "explicit storage minimum">:$storageMin,
+ OptionalParameter<"std::optional<int64_t>", "explicit storage maximum">:$storageMax
);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$storageType,
"Type":$quantileType,
- "ArrayRef<double>":$quantiles), [{
+ "ArrayRef<double>":$quantiles,
+ CArg<"std::optional<int64_t>", "std::nullopt">:$storageMin,
+ CArg<"std::optional<int64_t>", "std::nullopt">:$storageMax), [{
return $_get(storageType.getContext(), storageType, quantileType,
- quantiles);
+ quantiles, storageMin, storageMax);
}]>
];
let skipDefaultBuilders = 1;
@@ -1187,6 +1196,7 @@ def Builtin_Quantile : Builtin_Type<"Quantile", "quantile",
}];
}
+
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 906c2cf70b605..ad963a462db79 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include <cstdint>
#include <limits>
@@ -435,7 +436,58 @@ Type Parser::parseQuantileType() {
if (parseToken(Token::greater, "expected '>' in quantile type"))
return nullptr;
- return QuantileType::get(storageType, quantileType, quantiles);
+ // Optionally parse explicit storage range: `<min:max>`.
+ std::optional<int64_t> storageMin, storageMax;
+ if (consumeIf(Token::less)) {
+ int64_t minVal, maxVal;
+
+ // Parse minimum value (with optional sign).
+ bool minNegative = consumeIf(Token::minus);
+ if (!getToken().is(Token::integer))
+ return (emitWrongTokenError(
+ "expected integer minimum in quantile storage range"),
+ nullptr);
+ SMLoc minLoc = getToken().getLoc();
+ if (getToken().getSpelling().getAsInteger(10, minVal))
+ return nullptr;
+ consumeToken(Token::integer);
+ minVal = minNegative ? -minVal : minVal;
+
+ if (parseToken(Token::colon, "expected ':' in quantile storage range"))
+ return nullptr;
+
+ // Parse maximum value (with optional sign).
+ bool maxNegative = consumeIf(Token::minus);
+ if (!getToken().is(Token::integer))
+ return (emitWrongTokenError(
+ "expected integer maximum in quantile storage range"),
+ nullptr);
+ SMLoc maxLoc = getToken().getLoc();
+ if (getToken().getSpelling().getAsInteger(10, maxVal))
+ return nullptr;
+ consumeToken(Token::integer);
+ maxVal = maxNegative ? -maxVal : maxVal;
+
+ if (parseToken(Token::greater, "expected '>' after quantile storage range"))
+ return nullptr;
+
+ // Validate against the underlying storage type's inherent limits.
+ if (auto qsIface = llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
+ bool isSigned = qsIface.shouldDefaultToSigned();
+ if (minVal < qsIface.getDefaultMinimum(isSigned))
+ return (emitError(minLoc, "illegal storage type minimum: ") << minVal,
+ nullptr);
+ if (maxVal > qsIface.getDefaultMaximum(isSigned))
+ return (emitError(maxLoc, "illegal storage type maximum: ") << maxVal,
+ nullptr);
+ }
+
+ storageMin = minVal;
+ storageMax = maxVal;
+ }
+
+ return QuantileType::get(storageType, quantileType, quantiles, storageMin,
+ storageMax);
}
/// Parse a tensor type.
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 1a42b90ac31e2..5bffe43c818f9 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -320,7 +320,8 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// block-size-info `,` scale-zero-tensor `>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
-/// storage-type ::= (`i` | `u`) integer-literal
+/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN`
+// | `f4E2M1FN` | 'quantile'
/// expressed-type-spec ::= `:` `f` integer-literal
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= scale (`:` zero-point)?
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e1145eaafbc64..3f6600173fc5d 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2907,13 +2907,14 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(quantileTy.getQuantileType());
os << ", {";
ArrayRef<double> quantiles = quantileTy.getQuantiles();
- // interleaveComma(llvm::seq<size_t>(0, quantiles.size()),
- // [&](size_t i) { os << quantiles[i]; });
llvm::interleave(
llvm::seq<size_t>(0, quantiles.size()), os,
[&](size_t index) { os << quantiles[index]; }, ",");
os << "}";
os << '>';
+ if (auto minVal = quantileTy.getStorageMin())
+ if (auto maxVal = quantileTy.getStorageMax())
+ os << '<' << *minVal << ':' << *maxVal << '>';
})
.Case([&](TupleType tupleTy) {
os << "tuple<";
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index faf38e6a8354e..6429928b812a0 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -106,15 +106,28 @@ ArrayRef<double> QuantileType::getQuantiles() const {
return getImpl()->getQuantiles();
}
+std::optional<int64_t> QuantileType::getStorageMin() const {
+ return getImpl()->getStorageMin();
+}
+
+std::optional<int64_t> QuantileType::getStorageMax() const {
+ return getImpl()->getStorageMax();
+}
+
LogicalResult QuantileType::verify(function_ref<InFlightDiagnostic()> emitError,
Type storageType, Type quantileType,
- ArrayRef<double> quantiles) {
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin,
+ std::optional<int64_t> storageMax) {
if (!storageType.isIntOrFloat())
return emitError() << "storage type must be an integer or float type";
if (!llvm::isa<FloatType>(quantileType))
return emitError() << "quantile type must be a float type";
if (quantiles.empty())
return emitError() << "quantile values must not be empty";
+ if (storageMin.has_value() != storageMax.has_value())
+ return emitError()
+ << "storage min and max must both be specified or both omitted";
return success();
}
@@ -130,12 +143,16 @@ unsigned QuantileType::getStorageWidth() const {
}
int64_t QuantileType::getDefaultMaximum(bool isSigned) const {
+ if (auto explicitMax = getStorageMax())
+ return *explicitMax;
if (isSigned)
return (1LL << (getStorageWidth() - 1)) - 1;
return (1LL << getStorageWidth()) - 1;
}
int64_t QuantileType::getDefaultMinimum(bool isSigned) const {
+ if (auto explicitMin = getStorageMin())
+ return *explicitMin;
if (isSigned)
return -(1LL << (getStorageWidth() - 1));
return 0;
@@ -150,6 +167,9 @@ std::string QuantileType::getStorageTypeName(bool isSigned) const {
llvm::seq<size_t>(0, quantiles.size()), os,
[&](size_t index) { os << quantiles[index]; }, ",");
os << "}>";
+ if (auto minVal = getStorageMin())
+ if (auto maxVal = getStorageMax())
+ os << '<' << *minVal << ':' << *maxVal << '>';
return result;
}
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 6afcc6307d4b4..9e805ba25c3f4 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -151,35 +151,44 @@ unsigned getMemorySpaceAsInt(Attribute memorySpace);
/// Quantile Type Storage and Uniquing.
struct QuantileTypeStorage : public TypeStorage {
QuantileTypeStorage(Type storageType, Type quantileType,
- ArrayRef<double> quantiles)
+ ArrayRef<double> quantiles,
+ std::optional<int64_t> storageMin,
+ std::optional<int64_t> storageMax)
: storageType(storageType), quantileType(quantileType),
- quantilesData(quantiles.data()), numQuantiles(quantiles.size()) {}
+ quantilesData(quantiles.data()), numQuantiles(quantiles.size()),
+ storageMin(storageMin), storageMax(storageMax) {}
- /// The hash key used for uniquing.
- using KeyTy = std::tuple<Type, Type, ArrayRef<double>>;
+ using KeyTy = std::tuple<Type, Type, ArrayRef<double>, std::optional<int64_t>,
+ std::optional<int64_t>>;
static llvm::hash_code hashKey(const KeyTy &key) {
auto quantiles = std::get<2>(key);
- // Bit-cast doubles to int64_t for hashing since LLVM hashing
- // does not natively support double.
auto *quantilesBits = llvm::bit_cast<const int64_t *>(quantiles.data());
ArrayRef<int64_t> quantilesAsInts(quantilesBits, quantiles.size());
+ auto hashOptInt = [](std::optional<int64_t> opt) -> llvm::hash_code {
+ return opt ? llvm::hash_combine(true, *opt)
+ : llvm::hash_combine(false, int64_t{0});
+ };
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
llvm::hash_combine_range(quantilesAsInts.begin(),
- quantilesAsInts.end()));
+ quantilesAsInts.end()),
+ hashOptInt(std::get<3>(key)),
+ hashOptInt(std::get<4>(key)));
}
bool operator==(const KeyTy &key) const {
return storageType == std::get<0>(key) &&
quantileType == std::get<1>(key) &&
- getQuantiles() == std::get<2>(key);
+ getQuantiles() == std::get<2>(key) &&
+ storageMin == std::get<3>(key) && storageMax == std::get<4>(key);
}
static QuantileTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
ArrayRef<double> quantiles = allocator.copyInto(std::get<2>(key));
return new (allocator.allocate<QuantileTypeStorage>())
- QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles);
+ QuantileTypeStorage(std::get<0>(key), std::get<1>(key), quantiles,
+ std::get<3>(key), std::get<4>(key));
}
Type getStorageType() const { return storageType; }
@@ -187,11 +196,15 @@ struct QuantileTypeStorage : public TypeStorage {
ArrayRef<double> getQuantiles() const {
return ArrayRef<double>(quantilesData, numQuantiles);
}
+ std::optional<int64_t> getStorageMin() const { return storageMin; }
+ std::optional<int64_t> getStorageMax() const { return storageMax; }
Type storageType;
Type quantileType;
const double *quantilesData;
unsigned numQuantiles;
+ std::optional<int64_t> storageMin;
+ std::optional<int64_t> storageMax;
};
} // namespace detail
diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
index 6dbc86263bd71..329b1a98dc965 100644
--- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir
@@ -264,3 +264,15 @@
// Illegal storage min/max: min < defaultMin
// expected-error at +1 {{illegal storage type minimum: -10}}
!qalias = !quant.uniform<f4E2M1FN<-10:6>:f32, 0.99872:127>
+
+// -----
+
+// Illegal storage min/max: max > defaultMax
+// expected-error at +1 {{illegal storage type maximum: 10}}
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,1.0}><-6:100>:f32, 0.99872:127>
+
+// -----
+
+// Illegal storage min/max: min < defaultMin
+// expected-error at +1 {{illegal storage type minimum: -10}}
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0,1.0}><-100:6>:f32, 0.99872:127>
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 1431403aece41..68bd537185b74 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -227,3 +227,22 @@ func.func @parse() -> !qalias {
%0 = "foo"() : () -> !qalias
return %0 : !qalias
}
+
+// -----
+// Storage type: QuantileType
+// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}><-6:6>:f32, 9.987200e-01:127>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}><-6:6>:f32, 0.99872:127 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: QuantileType
+// CHECK: !quant.uniform<quantile<f4E2M1FN:f16, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-6.670000e-02,6.670000e-02,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}>:f32, 2.000000e+02>
+!qalias = !quant.uniform<quantile<f4E2M1FN:f16, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}>:f32, 2.0e+2 >
+func.func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
diff --git a/mlir/test/IR/invalid-quantile-types.mlir b/mlir/test/IR/invalid-quantile-types.mlir
index 5ee3e674b5cc1..500df1a83f275 100644
--- a/mlir/test/IR/invalid-quantile-types.mlir
+++ b/mlir/test/IR/invalid-quantile-types.mlir
@@ -37,3 +37,9 @@ func.func private @missing_rbrace() -> quantile<ui4:f16, {1.0>
// Test missing '>' closing the quantile type.
// expected-error @+1 {{expected '>' in quantile type}}
func.func private @missing_gt() -> quantile<ui4:f16, {1.0}
+
+// -----
+
+// Test missing '>' closing the storage range.
+// expected-error @below {{expected '>' after quantile storage range}}
+func.func private @missing_range_gt() -> quantile<ui4:f16, {1.0}><-8:7
diff --git a/mlir/test/IR/quantile-types.mlir b/mlir/test/IR/quantile-types.mlir
index 3c119ef91750d..bf6af88c5b83a 100644
--- a/mlir/test/IR/quantile-types.mlir
+++ b/mlir/test/IR/quantile-types.mlir
@@ -46,6 +46,14 @@ func.func private @nf4_16_values(quantile<ui4:f16, {
0.33791524171829224,0.44070982933044434,0.5626170039176941,
0.7229568362236023,1.0}>) -> ()
+// Verify explicit storage min/max range (unsigned storage, narrowed range).
+// CHECK: func private @quantile_with_range(quantile<ui4:f16, {-1.000000e+00,0.000000e+00,1.000000e+00}><0:7>)
+func.func private @quantile_with_range(quantile<ui4:f16, {-1.0,0.0,1.0}><0:7>) -> ()
+
+// Verify explicit range is preserved through round-trip.
+// CHECK: func private @quantile_signed_range(quantile<si8:f32, {-1.000000e+00,0.000000e+00,1.000000e+00}><-100:100>)
+func.func private @quantile_signed_range(quantile<si8:f32, {-1.0,0.0,1.0}><-100:100>) -> ()
+
// Verify negative-only quantile list.
// CHECK: func private @quantile_negatives(quantile<si4:f32, {-1.000000e+00,-5.000000e-01,-2.500000e-01}>)
func.func private @quantile_negatives(quantile<si4:f32, {-1.0,-0.5,-0.25}>) -> ()
More information about the Mlir-commits
mailing list