[Mlir-commits] [mlir] d29d1e2 - Add python bindings for Type and IntegerType.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Aug 19 09:24:15 PDT 2020
Author: Stella Laurenzo
Date: 2020-08-19T09:23:44-07:00
New Revision: d29d1e2ffd61f450b4392d2dab8060e54d040fcf
URL: https://github.com/llvm/llvm-project/commit/d29d1e2ffd61f450b4392d2dab8060e54d040fcf
DIFF: https://github.com/llvm/llvm-project/commit/d29d1e2ffd61f450b4392d2dab8060e54d040fcf.diff
LOG: Add python bindings for Type and IntegerType.
* The binding for Type is trivial and should be non-controversial.
* The way that I define the IntegerType should serve as a pattern for what I want to do next.
* I propose defining the rest of the standard types in this fashion and then generalizing for dialect types as necessary.
* Essentially, creating/accessing a concrete Type (vs interacting with the string form) is done by "casting" to the concrete type (i.e. IntegerType can be constructed with a Type and will throw if the cast is illegal).
* This deviates from some of our previous discussions about global objects but I think produces a usable API and we should go this way.
Differential Revision: https://reviews.llvm.org/D86179
Added:
mlir/test/Bindings/Python/ir_module.py
mlir/test/Bindings/Python/ir_types.py
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
Removed:
mlir/test/Bindings/Python/ir_module_test.py
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 68546bf35625..d97491b9f08a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -116,6 +116,9 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location);
/** Parses a module from the string and transfers ownership to the caller. */
MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
+/** Checks whether a module is null. */
+inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
+
/** Takes a module owned by the caller and deletes it. */
void mlirModuleDestroy(MlirModule module);
@@ -312,6 +315,9 @@ void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
/** Parses a type. The type is owned by the context. */
MlirType mlirTypeParseGet(MlirContext context, const char *type);
+/** Checks whether a type is null. */
+inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
+
/** Checks if two types are equal. */
int mlirTypeEqual(MlirType t1, MlirType t2);
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 27e1854e7455..bdce390188fa 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -9,7 +9,10 @@
#include "IRModules.h"
#include "PybindUtils.h"
+#include "mlir-c/StandardTypes.h"
+
namespace py = pybind11;
+using namespace mlir;
using namespace mlir::python;
//------------------------------------------------------------------------------
@@ -20,6 +23,15 @@ static const char kContextParseDocstring[] =
R"(Parses a module's assembly format from a string.
Returns a new MlirModule or raises a ValueError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kContextParseType[] = R"(Parses the assembly form of a type.
+
+Returns a Type object or raises a ValueError if the type cannot be parsed.
+
+See also: https://mlir.llvm.org/docs/LangRef/#type-system
)";
static const char kOperationStrDunderDocstring[] =
@@ -30,6 +42,9 @@ use the dedicated print method, which supports keyword arguments to customize
behavior.
)";
+static const char kTypeStrDunderDocstring[] =
+ R"(Prints the assembly form of the type.)";
+
static const char kDumpDocstring[] =
R"(Dumps a debug representation of the object to stderr.)";
@@ -64,39 +79,154 @@ struct PyPrintAccumulator {
} // namespace
//------------------------------------------------------------------------------
-// Context Wrapper Class.
+// PyType.
//------------------------------------------------------------------------------
-PyMlirModule PyMlirContext::parse(const std::string &module) {
- auto moduleRef = mlirModuleCreateParse(context, module.c_str());
- if (!moduleRef.ptr) {
- throw SetPyError(PyExc_ValueError,
- "Unable to parse module assembly (see diagnostics)");
- }
- return PyMlirModule(moduleRef);
+bool PyType::operator==(const PyType &other) {
+ return mlirTypeEqual(type, other.type);
}
//------------------------------------------------------------------------------
-// Module Wrapper Class.
+// Standard type subclasses.
//------------------------------------------------------------------------------
-void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
+namespace {
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+template <typename T>
+class PyConcreteType : public PyType {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = py::class_<T, PyType>;
+ using IsAFunctionTy = int (*)(MlirType);
+
+ PyConcreteType() = default;
+ PyConcreteType(MlirType t) : PyType(t) {}
+ PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
+
+ static MlirType castFrom(PyType &orig) {
+ if (!T::isaFunction(orig.type)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+ T::pyClassName + " (from " +
+ origRepr + ")");
+ }
+ return orig.type;
+ }
+
+ static void bind(py::module &m) {
+ auto class_ = ClassTy(m, T::pyClassName);
+ class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+ T::bindDerived(class_);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyIntegerType : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "signless",
+ [](PyMlirContext &context, unsigned width) {
+ MlirType t = mlirIntegerTypeGet(context.context, width);
+ return PyIntegerType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a signless integer type");
+ c.def_static(
+ "signed",
+ [](PyMlirContext &context, unsigned width) {
+ MlirType t = mlirIntegerTypeSignedGet(context.context, width);
+ return PyIntegerType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a signed integer type");
+ c.def_static(
+ "unsigned",
+ [](PyMlirContext &context, unsigned width) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
+ return PyIntegerType(t);
+ },
+ py::keep_alive<0, 1>(), "Create an unsigned integer type");
+ c.def_property_readonly(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
+ "Returns the width of the integer type");
+ c.def_property_readonly(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self.type);
+ },
+ "Returns whether this is a signless integer");
+ c.def_property_readonly(
+ "is_signed",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSigned(self.type);
+ },
+ "Returns whether this is a signed integer");
+ c.def_property_readonly(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self.type);
+ },
+ "Returns whether this is an unsigned integer");
+ }
+};
+
+} // namespace
//------------------------------------------------------------------------------
// Populates the pybind11 IR submodule.
//------------------------------------------------------------------------------
void mlir::python::populateIRSubmodule(py::module &m) {
- py::class_<PyMlirContext>(m, "MlirContext")
+ // Mapping of MlirContext
+ py::class_<PyMlirContext>(m, "Context")
.def(py::init<>())
- .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(),
- kContextParseDocstring);
+ .def(
+ "parse_module",
+ [](PyMlirContext &self, const std::string module) {
+ auto moduleRef =
+ mlirModuleCreateParse(self.context, module.c_str());
+ if (mlirModuleIsNull(moduleRef)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ "Unable to parse module assembly (see diagnostics)");
+ }
+ return PyModule(moduleRef);
+ },
+ py::keep_alive<0, 1>(), kContextParseDocstring)
+ .def(
+ "parse_type",
+ [](PyMlirContext &self, std::string typeSpec) {
+ MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+ if (mlirTypeIsNull(type)) {
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("Unable to parse type: '") +
+ typeSpec + "'");
+ }
+ return PyType(type);
+ },
+ py::keep_alive<0, 1>(), kContextParseType);
- py::class_<PyMlirModule>(m, "MlirModule")
- .def("dump", &PyMlirModule::dump, kDumpDocstring)
+ // Mapping of Module
+ py::class_<PyModule>(m, "Module")
+ .def(
+ "dump",
+ [](PyModule &self) {
+ mlirOperationDump(mlirModuleGetOperation(self.module));
+ },
+ kDumpDocstring)
.def(
"__str__",
- [](PyMlirModule &self) {
+ [](PyModule &self) {
auto operation = mlirModuleGetOperation(self.module);
PyPrintAccumulator printAccum;
mlirOperationPrint(operation, printAccum.getCallback(),
@@ -104,4 +234,42 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return printAccum.join();
},
kOperationStrDunderDocstring);
+
+ // Mapping of Type.
+ py::class_<PyType>(m, "Type")
+ .def("__eq__",
+ [](PyType &self, py::object &other) {
+ try {
+ PyType otherType = other.cast<PyType>();
+ return self == otherType;
+ } catch (std::exception &e) {
+ return false;
+ }
+ })
+ .def(
+ "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyType &self) {
+ PyPrintAccumulator printAccum;
+ mlirTypePrint(self.type, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ kTypeStrDunderDocstring)
+ .def("__repr__", [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self.type, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ });
+
+ // Standard type bindings.
+ PyIntegerType::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 325db497e2aa..4e90a9ae9795 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -17,38 +17,44 @@ namespace mlir {
namespace python {
class PyMlirContext;
-class PyMlirModule;
+class PyModule;
/// Wrapper around MlirContext.
class PyMlirContext {
public:
PyMlirContext() { context = mlirContextCreate(); }
~PyMlirContext() { mlirContextDestroy(context); }
- /// Parses the module from asm.
- PyMlirModule parse(const std::string &module);
MlirContext context;
};
/// Wrapper around MlirModule.
-class PyMlirModule {
+class PyModule {
public:
- PyMlirModule(MlirModule module) : module(module) {}
- PyMlirModule(PyMlirModule &) = delete;
- PyMlirModule(PyMlirModule &&other) {
+ PyModule(MlirModule module) : module(module) {}
+ PyModule(PyModule &) = delete;
+ PyModule(PyModule &&other) {
module = other.module;
other.module.ptr = nullptr;
}
- ~PyMlirModule() {
+ ~PyModule() {
if (module.ptr)
mlirModuleDestroy(module);
}
- /// Dumps the module.
- void dump();
MlirModule module;
};
+/// Wrapper around the generic MlirType.
+/// The lifetime of a type is bound by the PyContext that created it.
+class PyType {
+public:
+ PyType(MlirType type) : type(type) {}
+ bool operator==(const PyType &other);
+
+ MlirType type;
+};
+
void populateIRSubmodule(pybind11::module &m);
} // namespace python
diff --git a/mlir/test/Bindings/Python/ir_module_test.py b/mlir/test/Bindings/Python/ir_module.py
similarity index 79%
rename from mlir/test/Bindings/Python/ir_module_test.py
rename to mlir/test/Bindings/Python/ir_module.py
index 26b7fe63369c..3e7a53995a37 100644
--- a/mlir/test/Bindings/Python/ir_module_test.py
+++ b/mlir/test/Bindings/Python/ir_module.py
@@ -3,15 +3,15 @@
import mlir
def run(f):
- print("TEST:", f.__name__)
+ print("\nTEST:", f.__name__)
f()
# Verify successful parse.
# CHECK-LABEL: TEST: testParseSuccess
# CHECK: module @successfulParse
def testParseSuccess():
- ctx = mlir.ir.MlirContext()
- module = ctx.parse(r"""module @successfulParse {}""")
+ ctx = mlir.ir.Context()
+ module = ctx.parse_module(r"""module @successfulParse {}""")
module.dump() # Just outputs to stderr. Verifies that it functions.
print(str(module))
@@ -22,9 +22,9 @@ def testParseSuccess():
# CHECK-LABEL: TEST: testParseError
# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
def testParseError():
- ctx = mlir.ir.MlirContext()
+ ctx = mlir.ir.Context()
try:
- module = ctx.parse(r"""}SYNTAX ERROR{""")
+ module = ctx.parse_module(r"""}SYNTAX ERROR{""")
except ValueError as e:
print("testParseError:", e)
else:
@@ -40,8 +40,8 @@ def testParseError():
# CHECK: func @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
def testRoundtripUnicode():
- ctx = mlir.ir.MlirContext()
- module = ctx.parse(r"""
+ ctx = mlir.ir.Context()
+ module = ctx.parse_module(r"""
func @roundtripUnicode() attributes { foo = "😊" }
""")
print(str(module))
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
new file mode 100644
index 000000000000..cc66b1fdb208
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -0,0 +1,126 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testParsePrint
+def testParsePrint():
+ ctx = mlir.ir.Context()
+ t = ctx.parse_type("i32")
+ # CHECK: i32
+ print(str(t))
+ # CHECK: Type(i32)
+ print(repr(t))
+
+run(testParsePrint)
+
+
+# CHECK-LABEL: TEST: testParseError
+# TODO: Hook the diagnostic manager to capture a more meaningful error
+# message.
+def testParseError():
+ ctx = mlir.ir.Context()
+ try:
+ t = ctx.parse_type("BAD_TYPE_DOES_NOT_EXIST")
+ except ValueError as e:
+ # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
+ print("testParseError:", e)
+ else:
+ print("Exception not produced")
+
+run(testParseError)
+
+
+# CHECK-LABEL: TEST: testTypeEq
+def testTypeEq():
+ ctx = mlir.ir.Context()
+ t1 = ctx.parse_type("i32")
+ t2 = ctx.parse_type("f32")
+ t3 = ctx.parse_type("i32")
+ # CHECK: t1 == t1: True
+ print("t1 == t1:", t1 == t1)
+ # CHECK: t1 == t2: False
+ print("t1 == t2:", t1 == t2)
+ # CHECK: t1 == t3: True
+ print("t1 == t3:", t1 == t3)
+ # CHECK: t1 == None: False
+ print("t1 == None:", t1 == None)
+
+run(testTypeEq)
+
+
+# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
+def testTypeEqDoesNotRaise():
+ ctx = mlir.ir.Context()
+ t1 = ctx.parse_type("i32")
+ not_a_type = "foo"
+ # CHECK: False
+ print(t1 == not_a_type)
+ # CHECK: False
+ print(t1 == None)
+ # CHECK: True
+ print(t1 != None)
+
+run(testTypeEqDoesNotRaise)
+
+
+# CHECK-LABEL: TEST: testStandardTypeCasts
+def testStandardTypeCasts():
+ ctx = mlir.ir.Context()
+ t1 = ctx.parse_type("i32")
+ tint = mlir.ir.IntegerType(t1)
+ tself = mlir.ir.IntegerType(tint)
+ # CHECK: Type(i32)
+ print(repr(tint))
+ try:
+ tillegal = mlir.ir.IntegerType(ctx.parse_type("f32"))
+ except ValueError as e:
+ # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+ print("ValueError:", e)
+ else:
+ print("Exception not produced")
+
+run(testStandardTypeCasts)
+
+
+# CHECK-LABEL: TEST: testIntegerType
+def testIntegerType():
+ ctx = mlir.ir.Context()
+ i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
+ # CHECK: i32 width: 32
+ print("i32 width:", i32.width)
+ # CHECK: i32 signless: True
+ print("i32 signless:", i32.is_signless)
+ # CHECK: i32 signed: False
+ print("i32 signed:", i32.is_signed)
+ # CHECK: i32 unsigned: False
+ print("i32 unsigned:", i32.is_unsigned)
+
+ s32 = mlir.ir.IntegerType(ctx.parse_type("si32"))
+ # CHECK: s32 signless: False
+ print("s32 signless:", s32.is_signless)
+ # CHECK: s32 signed: True
+ print("s32 signed:", s32.is_signed)
+ # CHECK: s32 unsigned: False
+ print("s32 unsigned:", s32.is_unsigned)
+
+ u32 = mlir.ir.IntegerType(ctx.parse_type("ui32"))
+ # CHECK: u32 signless: False
+ print("u32 signless:", u32.is_signless)
+ # CHECK: u32 signed: False
+ print("u32 signed:", u32.is_signed)
+ # CHECK: u32 unsigned: True
+ print("u32 unsigned:", u32.is_unsigned)
+
+ # CHECK: signless: i16
+ print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
+ # CHECK: signed: si8
+ print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
+ # CHECK: unsigned: ui64
+ print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
+
+run(testIntegerType)
More information about the Mlir-commits
mailing list