[Mlir-commits] [mlir] e08ca4b - Add Float8E4M3FN type to MLIR.
Benjamin Kramer
llvmlistbot at llvm.org
Wed Nov 16 01:30:40 PST 2022
Author: Reed
Date: 2022-11-16T10:24:25+01:00
New Revision: e08ca4bb1dfec860eefce636f4eff472fc7081ea
URL: https://github.com/llvm/llvm-project/commit/e08ca4bb1dfec860eefce636f4eff472fc7081ea
DIFF: https://github.com/llvm/llvm-project/commit/e08ca4bb1dfec860eefce636f4eff472fc7081ea.diff
LOG: Add Float8E4M3FN type to MLIR.
The paper https://arxiv.org/abs/2209.05433 introduces two new FP8 dtypes: E5M2 (called Float8E5M2 in LLVM) and E4M3 (called Float8E4M3FN in LLVM). Support for Float8E5M2 in APFloat and MLIR was added in https://reviews.llvm.org/D133823. Support for Float8E4M3FN in APFloat was added in https://reviews.llvm.org/D137760. This change adds Float8E4M3FN to MLIR as well.
There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.
This change is identical to the MLIR changes in the patch that added Float8E5M2, except that Float8E4M3FN is added instead.
Reviewed By: stellaraccident, bkramer, rriddle
Differential Revision: https://reviews.llvm.org/D138075
Added:
Modified:
mlir/include/mlir-c/BuiltinTypes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/Types.h
mlir/lib/AsmParser/TokenKinds.def
mlir/lib/AsmParser/TypeParser.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Types.cpp
mlir/test/IR/attribute.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 9bd3d510b2483..1c4a1638205be 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -74,6 +74,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
+/// Checks whether the given type is an f8E4M3FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
+
+/// Creates an f8E4M3FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
+
/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 870c834ce2b0d..f2a547ef4f75b 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
// Types.
FloatType getFloat8E5M2Type();
+ FloatType getFloat8E4M3FNType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 1925127251558..ceba71d517589 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -47,6 +47,7 @@ class FloatType : public Type {
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FN(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
@@ -374,14 +375,18 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
- return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
- Float64Type, Float80Type, Float128Type>();
+ return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
+ Float32Type, Float64Type, Float80Type, Float128Type>();
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
return Float8E5M2Type::get(ctx);
}
+inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
+ return Float8E4M3FNType::get(ctx);
+}
+
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 50d8b3a0cb44a..fbd9c6350fcf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -89,7 +89,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
* bit encoding: S1E5M2
* exponent bias: 15
* infinities: supported with exponent set to all 1s and mantissa 0s
- * NaNs: supported with exponent bits set to all 1s and mantissa of
+ * NaNs: supported with exponent bits set to all 1s and mantissa of
(01, 10, or 11)
* denormals when exponent is 0
@@ -97,6 +97,27 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Float8E4M3FNType
+
+def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
+ let summary = "8-bit floating point with 3 bit mantissa";
+ let description = [{
+ An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
+ mantissa. This is not a standard type as defined by IEEE-754, but it follows
+ similar conventions, with the exception that there are no infinity values
+ and only two NaN representations. This type has the following
+ characteristics:
+
+ * bit encoding: S1E4M3
+ * exponent bias: 7
+ * infinities: Not supported
+ * NaNs: supported with exponent bits and mantissa bits set to all 1s
+ * denormals when exponent is 0
+
+ Described in: https://arxiv.org/abs/2209.05433
+ }];
+}
//===----------------------------------------------------------------------===//
// BFloat16Type
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 1c4db1b6c0f95..9d64a77742ef1 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -124,6 +124,7 @@ class Type {
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isFloat8E5M2() const;
+ bool isFloat8E4M3FN() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 02eba88f78b0d..9bd7b60afd282 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -94,6 +94,7 @@ TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
TOK_KEYWORD(f80)
TOK_KEYWORD(f8E5M2)
+TOK_KEYWORD(f8E4M3FN)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index fa428b2f06fab..fc8c3fdbb58d7 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_vector:
case Token::inttype:
case Token::kw_f8E5M2:
+ case Token::kw_f8E4M3FN:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
@@ -290,6 +291,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getFloat8E5M2Type();
+ case Token::kw_f8E4M3FN:
+ consumeToken(Token::kw_f8E4M3FN);
+ return builder.getFloat8E4M3FNType();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index ad9a5bc6640e2..596a760b99e89 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -76,6 +76,14 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}
+bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
+ return unwrap(type).isFloat8E4M3FN();
+}
+
+MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
+ return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
+}
+
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9a3d3e031dc31..32a26470d94a5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2244,6 +2244,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<IndexType>([&](Type) { os << "index"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
+ .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 053ffce1b1579..2f4e07990a0d1 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -37,6 +37,10 @@ FloatType Builder::getFloat8E5M2Type() {
return FloatType::getFloat8E5M2(context);
}
+FloatType Builder::getFloat8E4M3FNType() {
+ return FloatType::getFloat8E4M3FN(context);
+}
+
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d65c5e9d28b1e..f4d64c97836d1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,7 +88,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
- if (isa<Float8E5M2Type>())
+ if (isa<Float8E5M2Type, Float8E4M3FNType>())
return 8;
if (isa<Float16Type, BFloat16Type>())
return 16;
@@ -107,6 +107,8 @@ unsigned FloatType::getWidth() {
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (isa<Float8E5M2Type>())
return APFloat::Float8E5M2();
+ if (isa<Float8E4M3FNType>())
+ return APFloat::Float8E4M3FN();
if (isa<BFloat16Type>())
return APFloat::BFloat();
if (isa<Float16Type>())
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 182e249810e13..298f722da361c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -207,6 +207,7 @@ class MLIRContextImpl {
/// Cached Type Instances.
Float8E5M2Type f8E5M2Ty;
+ Float8E4M3FNType f8E4M3FNTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
Float32Type f32Ty;
@@ -278,6 +279,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
//// Types.
/// Floating-point Types.
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
+ impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -861,6 +863,9 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
return context->getImpl().f8E5M2Ty;
}
+Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
+ return context->getImpl().f8E4M3FNTy;
+}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index b97388bf33f52..670974bbf8373 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -19,6 +19,7 @@ using namespace mlir::detail;
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
+bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 540578e61527a..ebfbb89825030 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -40,6 +40,10 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2
float_attr = 2. : f8E5M2
} : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 2.000000e+00 : f8E4M3FN
+ float_attr = 2. : f8E4M3FN
+ } : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f16
float_attr = 2. : f16
More information about the Mlir-commits
mailing list