[Mlir-commits] [mlir] c1ded6a - Add mlir python APIs for creating operations, regions and blocks.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Sep 23 07:59:04 PDT 2020
Author: Stella Laurenzo
Date: 2020-09-23T07:57:50-07:00
New Revision: c1ded6a759913a32b44a851f0823bbb648d2a7e1
URL: https://github.com/llvm/llvm-project/commit/c1ded6a759913a32b44a851f0823bbb648d2a7e1
DIFF: https://github.com/llvm/llvm-project/commit/c1ded6a759913a32b44a851f0823bbb648d2a7e1.diff
LOG: Add mlir python APIs for creating operations, regions and blocks.
* The API is a bit more verbose than I feel like it needs to be. In a follow-up I'd like to abbreviate some things and look in to creating aliases for common accessors.
* There is a lingering lifetime hazard between the module and newly added operations. We have the facilities now to solve for this but I will do that in a follow-up.
* We may need to craft a more limited API for safely referencing successors when creating operations. We need more facilities to really prove that out and should defer for now.
Differential Revision: https://reviews.llvm.org/D87996
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/Bindings/Python/PybindUtils.h
mlir/test/Bindings/Python/ir_operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8eab7dab1675..3fad701d1641 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -24,6 +24,22 @@ using llvm::SmallVector;
// Docstrings (trivial, non-duplicated docstrings are included inline).
//------------------------------------------------------------------------------
+static const char kContextCreateOperationDocstring[] =
+ R"(Creates a new operation.
+
+Args:
+ name: Operation name (e.g. "dialect.operation").
+ location: A Location object.
+ results: Sequence of Type representing op result types.
+ attributes: Dict of str:Attribute.
+ successors: List of Block for the operation's successors.
+ regions: Number of regions to create.
+
+Returns:
+ A new "detached" Operation object. Detached operations can be added
+ to blocks, which causes them to become "attached."
+)";
+
static const char kContextParseDocstring[] =
R"(Parses a module's assembly format from a string.
@@ -60,6 +76,13 @@ static const char kTypeStrDunderDocstring[] =
static const char kDumpDocstring[] =
R"(Dumps a debug representation of the object to stderr.)";
+static const char kAppendBlockDocstring[] =
+ R"(Appends a new block, with argument types as positional args.
+
+Returns:
+ The created block.
+)";
+
//------------------------------------------------------------------------------
// Conversion utilities.
//------------------------------------------------------------------------------
@@ -265,11 +288,25 @@ class PyBlockList {
throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
}
+ PyBlock appendBlock(py::args pyArgTypes) {
+ operation->checkValid();
+ llvm::SmallVector<MlirType, 4> argTypes;
+ argTypes.reserve(pyArgTypes.size());
+ for (auto &pyArg : pyArgTypes) {
+ argTypes.push_back(pyArg.cast<PyType &>().type);
+ }
+
+ MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+ }
+
static void bind(py::module &m) {
py::class_<PyBlockList>(m, "BlockList")
.def("__getitem__", &PyBlockList::dunderGetItem)
.def("__iter__", &PyBlockList::dunderIter)
- .def("__len__", &PyBlockList::dunderLen);
+ .def("__len__", &PyBlockList::dunderLen)
+ .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
}
private:
@@ -352,11 +389,41 @@ class PyOperationList {
"attempt to access out of bounds operation");
}
+ void insert(int index, PyOperation &newOperation) {
+ parentOperation->checkValid();
+ newOperation.checkValid();
+ if (index < 0) {
+ throw SetPyError(
+ PyExc_IndexError,
+ "only positive insertion indices are supported for operations");
+ }
+ if (newOperation.isAttached()) {
+ throw SetPyError(
+ PyExc_ValueError,
+ "attempt to insert an operation that has already been inserted");
+ }
+ // TODO: Needing to do this check is unfortunate, especially since it will
+ // be a forward-scan, just like the following call to
+ // mlirBlockInsertOwnedOperation. Switch to insert before/after once
+ // D88148 lands.
+ if (index > dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to insert operation past end");
+ }
+ mlirBlockInsertOwnedOperation(block, index, newOperation.get());
+ newOperation.setAttached();
+ // TODO: Rework the parentKeepAlive so as to avoid ownership hazards under
+ // the new ownership.
+ }
+
static void bind(py::module &m) {
py::class_<PyOperationList>(m, "OperationList")
.def("__getitem__", &PyOperationList::dunderGetItem)
.def("__iter__", &PyOperationList::dunderIter)
- .def("__len__", &PyOperationList::dunderLen);
+ .def("__len__", &PyOperationList::dunderLen)
+ .def("insert", &PyOperationList::insert, py::arg("index"),
+ py::arg("operation"),
+ "Inserts an operation at an indexed position");
}
private:
@@ -416,6 +483,87 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+py::object PyMlirContext::createOperation(
+ std::string name, PyLocation location,
+ llvm::Optional<std::vector<PyType *>> results,
+ llvm::Optional<py::dict> attributes,
+ llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
+ llvm::SmallVector<MlirType, 4> mlirResults;
+ llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
+ llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
+
+ // General parameter validation.
+ if (regions < 0)
+ throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
+
+ // Unpack/validate results.
+ if (results) {
+ mlirResults.reserve(results->size());
+ for (PyType *result : *results) {
+ // TODO: Verify result type originate from the same context.
+ if (!result)
+ throw SetPyError(PyExc_ValueError, "result type cannot be None");
+ mlirResults.push_back(result->type);
+ }
+ }
+ // Unpack/validate attributes.
+ if (attributes) {
+ mlirAttributes.reserve(attributes->size());
+ for (auto &it : *attributes) {
+
+ auto name = it.first.cast<std::string>();
+ auto &attribute = it.second.cast<PyAttribute &>();
+ // TODO: Verify attribute originates from the same context.
+ mlirAttributes.emplace_back(std::move(name), attribute.attr);
+ }
+ }
+ // Unpack/validate successors.
+ if (successors) {
+ llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
+ mlirSuccessors.reserve(successors->size());
+ for (auto *successor : *successors) {
+ // TODO: Verify successor originate from the same context.
+ if (!successor)
+ throw SetPyError(PyExc_ValueError, "successor block cannot be None");
+ mlirSuccessors.push_back(successor->get());
+ }
+ }
+
+ // Apply unpacked/validated to the operation state. Beyond this
+ // point, exceptions cannot be thrown or else the state will leak.
+ MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
+ if (!mlirResults.empty())
+ mlirOperationStateAddResults(&state, mlirResults.size(),
+ mlirResults.data());
+ if (!mlirAttributes.empty()) {
+ // Note that the attribute names directly reference bytes in
+ // mlirAttributes, so that vector must not be changed from here
+ // on.
+ llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(mlirAttributes.size());
+ for (auto &it : mlirAttributes)
+ mlirNamedAttributes.push_back(
+ mlirNamedAttributeGet(it.first.c_str(), it.second));
+ mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ }
+ if (!mlirSuccessors.empty())
+ mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
+ mlirSuccessors.data());
+ if (regions) {
+ llvm::SmallVector<MlirRegion, 4> mlirRegions;
+ mlirRegions.resize(regions);
+ for (int i = 0; i < regions; ++i)
+ mlirRegions[i] = mlirRegionCreate();
+ mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
+ mlirRegions.data());
+ }
+
+ // Construct the operation.
+ MlirOperation operation = mlirOperationCreate(&state);
+ return PyOperation::createDetached(getRef(), operation).releaseObject();
+}
+
//------------------------------------------------------------------------------
// PyModule
//------------------------------------------------------------------------------
@@ -1153,6 +1301,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](PyMlirContext &self, bool value) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
+ .def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
+ py::arg("location"), py::arg("results") = py::none(),
+ py::arg("attributes") = py::none(),
+ py::arg("successors") = py::none(), py::arg("regions") = 0,
+ kContextCreateOperationDocstring)
.def(
"parse_module",
[](PyMlirContext &self, const std::string moduleAsm) {
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 06b697cfd786..41b18d216026 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -17,9 +17,12 @@
namespace mlir {
namespace python {
+class PyBlock;
+class PyLocation;
class PyMlirContext;
class PyModule;
class PyOperation;
+class PyType;
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
@@ -112,6 +115,14 @@ class PyMlirContext {
/// Used for testing.
size_t getLiveOperationCount();
+ /// Creates an operation. See corresponding python docstring.
+ pybind11::object
+ createOperation(std::string name, PyLocation location,
+ llvm::Optional<std::vector<PyType *>> results,
+ llvm::Optional<pybind11::dict> attributes,
+ llvm::Optional<std::vector<PyBlock *>> successors,
+ int regions);
+
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -227,6 +238,10 @@ class PyOperation : public BaseContextObject {
}
bool isAttached() { return attached; }
+ void setAttached() {
+ assert(!attached && "operation already attached");
+ attached = true;
+ }
void checkValid();
private:
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 29ab06a25055..0435aa461809 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -12,8 +12,16 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
+namespace pybind11 {
+namespace detail {
+template <typename T>
+struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
+} // namespace detail
+} // namespace pybind11
+
namespace mlir {
namespace python {
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 9522e4b1ad98..881398e1eba3 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -99,3 +99,99 @@ def walk_operations(indent, op):
walk_operations("", module.operation)
run(testTraverseOpRegionBlockIndices)
+
+
+# CHECK-LABEL: TEST: testDetachedOperation
+def testDetachedOperation():
+ ctx = mlir.ir.Context()
+ ctx.allow_unregistered_dialects = True
+ loc = ctx.get_unknown_location()
+ i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
+ op1 = ctx.create_operation(
+ "custom.op1", loc, results=[i32, i32], regions=1, attributes={
+ "foo": mlir.ir.StringAttr.get(ctx, "foo_value"),
+ "bar": mlir.ir.StringAttr.get(ctx, "bar_value"),
+ })
+ # CHECK: %0:2 = "custom.op1"() ( {
+ # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
+ print(op1)
+
+ # TODO: Check successors once enough infra exists to do it properly.
+
+run(testDetachedOperation)
+
+
+# CHECK-LABEL: TEST: testOperationInsert
+def testOperationInsert():
+ ctx = mlir.ir.Context()
+ ctx.allow_unregistered_dialects = True
+ module = ctx.parse_module(r"""
+ func @f1(%arg0: i32) -> i32 {
+ %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+ return %1 : i32
+ }
+ """)
+
+ # Create test op.
+ loc = ctx.get_unknown_location()
+ op1 = ctx.create_operation("custom.op1", loc)
+ op2 = ctx.create_operation("custom.op2", loc)
+
+ func = module.operation.regions[0].blocks[0].operations[0]
+ entry_block = func.regions[0].blocks[0]
+ entry_block.operations.insert(0, op1)
+ entry_block.operations.insert(1, op2)
+ # CHECK: func @f1
+ # CHECK: "custom.op1"()
+ # CHECK: "custom.op2"()
+ # CHECK: %0 = "custom.addi"
+ print(module)
+
+ # Trying to add a previously added op should raise.
+ try:
+ entry_block.operations.insert(0, op1)
+ except ValueError:
+ pass
+ else:
+ assert False, "expected insert of attached op to raise"
+
+run(testOperationInsert)
+
+
+# CHECK-LABEL: TEST: testOperationWithRegion
+def testOperationWithRegion():
+ ctx = mlir.ir.Context()
+ ctx.allow_unregistered_dialects = True
+ loc = ctx.get_unknown_location()
+ i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
+ op1 = ctx.create_operation("custom.op1", loc, regions=1)
+ block = op1.regions[0].blocks.append(i32, i32)
+ # CHECK: "custom.op1"() ( {
+ # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors
+ # CHECK: "custom.terminator"() : () -> ()
+ # CHECK: }) : () -> ()
+ terminator = ctx.create_operation("custom.terminator", loc)
+ block.operations.insert(0, terminator)
+ print(op1)
+
+ # Now add the whole operation to another op.
+ # TODO: Verify lifetime hazard by nulling out the new owning module and
+ # accessing op1.
+ # TODO: Also verify accessing the terminator once both parents are nulled
+ # out.
+ module = ctx.parse_module(r"""
+ func @f1(%arg0: i32) -> i32 {
+ %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+ return %1 : i32
+ }
+ """)
+ func = module.operation.regions[0].blocks[0].operations[0]
+ entry_block = func.regions[0].blocks[0]
+ entry_block.operations.insert(0, op1)
+ # CHECK: func @f1
+ # CHECK: "custom.op1"()
+ # CHECK: "custom.terminator"
+ # CHECK: %0 = "custom.addi"
+ print(module)
+
+run(testOperationWithRegion)
More information about the Mlir-commits
mailing list