[Mlir-commits] [mlir] 4d29f6e - [mlir][python] Expose fp8 types with pybind.

Mehdi Amini llvmlistbot at llvm.org
Tue Jan 3 11:19:04 PST 2023


Author: Qiao Zhang
Date: 2023-01-03T19:18:46Z
New Revision: 4d29f6ed6e73609f2f181d048d8157aeba5b73ca

URL: https://github.com/llvm/llvm-project/commit/4d29f6ed6e73609f2f181d048d8157aeba5b73ca
DIFF: https://github.com/llvm/llvm-project/commit/4d29f6ed6e73609f2f181d048d8157aeba5b73ca.diff

LOG: [mlir][python] Expose fp8 types with pybind.

Expose fp8 types with pybind.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D140746

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7a41cb1e8134f..10527af6c9eac 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -102,6 +102,42 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
   }
 };
 
+/// Floating Point Type subclass - Float8E4M3FNType.
+class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+  static constexpr const char *pyClassName = "Float8E4M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+          return PyFloat8E4M3FNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
+  }
+};
+
+/// Floating Point Type subclass - Float8M5E2Type.
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+  static constexpr const char *pyClassName = "Float8E5M2Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E5M2TypeGet(context->get());
+          return PyFloat8E5M2Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e5m2 type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type> {
 public:
@@ -663,6 +699,8 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
 void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
   PyIndexType::bind(m);
+  PyFloat8E4M3FNType::bind(m);
+  PyFloat8E5M2Type::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyF32Type::bind(m);

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 60bc3676f398c..505946ca1a843 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -50,6 +50,8 @@ __all__ = [
     "DiagnosticHandler",
     "DiagnosticSeverity",
     "DictAttr",
+    "Float8E4M3FNType",
+    "Float8E5M2Type",
     "F16Type",
     "F32Type",
     "F64Type",
@@ -577,6 +579,20 @@ class DictAttr(Attribute):
     @property
     def type(self) -> Type: ...
 
+class Float8E4M3FNType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E4M3FNType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
+class Float8E5M2Type(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E5M2Type: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
 # TODO: Auto-generated. Audit and fix.
 class F16Type(Type):
     def __init__(self, cast_from_type: Type) -> None: ...

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 91c820f121b9f..e160216cb15f4 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -193,6 +193,10 @@ def testIndexType():
 @run
 def testFloatType():
   with Context():
+    # CHECK: float: f8E4M3FN
+    print("float:", Float8E4M3FNType.get())
+    # CHECK: float: f8E5M2
+    print("float:", Float8E5M2Type.get())
     # CHECK: float: bf16
     print("float:", BF16Type.get())
     # CHECK: float: f16


        


More information about the Mlir-commits mailing list