[Mlir-commits] [mlir] e08ca4b - Add Float8E4M3FN type to MLIR.

Benjamin Kramer llvmlistbot at llvm.org
Wed Nov 16 01:30:40 PST 2022


Author: Reed
Date: 2022-11-16T10:24:25+01:00
New Revision: e08ca4bb1dfec860eefce636f4eff472fc7081ea

URL: https://github.com/llvm/llvm-project/commit/e08ca4bb1dfec860eefce636f4eff472fc7081ea
DIFF: https://github.com/llvm/llvm-project/commit/e08ca4bb1dfec860eefce636f4eff472fc7081ea.diff

LOG: Add Float8E4M3FN type to MLIR.

The paper https://arxiv.org/abs/2209.05433 introduces two new FP8 dtypes: E5M2 (called Float8E5M2 in LLVM) and E4M3 (called Float8E4M3FN in LLVM). Support for Float8E5M2 in APFloat and MLIR was added in https://reviews.llvm.org/D133823. Support for Float8E4M3FN in APFloat was added in https://reviews.llvm.org/D137760. This change adds Float8E4M3FN to MLIR as well.

There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.

This change is identical to the MLIR changes in the patch that added Float8E5M2, except that Float8E4M3FN is added instead.

Reviewed By: stellaraccident, bkramer, rriddle

Differential Revision: https://reviews.llvm.org/D138075

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/AsmParser/TokenKinds.def
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Types.cpp
    mlir/test/IR/attribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 9bd3d510b2483..1c4a1638205be 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -74,6 +74,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
 
+/// Checks whether the given type is an f8E4M3FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
+
+/// Creates an f8E4M3FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 870c834ce2b0d..f2a547ef4f75b 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
 
   // Types.
   FloatType getFloat8E5M2Type();
+  FloatType getFloat8E4M3FNType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 1925127251558..ceba71d517589 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -47,6 +47,7 @@ class FloatType : public Type {
   static FloatType getF80(MLIRContext *ctx);
   static FloatType getF128(MLIRContext *ctx);
   static FloatType getFloat8E5M2(MLIRContext *ctx);
+  static FloatType getFloat8E4M3FN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -374,14 +375,18 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
-                  Float64Type, Float80Type, Float128Type>();
+  return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
+                  Float32Type, Float64Type, Float80Type, Float128Type>();
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
   return Float8E5M2Type::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
+  return Float8E4M3FNType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 50d8b3a0cb44a..fbd9c6350fcf2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -89,7 +89,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
       * bit encoding: S1E5M2
       * exponent bias: 15
       * infinities: supported with exponent set to all 1s and mantissa 0s
-      * NaNs: supported with exponent bits set to all 1s and mantissa of 
+      * NaNs: supported with exponent bits set to all 1s and mantissa of
         (01, 10, or 11)
       * denormals when exponent is 0
 
@@ -97,6 +97,27 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E4M3FNType
+
+def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
+  let summary = "8-bit floating point with 3 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it follows
+    similar conventions, with the exception that there are no infinity values
+    and only two NaN representations. This type has the following
+    characteristics:
+
+      * bit encoding: S1E4M3
+      * exponent bias: 7
+      * infinities: Not supported
+      * NaNs: supported with exponent bits and mantissa bits set to all 1s
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2209.05433
+  }];
+}
 
 //===----------------------------------------------------------------------===//
 // BFloat16Type

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 1c4db1b6c0f95..9d64a77742ef1 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -124,6 +124,7 @@ class Type {
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
   bool isFloat8E5M2() const;
+  bool isFloat8E4M3FN() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;

diff  --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 02eba88f78b0d..9bd7b60afd282 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -94,6 +94,7 @@ TOK_KEYWORD(f32)
 TOK_KEYWORD(f64)
 TOK_KEYWORD(f80)
 TOK_KEYWORD(f8E5M2)
+TOK_KEYWORD(f8E4M3FN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index fa428b2f06fab..fc8c3fdbb58d7 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -31,6 +31,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_vector:
   case Token::inttype:
   case Token::kw_f8E5M2:
+  case Token::kw_f8E4M3FN:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -290,6 +291,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E5M2:
     consumeToken(Token::kw_f8E5M2);
     return builder.getFloat8E5M2Type();
+  case Token::kw_f8E4M3FN:
+    consumeToken(Token::kw_f8E4M3FN);
+    return builder.getFloat8E4M3FNType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index ad9a5bc6640e2..596a760b99e89 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -76,6 +76,14 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
 }
 
+bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
+  return unwrap(type).isFloat8E4M3FN();
+}
+
+MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9a3d3e031dc31..32a26470d94a5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2244,6 +2244,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       })
       .Case<IndexType>([&](Type) { os << "index"; })
       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
+      .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 053ffce1b1579..2f4e07990a0d1 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -37,6 +37,10 @@ FloatType Builder::getFloat8E5M2Type() {
   return FloatType::getFloat8E5M2(context);
 }
 
+FloatType Builder::getFloat8E4M3FNType() {
+  return FloatType::getFloat8E4M3FN(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d65c5e9d28b1e..f4d64c97836d1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,7 +88,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  if (isa<Float8E5M2Type>())
+  if (isa<Float8E5M2Type, Float8E4M3FNType>())
     return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
@@ -107,6 +107,8 @@ unsigned FloatType::getWidth() {
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
   if (isa<Float8E5M2Type>())
     return APFloat::Float8E5M2();
+  if (isa<Float8E4M3FNType>())
+    return APFloat::Float8E4M3FN();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 182e249810e13..298f722da361c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -207,6 +207,7 @@ class MLIRContextImpl {
 
   /// Cached Type Instances.
   Float8E5M2Type f8E5M2Ty;
+  Float8E4M3FNType f8E4M3FNTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -278,6 +279,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   //// Types.
   /// Floating-point Types.
   impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
+  impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -861,6 +863,9 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
   return context->getImpl().f8E5M2Ty;
 }
+Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
+  return context->getImpl().f8E4M3FNTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index b97388bf33f52..670974bbf8373 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -19,6 +19,7 @@ using namespace mlir::detail;
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
 bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
+bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 540578e61527a..ebfbb89825030 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -40,6 +40,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E5M2
     float_attr = 2. : f8E5M2
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E4M3FN
+    float_attr = 2. : f8E4M3FN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16


        


More information about the Mlir-commits mailing list