[Mlir-commits] [mlir] 1f6c4d8 - [mlir] Add Index Type, Floating Point Type and None Type subclasses to python bindings.

Stella Laurenzo llvmlistbot at llvm.org
Mon Aug 24 12:09:43 PDT 2020


Author: zhanghb97
Date: 2020-08-24T18:54:54Z
New Revision: 1f6c4d829c2dad147e30dcb0611eb9886dae9155

URL: https://github.com/llvm/llvm-project/commit/1f6c4d829c2dad147e30dcb0611eb9886dae9155
DIFF: https://github.com/llvm/llvm-project/commit/1f6c4d829c2dad147e30dcb0611eb9886dae9155.diff

LOG: [mlir] Add Index Type, Floating Point Type and None Type subclasses to python bindings.

Based on the PyType and PyConcreteType classes, this patch implements the bindings of Index Type, Floating Point Type and None Type subclasses.
These three subclasses share the same binding strategy:
- The function pointer `isaFunction` points to `mlirTypeIsA***`.
- The `mlir***TypeGet` C API is bound with the `***Type` constructor in the python side.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D86466

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index ae48e33d3530..2f5735f83975 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -305,6 +305,102 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
   }
 };
 
+/// Index Type subclass - IndexType.
+class PyIndexType : public PyConcreteType<PyIndexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+  static constexpr const char *pyClassName = "IndexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def(py::init([](PyMlirContext &context) {
+            MlirType t = mlirIndexTypeGet(context.context);
+            return PyIndexType(t);
+          }),
+          py::keep_alive<0, 1>(), "Create a index type.");
+  }
+};
+
+/// Floating Point Type subclass - BF16Type.
+class PyBF16Type : public PyConcreteType<PyBF16Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+  static constexpr const char *pyClassName = "BF16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def(py::init([](PyMlirContext &context) {
+            MlirType t = mlirBF16TypeGet(context.context);
+            return PyBF16Type(t);
+          }),
+          py::keep_alive<0, 1>(), "Create a bf16 type.");
+  }
+};
+
+/// Floating Point Type subclass - F16Type.
+class PyF16Type : public PyConcreteType<PyF16Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+  static constexpr const char *pyClassName = "F16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def(py::init([](PyMlirContext &context) {
+            MlirType t = mlirF16TypeGet(context.context);
+            return PyF16Type(t);
+          }),
+          py::keep_alive<0, 1>(), "Create a f16 type.");
+  }
+};
+
+/// Floating Point Type subclass - F32Type.
+class PyF32Type : public PyConcreteType<PyF32Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+  static constexpr const char *pyClassName = "F32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def(py::init([](PyMlirContext &context) {
+            MlirType t = mlirF32TypeGet(context.context);
+            return PyF32Type(t);
+          }),
+          py::keep_alive<0, 1>(), "Create a f32 type.");
+  }
+};
+
+/// Floating Point Type subclass - F64Type.
+class PyF64Type : public PyConcreteType<PyF64Type> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+  static constexpr const char *pyClassName = "F64Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def(py::init([](PyMlirContext &context) {
+            MlirType t = mlirF64TypeGet(context.context);
+            return PyF64Type(t);
+          }),
+          py::keep_alive<0, 1>(), "Create a f64 type.");
+  }
+};
+
+/// None Type subclass - NoneType.
+class PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+  static constexpr const char *pyClassName = "NoneType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def(py::init([](PyMlirContext &context) {
+            MlirType t = mlirNoneTypeGet(context.context);
+            return PyNoneType(t);
+          }),
+          py::keep_alive<0, 1>(), "Create a none type.");
+  }
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -489,4 +585,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Standard type bindings.
   PyIntegerType::bind(m);
+  PyIndexType::bind(m);
+  PyBF16Type::bind(m);
+  PyF16Type::bind(m);
+  PyF32Type::bind(m);
+  PyF64Type::bind(m);
+  PyNoneType::bind(m);
 }

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 1dce0a95c812..32e26c57518a 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -124,3 +124,33 @@ def testIntegerType():
   print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
 
 run(testIntegerType)
+
+# CHECK-LABEL: TEST: testIndexType
+def testIndexType():
+  ctx = mlir.ir.Context()
+  # CHECK: index type: index
+  print("index type:", mlir.ir.IndexType(ctx))
+
+run(testIndexType)
+
+# CHECK-LABEL: TEST: testFloatType
+def testFloatType():
+  ctx = mlir.ir.Context()
+  # CHECK: float: bf16
+  print("float:", mlir.ir.BF16Type(ctx))
+  # CHECK: float: f16
+  print("float:", mlir.ir.F16Type(ctx))
+  # CHECK: float: f32
+  print("float:", mlir.ir.F32Type(ctx))
+  # CHECK: float: f64
+  print("float:", mlir.ir.F64Type(ctx))
+
+run(testFloatType)
+
+# CHECK-LABEL: TEST: testNoneType
+def testNoneType():
+  ctx = mlir.ir.Context()
+  # CHECK: none type: none
+  print("none type:", mlir.ir.NoneType(ctx))
+
+run(testNoneType)


        


More information about the Mlir-commits mailing list