[Mlir-commits] [mlir] 4d29f6e - [mlir][python] Expose fp8 types with pybind.
Mehdi Amini
llvmlistbot at llvm.org
Tue Jan 3 11:19:04 PST 2023
Author: Qiao Zhang
Date: 2023-01-03T19:18:46Z
New Revision: 4d29f6ed6e73609f2f181d048d8157aeba5b73ca
URL: https://github.com/llvm/llvm-project/commit/4d29f6ed6e73609f2f181d048d8157aeba5b73ca
DIFF: https://github.com/llvm/llvm-project/commit/4d29f6ed6e73609f2f181d048d8157aeba5b73ca.diff
LOG: [mlir][python] Expose fp8 types with pybind.
Expose fp8 types with pybind.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D140746
Added:
Modified:
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7a41cb1e8134f..10527af6c9eac 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -102,6 +102,42 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
}
};
+/// Floating Point Type subclass - Float8E4M3FNType.
+class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr const char *pyClassName = "Float8E4M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+ return PyFloat8E4M3FNType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
+ }
+};
+
+/// Floating Point Type subclass - Float8M5E2Type.
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr const char *pyClassName = "Float8E5M2Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2TypeGet(context->get());
+ return PyFloat8E5M2Type(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a float8_e5m2 type.");
+ }
+};
+
/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
@@ -663,6 +699,8 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyIndexType::bind(m);
+ PyFloat8E4M3FNType::bind(m);
+ PyFloat8E5M2Type::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyF32Type::bind(m);
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 60bc3676f398c..505946ca1a843 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -50,6 +50,8 @@ __all__ = [
"DiagnosticHandler",
"DiagnosticSeverity",
"DictAttr",
+ "Float8E4M3FNType",
+ "Float8E5M2Type",
"F16Type",
"F32Type",
"F64Type",
@@ -577,6 +579,20 @@ class DictAttr(Attribute):
@property
def type(self) -> Type: ...
+class Float8E4M3FNType(Type):
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @staticmethod
+ def get(*args, **kwargs) -> Float8E4M3FNType: ...
+ @staticmethod
+ def isinstance(arg: Any) -> bool: ...
+
+class Float8E5M2Type(Type):
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @staticmethod
+ def get(*args, **kwargs) -> Float8E5M2Type: ...
+ @staticmethod
+ def isinstance(arg: Any) -> bool: ...
+
# TODO: Auto-generated. Audit and fix.
class F16Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 91c820f121b9f..e160216cb15f4 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -193,6 +193,10 @@ def testIndexType():
@run
def testFloatType():
with Context():
+ # CHECK: float: f8E4M3FN
+ print("float:", Float8E4M3FNType.get())
+ # CHECK: float: f8E5M2
+ print("float:", Float8E5M2Type.get())
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16
More information about the Mlir-commits
mailing list