[Mlir-commits] [mlir] 96267b6 - [mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR

Chris Jackson llvmlistbot at llvm.org
Mon Feb 13 10:26:53 PST 2023


Author: Jake Hall
Date: 2023-02-13T18:26:27Z
New Revision: 96267b6b88405c9222e69aadb669461533bc1352

URL: https://github.com/llvm/llvm-project/commit/96267b6b88405c9222e69aadb669461533bc1352
DIFF: https://github.com/llvm/llvm-project/commit/96267b6b88405c9222e69aadb669461533bc1352.diff

LOG: [mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR

Float8E5M2FNUZ and Float8E4M3FNUZ have been added to APFloat in D141863.
This change adds these types as MLIR builtin types alongside Float8E5M2
and Float8E4M3FN (added in D133823 and D138075).

Reviewed By: krzysz00

Differential Revision: https://reviews.llvm.org/D143744

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/AsmParser/TokenKinds.def
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Types.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/IR/attribute.mlir
    mlir/test/python/ir/builtin_types.py
    mlir/utils/lldb-scripts/mlirDataFormatters.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 000397505479..8b855d8c39a4 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -81,6 +81,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
 
+/// Checks whether the given type is an f8E5M2FNUZ type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
+
+/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
+
+/// Checks whether the given type is an f8E4M3FNUZ type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
+
+/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index e0d33dd9271f..14df7b09032a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -62,6 +62,8 @@ class Builder {
   // Types.
   FloatType getFloat8E5M2Type();
   FloatType getFloat8E4M3FNType();
+  FloatType getFloat8E5M2FNUZType();
+  FloatType getFloat8E4M3FNUZType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 135fa9b559d8..33995f34ee39 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -47,6 +47,8 @@ class FloatType : public Type {
   static FloatType getF128(MLIRContext *ctx);
   static FloatType getFloat8E5M2(MLIRContext *ctx);
   static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+  static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+  static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -374,8 +376,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
-                  Float32Type, Float64Type, Float80Type, Float128Type>();
+  return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+                  Float8E4M3FNUZType, BFloat16Type, Float16Type, Float32Type,
+                  Float64Type, Float80Type, Float128Type>();
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -386,6 +389,14 @@ inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
   return Float8E4M3FNType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {
+  return Float8E5M2FNUZType::get(ctx);
+}
+
+inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
+  return Float8E4M3FNUZType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 5f9141d71b69..a8d1fae6a5a2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -118,6 +118,50 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E5M2FNUZType
+
+def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> {
+  let summary = "8-bit floating point with 2 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it follows
+    similar conventions, with the exception that there are no infinity values,
+    no negative zero, and only one NaN representation. This type has the
+    following characteristics:
+
+      * bit encoding: S1E5M2
+      * exponent bias: 16
+      * infinities: Not supported
+      * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2206.02915
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Float8E4M3FNUZType
+
+def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> {
+  let summary = "8-bit floating point with 3 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it follows
+    similar conventions, with the exception that there are no infinity values,
+    no negative zero, and only one NaN representation. This type has the
+    following characteristics:
+
+      * bit encoding: S1E4M3
+      * exponent bias: 8
+      * infinities: Not supported
+      * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2209.05433
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index cd888ac61f4d..527ccc05e090 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -488,6 +488,10 @@ def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
                BuildableType<"$_builder.getFloat8E4M3FNType()">;
 def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
              BuildableType<"$_builder.getFloat8E5M2Type()">;
+def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
+                 BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
+def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
+                 BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
 
 def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
                       "complex-type", "::mlir::ComplexType">;

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 2a0586c455eb..9f30ce1620b4 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -122,6 +122,8 @@ class Type {
   bool isIndex() const;
   bool isFloat8E5M2() const;
   bool isFloat8E4M3FN() const;
+  bool isFloat8E5M2FNUZ() const;
+  bool isFloat8E4M3FNUZ() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;

diff  --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 9bd7b60afd28..0e666c792b9d 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -95,6 +95,8 @@ TOK_KEYWORD(f64)
 TOK_KEYWORD(f80)
 TOK_KEYWORD(f8E5M2)
 TOK_KEYWORD(f8E4M3FN)
+TOK_KEYWORD(f8E5M2FNUZ)
+TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index fab7244d1b72..47078c1ba047 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -33,6 +33,8 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::inttype:
   case Token::kw_f8E5M2:
   case Token::kw_f8E4M3FN:
+  case Token::kw_f8E5M2FNUZ:
+  case Token::kw_f8E4M3FNUZ:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -295,6 +297,12 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E4M3FN:
     consumeToken(Token::kw_f8E4M3FN);
     return builder.getFloat8E4M3FNType();
+  case Token::kw_f8E5M2FNUZ:
+    consumeToken(Token::kw_f8E5M2FNUZ);
+    return builder.getFloat8E5M2FNUZType();
+  case Token::kw_f8E4M3FNUZ:
+    consumeToken(Token::kw_f8E4M3FNUZ);
+    return builder.getFloat8E4M3FNUZType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 3cc226d7aa2c..87ffe593655b 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -139,6 +139,42 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
   }
 };
 
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+          return PyFloat8E4M3FNUZType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
+  }
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+          return PyFloat8E5M2FNUZType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type> {
 public:
@@ -700,6 +736,8 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyIndexType::bind(m);
   PyFloat8E4M3FNType::bind(m);
   PyFloat8E5M2Type::bind(m);
+  PyFloat8E4M3FNUZType::bind(m);
+  PyFloat8E5M2FNUZType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyF32Type::bind(m);

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 73a3ec414876..aea1221200af 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -84,6 +84,22 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
 }
 
+bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
+  return unwrap(type).isFloat8E5M2FNUZ();
+}
+
+MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
+}
+
+bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
+  return unwrap(type).isFloat8E4M3FNUZ();
+}
+
+MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1ce617c0428f..8c5bb3021577 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2410,6 +2410,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<IndexType>([&](Type) { os << "index"; })
       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
+      .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
+      .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 38f0501adfbf..d36791fef23d 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -41,6 +41,14 @@ FloatType Builder::getFloat8E4M3FNType() {
   return FloatType::getFloat8E4M3FN(context);
 }
 
+FloatType Builder::getFloat8E5M2FNUZType() {
+  return FloatType::getFloat8E5M2FNUZ(context);
+}
+
+FloatType Builder::getFloat8E4M3FNUZType() {
+  return FloatType::getFloat8E4M3FNUZ(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 238b5bbb4eae..6e6c6b9683c7 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,7 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  if (isa<Float8E5M2Type, Float8E4M3FNType>())
+  if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType>())
     return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
@@ -109,6 +110,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E5M2();
   if (isa<Float8E4M3FNType>())
     return APFloat::Float8E4M3FN();
+  if (isa<Float8E5M2FNUZType>())
+    return APFloat::Float8E5M2FNUZ();
+  if (isa<Float8E4M3FNUZType>())
+    return APFloat::Float8E4M3FNUZ();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 5bbedda19641..176a8abe16fc 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -209,6 +209,8 @@ class MLIRContextImpl {
   /// Cached Type Instances.
   Float8E5M2Type f8E5M2Ty;
   Float8E4M3FNType f8E4M3FNTy;
+  Float8E5M2FNUZType f8E5M2FNUZTy;
+  Float8E4M3FNUZType f8E4M3FNUZTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -281,6 +283,8 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   /// Floating-point Types.
   impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
+  impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
+  impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -870,6 +874,12 @@ Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
 Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3FNTy;
 }
+Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
+  return context->getImpl().f8E5M2FNUZTy;
+}
+Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
+  return context->getImpl().f8E4M3FNUZTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 070ed4b14686..e739786bd399 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -36,6 +36,8 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
 bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
 bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
+bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
+bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 63a3125ec715..7d5ff23f60ab 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -52,6 +52,8 @@ __all__ = [
     "DictAttr",
     "Float8E4M3FNType",
     "Float8E5M2Type",
+    "Float8E4M3FNUZType",
+    "Float8E5M2FNUZType",
     "F16Type",
     "F32Type",
     "F64Type",
@@ -593,6 +595,20 @@ class Float8E5M2Type(Type):
     @staticmethod
     def isinstance(arg: Any) -> bool: ...
 
+class Float8E4M3FNUZType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E4M3FNUZType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
+class Float8E5M2FNUZType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E5M2FNUZType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
 # TODO: Auto-generated. Audit and fix.
 class F16Type(Type):
     def __init__(self, cast_from_type: Type) -> None: ...

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index d494824ec7e7..de840f950f45 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -44,6 +44,14 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E4M3FN
     float_attr = 2. : f8E4M3FN
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
+    float_attr = 2. : f8E5M2FNUZ
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
+    float_attr = 2. : f8E4M3FNUZ
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index e160216cb15f..7af81859e3e7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -197,6 +197,10 @@ def testFloatType():
     print("float:", Float8E4M3FNType.get())
     # CHECK: float: f8E5M2
     print("float:", Float8E5M2Type.get())
+    # CHECK: float: f8E5M2FNUZ
+    print("float:", Float8E5M2FNUZType.get())
+    # CHECK: float: f8E4M3FNUZ
+    print("float:", Float8E4M3FNUZType.get())
     # CHECK: float: bf16
     print("float:", BF16Type.get())
     # CHECK: float: f16

diff  --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index f04516b18111..908a734f6e30 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -52,6 +52,8 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::UnknownLoc": '"loc(unknown)"',
     "mlir::Float8E5M2Type": '"f8E5M2"',
     "mlir::Float8E4M3FNType": '"f8E4M3FN"',
+    "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
+    "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::Float32Type": '"f32"',


        


More information about the Mlir-commits mailing list