[Mlir-commits] [mlir] c1ded6a - Add mlir python APIs for creating operations, regions and blocks.

Stella Laurenzo llvmlistbot at llvm.org
Wed Sep 23 07:59:04 PDT 2020


Author: Stella Laurenzo
Date: 2020-09-23T07:57:50-07:00
New Revision: c1ded6a759913a32b44a851f0823bbb648d2a7e1

URL: https://github.com/llvm/llvm-project/commit/c1ded6a759913a32b44a851f0823bbb648d2a7e1
DIFF: https://github.com/llvm/llvm-project/commit/c1ded6a759913a32b44a851f0823bbb648d2a7e1.diff

LOG: Add mlir python APIs for creating operations, regions and blocks.

* The API is a bit more verbose than I feel like it needs to be. In a follow-up I'd like to abbreviate some things and look in to creating aliases for common accessors.
* There is a lingering lifetime hazard between the module and newly added operations. We have the facilities now to solve for this but I will do that in a follow-up.
* We may need to craft a more limited API for safely referencing successors when creating operations. We need more facilities to really prove that out and should defer for now.

Differential Revision: https://reviews.llvm.org/D87996

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8eab7dab1675..3fad701d1641 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -24,6 +24,22 @@ using llvm::SmallVector;
 // Docstrings (trivial, non-duplicated docstrings are included inline).
 //------------------------------------------------------------------------------
 
+static const char kContextCreateOperationDocstring[] =
+    R"(Creates a new operation.
+
+Args:
+  name: Operation name (e.g. "dialect.operation").
+  location: A Location object.
+  results: Sequence of Type representing op result types.
+  attributes: Dict of str:Attribute.
+  successors: List of Block for the operation's successors.
+  regions: Number of regions to create.
+
+Returns:
+  A new "detached" Operation object. Detached operations can be added
+  to blocks, which causes them to become "attached."
+)";
+
 static const char kContextParseDocstring[] =
     R"(Parses a module's assembly format from a string.
 
@@ -60,6 +76,13 @@ static const char kTypeStrDunderDocstring[] =
 static const char kDumpDocstring[] =
     R"(Dumps a debug representation of the object to stderr.)";
 
+static const char kAppendBlockDocstring[] =
+    R"(Appends a new block, with argument types as positional args.
+
+Returns:
+  The created block.
+)";
+
 //------------------------------------------------------------------------------
 // Conversion utilities.
 //------------------------------------------------------------------------------
@@ -265,11 +288,25 @@ class PyBlockList {
     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
   }
 
+  PyBlock appendBlock(py::args pyArgTypes) {
+    operation->checkValid();
+    llvm::SmallVector<MlirType, 4> argTypes;
+    argTypes.reserve(pyArgTypes.size());
+    for (auto &pyArg : pyArgTypes) {
+      argTypes.push_back(pyArg.cast<PyType &>().type);
+    }
+
+    MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+    mlirRegionAppendOwnedBlock(region, block);
+    return PyBlock(operation, block);
+  }
+
   static void bind(py::module &m) {
     py::class_<PyBlockList>(m, "BlockList")
         .def("__getitem__", &PyBlockList::dunderGetItem)
         .def("__iter__", &PyBlockList::dunderIter)
-        .def("__len__", &PyBlockList::dunderLen);
+        .def("__len__", &PyBlockList::dunderLen)
+        .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
   }
 
 private:
@@ -352,11 +389,41 @@ class PyOperationList {
                      "attempt to access out of bounds operation");
   }
 
+  void insert(int index, PyOperation &newOperation) {
+    parentOperation->checkValid();
+    newOperation.checkValid();
+    if (index < 0) {
+      throw SetPyError(
+          PyExc_IndexError,
+          "only positive insertion indices are supported for operations");
+    }
+    if (newOperation.isAttached()) {
+      throw SetPyError(
+          PyExc_ValueError,
+          "attempt to insert an operation that has already been inserted");
+    }
+    // TODO: Needing to do this check is unfortunate, especially since it will
+    // be a forward-scan, just like the following call to
+    // mlirBlockInsertOwnedOperation. Switch to insert before/after once
+    // D88148 lands.
+    if (index > dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to insert operation past end");
+    }
+    mlirBlockInsertOwnedOperation(block, index, newOperation.get());
+    newOperation.setAttached();
+    // TODO: Rework the parentKeepAlive so as to avoid ownership hazards under
+    // the new ownership.
+  }
+
   static void bind(py::module &m) {
     py::class_<PyOperationList>(m, "OperationList")
         .def("__getitem__", &PyOperationList::dunderGetItem)
         .def("__iter__", &PyOperationList::dunderIter)
-        .def("__len__", &PyOperationList::dunderLen);
+        .def("__len__", &PyOperationList::dunderLen)
+        .def("insert", &PyOperationList::insert, py::arg("index"),
+             py::arg("operation"),
+             "Inserts an operation at an indexed position");
   }
 
 private:
@@ -416,6 +483,87 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
+py::object PyMlirContext::createOperation(
+    std::string name, PyLocation location,
+    llvm::Optional<std::vector<PyType *>> results,
+    llvm::Optional<py::dict> attributes,
+    llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
+  llvm::SmallVector<MlirType, 4> mlirResults;
+  llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
+  llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
+
+  // General parameter validation.
+  if (regions < 0)
+    throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
+
+  // Unpack/validate results.
+  if (results) {
+    mlirResults.reserve(results->size());
+    for (PyType *result : *results) {
+      // TODO: Verify result type originate from the same context.
+      if (!result)
+        throw SetPyError(PyExc_ValueError, "result type cannot be None");
+      mlirResults.push_back(result->type);
+    }
+  }
+  // Unpack/validate attributes.
+  if (attributes) {
+    mlirAttributes.reserve(attributes->size());
+    for (auto &it : *attributes) {
+
+      auto name = it.first.cast<std::string>();
+      auto &attribute = it.second.cast<PyAttribute &>();
+      // TODO: Verify attribute originates from the same context.
+      mlirAttributes.emplace_back(std::move(name), attribute.attr);
+    }
+  }
+  // Unpack/validate successors.
+  if (successors) {
+    llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
+    mlirSuccessors.reserve(successors->size());
+    for (auto *successor : *successors) {
+      // TODO: Verify successor originate from the same context.
+      if (!successor)
+        throw SetPyError(PyExc_ValueError, "successor block cannot be None");
+      mlirSuccessors.push_back(successor->get());
+    }
+  }
+
+  // Apply unpacked/validated to the operation state. Beyond this
+  // point, exceptions cannot be thrown or else the state will leak.
+  MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
+  if (!mlirResults.empty())
+    mlirOperationStateAddResults(&state, mlirResults.size(),
+                                 mlirResults.data());
+  if (!mlirAttributes.empty()) {
+    // Note that the attribute names directly reference bytes in
+    // mlirAttributes, so that vector must not be changed from here
+    // on.
+    llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
+    mlirNamedAttributes.reserve(mlirAttributes.size());
+    for (auto &it : mlirAttributes)
+      mlirNamedAttributes.push_back(
+          mlirNamedAttributeGet(it.first.c_str(), it.second));
+    mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
+                                    mlirNamedAttributes.data());
+  }
+  if (!mlirSuccessors.empty())
+    mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
+                                    mlirSuccessors.data());
+  if (regions) {
+    llvm::SmallVector<MlirRegion, 4> mlirRegions;
+    mlirRegions.resize(regions);
+    for (int i = 0; i < regions; ++i)
+      mlirRegions[i] = mlirRegionCreate();
+    mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
+                                      mlirRegions.data());
+  }
+
+  // Construct the operation.
+  MlirOperation operation = mlirOperationCreate(&state);
+  return PyOperation::createDetached(getRef(), operation).releaseObject();
+}
+
 //------------------------------------------------------------------------------
 // PyModule
 //------------------------------------------------------------------------------
@@ -1153,6 +1301,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           [](PyMlirContext &self, bool value) {
             mlirContextSetAllowUnregisteredDialects(self.get(), value);
           })
+      .def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
+           py::arg("location"), py::arg("results") = py::none(),
+           py::arg("attributes") = py::none(),
+           py::arg("successors") = py::none(), py::arg("regions") = 0,
+           kContextCreateOperationDocstring)
       .def(
           "parse_module",
           [](PyMlirContext &self, const std::string moduleAsm) {

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 06b697cfd786..41b18d216026 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -17,9 +17,12 @@
 namespace mlir {
 namespace python {
 
+class PyBlock;
+class PyLocation;
 class PyMlirContext;
 class PyModule;
 class PyOperation;
+class PyType;
 
 /// Template for a reference to a concrete type which captures a python
 /// reference to its underlying python object.
@@ -112,6 +115,14 @@ class PyMlirContext {
   /// Used for testing.
   size_t getLiveOperationCount();
 
+  /// Creates an operation. See corresponding python docstring.
+  pybind11::object
+  createOperation(std::string name, PyLocation location,
+                  llvm::Optional<std::vector<PyType *>> results,
+                  llvm::Optional<pybind11::dict> attributes,
+                  llvm::Optional<std::vector<PyBlock *>> successors,
+                  int regions);
+
 private:
   PyMlirContext(MlirContext context);
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -227,6 +238,10 @@ class PyOperation : public BaseContextObject {
   }
 
   bool isAttached() { return attached; }
+  void setAttached() {
+    assert(!attached && "operation already attached");
+    attached = true;
+  }
   void checkValid();
 
 private:

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 29ab06a25055..0435aa461809 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -12,8 +12,16 @@
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/Twine.h"
 
+namespace pybind11 {
+namespace detail {
+template <typename T>
+struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
+} // namespace detail
+} // namespace pybind11
+
 namespace mlir {
 namespace python {
 

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 9522e4b1ad98..881398e1eba3 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -99,3 +99,99 @@ def walk_operations(indent, op):
   walk_operations("", module.operation)
 
 run(testTraverseOpRegionBlockIndices)
+
+
+# CHECK-LABEL: TEST: testDetachedOperation
+def testDetachedOperation():
+  ctx = mlir.ir.Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
+  op1 = ctx.create_operation(
+      "custom.op1", loc, results=[i32, i32], regions=1, attributes={
+          "foo": mlir.ir.StringAttr.get(ctx, "foo_value"),
+          "bar": mlir.ir.StringAttr.get(ctx, "bar_value"),
+      })
+  # CHECK: %0:2 = "custom.op1"() ( {
+  # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
+  print(op1)
+
+  # TODO: Check successors once enough infra exists to do it properly.
+
+run(testDetachedOperation)
+
+
+# CHECK-LABEL: TEST: testOperationInsert
+def testOperationInsert():
+  ctx = mlir.ir.Context()
+  ctx.allow_unregistered_dialects = True
+  module = ctx.parse_module(r"""
+    func @f1(%arg0: i32) -> i32 {
+      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+      return %1 : i32
+    }
+  """)
+
+  # Create test op.
+  loc = ctx.get_unknown_location()
+  op1 = ctx.create_operation("custom.op1", loc)
+  op2 = ctx.create_operation("custom.op2", loc)
+
+  func = module.operation.regions[0].blocks[0].operations[0]
+  entry_block = func.regions[0].blocks[0]
+  entry_block.operations.insert(0, op1)
+  entry_block.operations.insert(1, op2)
+  # CHECK: func @f1
+  # CHECK: "custom.op1"()
+  # CHECK: "custom.op2"()
+  # CHECK: %0 = "custom.addi"
+  print(module)
+
+  # Trying to add a previously added op should raise.
+  try:
+    entry_block.operations.insert(0, op1)
+  except ValueError:
+    pass
+  else:
+    assert False, "expected insert of attached op to raise"
+
+run(testOperationInsert)
+
+
+# CHECK-LABEL: TEST: testOperationWithRegion
+def testOperationWithRegion():
+  ctx = mlir.ir.Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
+  op1 = ctx.create_operation("custom.op1", loc, regions=1)
+  block = op1.regions[0].blocks.append(i32, i32)
+  # CHECK: "custom.op1"() ( {
+  # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
+  # CHECK:   "custom.terminator"() : () -> ()
+  # CHECK: }) : () -> ()
+  terminator = ctx.create_operation("custom.terminator", loc)
+  block.operations.insert(0, terminator)
+  print(op1)
+
+  # Now add the whole operation to another op.
+  # TODO: Verify lifetime hazard by nulling out the new owning module and
+  # accessing op1.
+  # TODO: Also verify accessing the terminator once both parents are nulled
+  # out.
+  module = ctx.parse_module(r"""
+    func @f1(%arg0: i32) -> i32 {
+      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+      return %1 : i32
+    }
+  """)
+  func = module.operation.regions[0].blocks[0].operations[0]
+  entry_block = func.regions[0].blocks[0]
+  entry_block.operations.insert(0, op1)
+  # CHECK: func @f1
+  # CHECK: "custom.op1"()
+  # CHECK:   "custom.terminator"
+  # CHECK: %0 = "custom.addi"
+  print(module)
+
+run(testOperationWithRegion)


        


More information about the Mlir-commits mailing list