[Mlir-commits] [mlir] d29d1e2 - Add python bindings for Type and IntegerType.

Stella Laurenzo llvmlistbot at llvm.org
Wed Aug 19 09:24:15 PDT 2020


Author: Stella Laurenzo
Date: 2020-08-19T09:23:44-07:00
New Revision: d29d1e2ffd61f450b4392d2dab8060e54d040fcf

URL: https://github.com/llvm/llvm-project/commit/d29d1e2ffd61f450b4392d2dab8060e54d040fcf
DIFF: https://github.com/llvm/llvm-project/commit/d29d1e2ffd61f450b4392d2dab8060e54d040fcf.diff

LOG: Add python bindings for Type and IntegerType.

* The binding for Type is trivial and should be non-controversial.
* The way that I define the IntegerType should serve as a pattern for what I want to do next.
* I propose defining the rest of the standard types in this fashion and then generalizing for dialect types as necessary.
* Essentially, creating/accessing a concrete Type (vs interacting with the string form) is done by "casting" to the concrete type (i.e. IntegerType can be constructed with a Type and will throw if the cast is illegal).
* This deviates from some of our previous discussions about global objects but I think produces a usable API and we should go this way.

Differential Revision: https://reviews.llvm.org/D86179

Added: 
    mlir/test/Bindings/Python/ir_module.py
    mlir/test/Bindings/Python/ir_types.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h

Removed: 
    mlir/test/Bindings/Python/ir_module_test.py


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 68546bf35625..d97491b9f08a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -116,6 +116,9 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location);
 /** Parses a module from the string and transfers ownership to the caller. */
 MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
 
+/** Checks whether a module is null. */
+inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
+
 /** Takes a module owned by the caller and deletes it. */
 void mlirModuleDestroy(MlirModule module);
 
@@ -312,6 +315,9 @@ void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
 /** Parses a type. The type is owned by the context. */
 MlirType mlirTypeParseGet(MlirContext context, const char *type);
 
+/** Checks whether a type is null. */
+inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
+
 /** Checks if two types are equal. */
 int mlirTypeEqual(MlirType t1, MlirType t2);
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 27e1854e7455..bdce390188fa 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -9,7 +9,10 @@
 #include "IRModules.h"
 #include "PybindUtils.h"
 
+#include "mlir-c/StandardTypes.h"
+
 namespace py = pybind11;
+using namespace mlir;
 using namespace mlir::python;
 
 //------------------------------------------------------------------------------
@@ -20,6 +23,15 @@ static const char kContextParseDocstring[] =
     R"(Parses a module's assembly format from a string.
 
 Returns a new MlirModule or raises a ValueError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kContextParseType[] = R"(Parses the assembly form of a type.
+
+Returns a Type object or raises a ValueError if the type cannot be parsed.
+
+See also: https://mlir.llvm.org/docs/LangRef/#type-system
 )";
 
 static const char kOperationStrDunderDocstring[] =
@@ -30,6 +42,9 @@ use the dedicated print method, which supports keyword arguments to customize
 behavior.
 )";
 
+static const char kTypeStrDunderDocstring[] =
+    R"(Prints the assembly form of the type.)";
+
 static const char kDumpDocstring[] =
     R"(Dumps a debug representation of the object to stderr.)";
 
@@ -64,39 +79,154 @@ struct PyPrintAccumulator {
 } // namespace
 
 //------------------------------------------------------------------------------
-// Context Wrapper Class.
+// PyType.
 //------------------------------------------------------------------------------
 
-PyMlirModule PyMlirContext::parse(const std::string &module) {
-  auto moduleRef = mlirModuleCreateParse(context, module.c_str());
-  if (!moduleRef.ptr) {
-    throw SetPyError(PyExc_ValueError,
-                     "Unable to parse module assembly (see diagnostics)");
-  }
-  return PyMlirModule(moduleRef);
+bool PyType::operator==(const PyType &other) {
+  return mlirTypeEqual(type, other.type);
 }
 
 //------------------------------------------------------------------------------
-// Module Wrapper Class.
+// Standard type subclasses.
 //------------------------------------------------------------------------------
 
-void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
+namespace {
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+template <typename T>
+class PyConcreteType : public PyType {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = py::class_<T, PyType>;
+  using IsAFunctionTy = int (*)(MlirType);
+
+  PyConcreteType() = default;
+  PyConcreteType(MlirType t) : PyType(t) {}
+  PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
+
+  static MlirType castFrom(PyType &orig) {
+    if (!T::isaFunction(orig.type)) {
+      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+                                             T::pyClassName + " (from " +
+                                             origRepr + ")");
+    }
+    return orig.type;
+  }
+
+  static void bind(py::module &m) {
+    auto class_ = ClassTy(m, T::pyClassName);
+    class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+    T::bindDerived(class_);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+class PyIntegerType : public PyConcreteType<PyIntegerType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+  static constexpr const char *pyClassName = "IntegerType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "signless",
+        [](PyMlirContext &context, unsigned width) {
+          MlirType t = mlirIntegerTypeGet(context.context, width);
+          return PyIntegerType(t);
+        },
+        py::keep_alive<0, 1>(), "Create a signless integer type");
+    c.def_static(
+        "signed",
+        [](PyMlirContext &context, unsigned width) {
+          MlirType t = mlirIntegerTypeSignedGet(context.context, width);
+          return PyIntegerType(t);
+        },
+        py::keep_alive<0, 1>(), "Create a signed integer type");
+    c.def_static(
+        "unsigned",
+        [](PyMlirContext &context, unsigned width) {
+          MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
+          return PyIntegerType(t);
+        },
+        py::keep_alive<0, 1>(), "Create an unsigned integer type");
+    c.def_property_readonly(
+        "width",
+        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
+        "Returns the width of the integer type");
+    c.def_property_readonly(
+        "is_signless",
+        [](PyIntegerType &self) -> bool {
+          return mlirIntegerTypeIsSignless(self.type);
+        },
+        "Returns whether this is a signless integer");
+    c.def_property_readonly(
+        "is_signed",
+        [](PyIntegerType &self) -> bool {
+          return mlirIntegerTypeIsSigned(self.type);
+        },
+        "Returns whether this is a signed integer");
+    c.def_property_readonly(
+        "is_unsigned",
+        [](PyIntegerType &self) -> bool {
+          return mlirIntegerTypeIsUnsigned(self.type);
+        },
+        "Returns whether this is an unsigned integer");
+  }
+};
+
+} // namespace
 
 //------------------------------------------------------------------------------
 // Populates the pybind11 IR submodule.
 //------------------------------------------------------------------------------
 
 void mlir::python::populateIRSubmodule(py::module &m) {
-  py::class_<PyMlirContext>(m, "MlirContext")
+  // Mapping of MlirContext
+  py::class_<PyMlirContext>(m, "Context")
       .def(py::init<>())
-      .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(),
-           kContextParseDocstring);
+      .def(
+          "parse_module",
+          [](PyMlirContext &self, const std::string module) {
+            auto moduleRef =
+                mlirModuleCreateParse(self.context, module.c_str());
+            if (mlirModuleIsNull(moduleRef)) {
+              throw SetPyError(
+                  PyExc_ValueError,
+                  "Unable to parse module assembly (see diagnostics)");
+            }
+            return PyModule(moduleRef);
+          },
+          py::keep_alive<0, 1>(), kContextParseDocstring)
+      .def(
+          "parse_type",
+          [](PyMlirContext &self, std::string typeSpec) {
+            MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+            if (mlirTypeIsNull(type)) {
+              throw SetPyError(PyExc_ValueError,
+                               llvm::Twine("Unable to parse type: '") +
+                                   typeSpec + "'");
+            }
+            return PyType(type);
+          },
+          py::keep_alive<0, 1>(), kContextParseType);
 
-  py::class_<PyMlirModule>(m, "MlirModule")
-      .def("dump", &PyMlirModule::dump, kDumpDocstring)
+  // Mapping of Module
+  py::class_<PyModule>(m, "Module")
+      .def(
+          "dump",
+          [](PyModule &self) {
+            mlirOperationDump(mlirModuleGetOperation(self.module));
+          },
+          kDumpDocstring)
       .def(
           "__str__",
-          [](PyMlirModule &self) {
+          [](PyModule &self) {
             auto operation = mlirModuleGetOperation(self.module);
             PyPrintAccumulator printAccum;
             mlirOperationPrint(operation, printAccum.getCallback(),
@@ -104,4 +234,42 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             return printAccum.join();
           },
           kOperationStrDunderDocstring);
+
+  // Mapping of Type.
+  py::class_<PyType>(m, "Type")
+      .def("__eq__",
+           [](PyType &self, py::object &other) {
+             try {
+               PyType otherType = other.cast<PyType>();
+               return self == otherType;
+             } catch (std::exception &e) {
+               return false;
+             }
+           })
+      .def(
+          "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
+      .def(
+          "__str__",
+          [](PyType &self) {
+            PyPrintAccumulator printAccum;
+            mlirTypePrint(self.type, printAccum.getCallback(),
+                          printAccum.getUserData());
+            return printAccum.join();
+          },
+          kTypeStrDunderDocstring)
+      .def("__repr__", [](PyType &self) {
+        // Generally, assembly formats are not printed for __repr__ because
+        // this can cause exceptionally long debug output and exceptions.
+        // However, types are an exception as they typically have compact
+        // assembly forms and printing them is useful.
+        PyPrintAccumulator printAccum;
+        printAccum.parts.append("Type(");
+        mlirTypePrint(self.type, printAccum.getCallback(),
+                      printAccum.getUserData());
+        printAccum.parts.append(")");
+        return printAccum.join();
+      });
+
+  // Standard type bindings.
+  PyIntegerType::bind(m);
 }

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 325db497e2aa..4e90a9ae9795 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -17,38 +17,44 @@ namespace mlir {
 namespace python {
 
 class PyMlirContext;
-class PyMlirModule;
+class PyModule;
 
 /// Wrapper around MlirContext.
 class PyMlirContext {
 public:
   PyMlirContext() { context = mlirContextCreate(); }
   ~PyMlirContext() { mlirContextDestroy(context); }
-  /// Parses the module from asm.
-  PyMlirModule parse(const std::string &module);
 
   MlirContext context;
 };
 
 /// Wrapper around MlirModule.
-class PyMlirModule {
+class PyModule {
 public:
-  PyMlirModule(MlirModule module) : module(module) {}
-  PyMlirModule(PyMlirModule &) = delete;
-  PyMlirModule(PyMlirModule &&other) {
+  PyModule(MlirModule module) : module(module) {}
+  PyModule(PyModule &) = delete;
+  PyModule(PyModule &&other) {
     module = other.module;
     other.module.ptr = nullptr;
   }
-  ~PyMlirModule() {
+  ~PyModule() {
     if (module.ptr)
       mlirModuleDestroy(module);
   }
-  /// Dumps the module.
-  void dump();
 
   MlirModule module;
 };
 
+/// Wrapper around the generic MlirType.
+/// The lifetime of a type is bound by the PyContext that created it.
+class PyType {
+public:
+  PyType(MlirType type) : type(type) {}
+  bool operator==(const PyType &other);
+
+  MlirType type;
+};
+
 void populateIRSubmodule(pybind11::module &m);
 
 } // namespace python

diff  --git a/mlir/test/Bindings/Python/ir_module_test.py b/mlir/test/Bindings/Python/ir_module.py
similarity index 79%
rename from mlir/test/Bindings/Python/ir_module_test.py
rename to mlir/test/Bindings/Python/ir_module.py
index 26b7fe63369c..3e7a53995a37 100644
--- a/mlir/test/Bindings/Python/ir_module_test.py
+++ b/mlir/test/Bindings/Python/ir_module.py
@@ -3,15 +3,15 @@
 import mlir
 
 def run(f):
-  print("TEST:", f.__name__)
+  print("\nTEST:", f.__name__)
   f()
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testParseSuccess
 # CHECK: module @successfulParse
 def testParseSuccess():
-  ctx = mlir.ir.MlirContext()
-  module = ctx.parse(r"""module @successfulParse {}""")
+  ctx = mlir.ir.Context()
+  module = ctx.parse_module(r"""module @successfulParse {}""")
   module.dump()  # Just outputs to stderr. Verifies that it functions.
   print(str(module))
 
@@ -22,9 +22,9 @@ def testParseSuccess():
 # CHECK-LABEL: TEST: testParseError
 # CHECK: testParseError: Unable to parse module assembly (see diagnostics)
 def testParseError():
-  ctx = mlir.ir.MlirContext()
+  ctx = mlir.ir.Context()
   try:
-    module = ctx.parse(r"""}SYNTAX ERROR{""")
+    module = ctx.parse_module(r"""}SYNTAX ERROR{""")
   except ValueError as e:
     print("testParseError:", e)
   else:
@@ -40,8 +40,8 @@ def testParseError():
 # CHECK: func @roundtripUnicode()
 # CHECK: foo = "\F0\9F\98\8A"
 def testRoundtripUnicode():
-  ctx = mlir.ir.MlirContext()
-  module = ctx.parse(r"""
+  ctx = mlir.ir.Context()
+  module = ctx.parse_module(r"""
     func @roundtripUnicode() attributes { foo = "😊" }
   """)
   print(str(module))

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
new file mode 100644
index 000000000000..cc66b1fdb208
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -0,0 +1,126 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+
+# CHECK-LABEL: TEST: testParsePrint
+def testParsePrint():
+  ctx = mlir.ir.Context()
+  t = ctx.parse_type("i32")
+  # CHECK: i32
+  print(str(t))
+  # CHECK: Type(i32)
+  print(repr(t))
+
+run(testParsePrint)
+
+
+# CHECK-LABEL: TEST: testParseError
+# TODO: Hook the diagnostic manager to capture a more meaningful error
+# message.
+def testParseError():
+  ctx = mlir.ir.Context()
+  try:
+    t = ctx.parse_type("BAD_TYPE_DOES_NOT_EXIST")
+  except ValueError as e:
+    # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
+    print("testParseError:", e)
+  else:
+    print("Exception not produced")
+
+run(testParseError)
+
+
+# CHECK-LABEL: TEST: testTypeEq
+def testTypeEq():
+  ctx = mlir.ir.Context()
+  t1 = ctx.parse_type("i32")
+  t2 = ctx.parse_type("f32")
+  t3 = ctx.parse_type("i32")
+  # CHECK: t1 == t1: True
+  print("t1 == t1:", t1 == t1)
+  # CHECK: t1 == t2: False
+  print("t1 == t2:", t1 == t2)
+  # CHECK: t1 == t3: True
+  print("t1 == t3:", t1 == t3)
+  # CHECK: t1 == None: False
+  print("t1 == None:", t1 == None)
+
+run(testTypeEq)
+
+
+# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
+def testTypeEqDoesNotRaise():
+  ctx = mlir.ir.Context()
+  t1 = ctx.parse_type("i32")
+  not_a_type = "foo"
+  # CHECK: False
+  print(t1 == not_a_type)
+  # CHECK: False
+  print(t1 == None)
+  # CHECK: True
+  print(t1 != None)
+
+run(testTypeEqDoesNotRaise)
+
+
+# CHECK-LABEL: TEST: testStandardTypeCasts
+def testStandardTypeCasts():
+  ctx = mlir.ir.Context()
+  t1 = ctx.parse_type("i32")
+  tint = mlir.ir.IntegerType(t1)
+  tself = mlir.ir.IntegerType(tint)
+  # CHECK: Type(i32)
+  print(repr(tint))
+  try:
+    tillegal = mlir.ir.IntegerType(ctx.parse_type("f32"))
+  except ValueError as e:
+    # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+    print("ValueError:", e)
+  else:
+    print("Exception not produced")
+
+run(testStandardTypeCasts)
+
+
+# CHECK-LABEL: TEST: testIntegerType
+def testIntegerType():
+  ctx = mlir.ir.Context()
+  i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
+  # CHECK: i32 width: 32
+  print("i32 width:", i32.width)
+  # CHECK: i32 signless: True
+  print("i32 signless:", i32.is_signless)
+  # CHECK: i32 signed: False
+  print("i32 signed:", i32.is_signed)
+  # CHECK: i32 unsigned: False
+  print("i32 unsigned:", i32.is_unsigned)
+
+  s32 = mlir.ir.IntegerType(ctx.parse_type("si32"))
+  # CHECK: s32 signless: False
+  print("s32 signless:", s32.is_signless)
+  # CHECK: s32 signed: True
+  print("s32 signed:", s32.is_signed)
+  # CHECK: s32 unsigned: False
+  print("s32 unsigned:", s32.is_unsigned)
+
+  u32 = mlir.ir.IntegerType(ctx.parse_type("ui32"))
+  # CHECK: u32 signless: False
+  print("u32 signless:", u32.is_signless)
+  # CHECK: u32 signed: False
+  print("u32 signed:", u32.is_signed)
+  # CHECK: u32 unsigned: True
+  print("u32 unsigned:", u32.is_unsigned)
+
+  # CHECK: signless: i16
+  print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
+  # CHECK: signed: si8
+  print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
+  # CHECK: unsigned: ui64
+  print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
+
+run(testIntegerType)


        


More information about the Mlir-commits mailing list