[Mlir-commits] [mlir] c645ea5 - Add InsertionPoint and context managers to the Python API.

Stella Laurenzo llvmlistbot at llvm.org
Thu Oct 29 17:53:28 PDT 2020


Author: Stella Laurenzo
Date: 2020-10-29T17:50:13-07:00
New Revision: c645ea5e29e5e42598c3be67a28698405e8bc563

URL: https://github.com/llvm/llvm-project/commit/c645ea5e29e5e42598c3be67a28698405e8bc563
DIFF: https://github.com/llvm/llvm-project/commit/c645ea5e29e5e42598c3be67a28698405e8bc563.diff

LOG: Add InsertionPoint and context managers to the Python API.

* Removes index based insertion. All insertion now happens through the insertion point.
* Introduces thread local context managers for implicit creation relative to an insertion point.
* Introduces (but does not yet use) binding the Context to the thread local context stack. Intent is to refactor all methods to take context optionally and have them use the default if available.
* Adds C APIs for mlirOperationGetParentOperation(), mlirOperationGetBlock() and mlirBlockGetTerminator().
* Removes an assert in PyOperation creation that was incorrectly constraining. There is already a TODO to rework the keepAlive field that it was guarding and without the assert, it is no worse than the current state.

Differential Revision: https://reviews.llvm.org/D90368

Added: 
    mlir/test/Bindings/Python/insertion_point.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/dialects.py
    mlir/test/Bindings/Python/ir_operation.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index af0ab1fdf341..52152960f415 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -285,6 +285,14 @@ static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
  * not perform deep comparison. */
 int mlirOperationEqual(MlirOperation op, MlirOperation other);
 
+/** Gets the block that owns this operation, returning null if the operation is
+ * not owned. */
+MlirBlock mlirOperationGetBlock(MlirOperation op);
+
+/** Gets the operation that owns this operation, returning null if the operation
+ * is not owned. */
+MlirOperation mlirOperationGetParentOperation(MlirOperation op);
+
 /** Returns the number of regions attached to the given operation. */
 intptr_t mlirOperationGetNumRegions(MlirOperation op);
 
@@ -408,6 +416,9 @@ MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
 /** Returns the first operation in the block. */
 MlirOperation mlirBlockGetFirstOperation(MlirBlock block);
 
+/** Returns the terminator operation in the block or null if no terminator. */
+MlirOperation mlirBlockGetTerminator(MlirBlock block);
+
 /** Takes an operation owned by the caller and appends it to the block. */
 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation);
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 0ec5b566a3b5..5e0e45d0784d 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -467,41 +467,11 @@ class PyOperationList {
                      "attempt to access out of bounds operation");
   }
 
-  void insert(int index, PyOperation &newOperation) {
-    parentOperation->checkValid();
-    newOperation.checkValid();
-    if (index < 0) {
-      throw SetPyError(
-          PyExc_IndexError,
-          "only positive insertion indices are supported for operations");
-    }
-    if (newOperation.isAttached()) {
-      throw SetPyError(
-          PyExc_ValueError,
-          "attempt to insert an operation that has already been inserted");
-    }
-    // TODO: Needing to do this check is unfortunate, especially since it will
-    // be a forward-scan, just like the following call to
-    // mlirBlockInsertOwnedOperation. Switch to insert before/after once
-    // D88148 lands.
-    if (index > dunderLen()) {
-      throw SetPyError(PyExc_IndexError,
-                       "attempt to insert operation past end");
-    }
-    mlirBlockInsertOwnedOperation(block, index, newOperation.get());
-    newOperation.setAttached();
-    // TODO: Rework the parentKeepAlive so as to avoid ownership hazards under
-    // the new ownership.
-  }
-
   static void bind(py::module &m) {
     py::class_<PyOperationList>(m, "OperationList")
         .def("__getitem__", &PyOperationList::dunderGetItem)
         .def("__iter__", &PyOperationList::dunderIter)
-        .def("__len__", &PyOperationList::dunderLen)
-        .def("insert", &PyOperationList::insert, py::arg("index"),
-             py::arg("operation"),
-             "Inserts an operation at an indexed position");
+        .def("__len__", &PyOperationList::dunderLen);
   }
 
 private:
@@ -668,7 +638,75 @@ py::object PyMlirContext::createOperation(
 
   // Construct the operation.
   MlirOperation operation = mlirOperationCreate(&state);
-  return PyOperation::createDetached(getRef(), operation).releaseObject();
+  PyOperationRef created = PyOperation::createDetached(getRef(), operation);
+
+  // InsertPoint active?
+  PyInsertionPoint *ip =
+      PyThreadContextEntry::getDefaultInsertionPoint(/*required=*/false);
+  if (ip)
+    ip->insert(*created.get());
+
+  return created.releaseObject();
+}
+
+//------------------------------------------------------------------------------
+// PyThreadContextEntry management
+//------------------------------------------------------------------------------
+
+std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
+  static thread_local std::vector<PyThreadContextEntry> stack;
+  return stack;
+}
+
+PyThreadContextEntry *PyThreadContextEntry::getTos() {
+  auto &stack = getStack();
+  if (stack.empty())
+    return nullptr;
+  return &stack.back();
+}
+
+void PyThreadContextEntry::push(pybind11::object context,
+                                pybind11::object insertionPoint) {
+  auto &stack = getStack();
+  stack.emplace_back(std::move(context), std::move(insertionPoint));
+}
+
+PyMlirContext *PyThreadContextEntry::getContext() {
+  if (!context)
+    return nullptr;
+  return py::cast<PyMlirContext *>(context);
+}
+
+PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
+  if (!insertionPoint)
+    return nullptr;
+  return py::cast<PyInsertionPoint *>(insertionPoint);
+}
+
+PyMlirContext *PyThreadContextEntry::getDefaultContext(bool required) {
+  auto *tos = getTos();
+  PyMlirContext *context = tos ? tos->getContext() : nullptr;
+  if (required && !context) {
+    throw SetPyError(
+        PyExc_RuntimeError,
+        "A default context is required for this call but is not provided. "
+        "Establish a default by surrounding the code with "
+        "'with context:'");
+  }
+  return context;
+}
+
+PyInsertionPoint *
+PyThreadContextEntry::getDefaultInsertionPoint(bool required) {
+  auto *tos = getTos();
+  PyInsertionPoint *ip = tos ? tos->getInsertionPoint() : nullptr;
+  if (required && !ip)
+    throw SetPyError(PyExc_RuntimeError,
+                     "A default insertion point is required for this call but "
+                     "is not provided. "
+                     "Establish a default by surrounding the code with "
+                     "'with InsertionPoint(...):'");
+  return ip;
 }
 
 //------------------------------------------------------------------------------
@@ -791,7 +829,6 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
   }
   // Use existing.
   PyOperation *existing = it->second.second;
-  assert(existing->parentKeepAlive.is(parentKeepAlive));
   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
   return PyOperationRef(existing, std::move(pyRef));
 }
@@ -858,6 +895,22 @@ py::object PyOperation::getAsm(bool binary,
   return fileObject.attr("getvalue")();
 }
 
+PyOperationRef PyOperation::getParentOperation() {
+  if (!isAttached())
+    throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
+  MlirOperation operation = mlirOperationGetParentOperation(get());
+  if (mlirOperationIsNull(operation))
+    throw SetPyError(PyExc_ValueError, "Operation has no parent.");
+  return PyOperation::forOperation(getContext(), operation);
+}
+
+PyBlock PyOperation::getBlock() {
+  PyOperationRef parentOperation = getParentOperation();
+  MlirBlock block = mlirOperationGetBlock(get());
+  assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
+  return PyBlock{std::move(parentOperation), block};
+}
+
 PyOpView::PyOpView(py::object operation)
     : operationObject(std::move(operation)),
       operation(py::cast<PyOperation *>(this->operationObject)) {}
@@ -897,6 +950,76 @@ py::object PyOpView::createRawSubclass(py::object userClass) {
   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
 }
 
+//------------------------------------------------------------------------------
+// PyInsertionPoint.
+//------------------------------------------------------------------------------
+
+PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
+
+PyInsertionPoint::PyInsertionPoint(PyOperation &beforeOperation)
+    : block(beforeOperation.getBlock()),
+      refOperation(beforeOperation.getRef()) {}
+
+void PyInsertionPoint::insert(PyOperation &operation) {
+  if (operation.isAttached())
+    throw SetPyError(PyExc_ValueError,
+                     "Attempt to insert operation that is already attached");
+  block.getParentOperation()->checkValid();
+  MlirOperation beforeOp = {nullptr};
+  if (refOperation) {
+    // Insert before operation.
+    (*refOperation)->checkValid();
+    beforeOp = (*refOperation)->get();
+  }
+  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get());
+  operation.setAttached();
+}
+
+PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
+  MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
+  if (mlirOperationIsNull(firstOp)) {
+    // Just insert at end.
+    return PyInsertionPoint(block);
+  }
+
+  // Insert before first op.
+  PyOperationRef firstOpRef = PyOperation::forOperation(
+      block.getParentOperation()->getContext(), firstOp);
+  return PyInsertionPoint{block, std::move(firstOpRef)};
+}
+
+PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
+  MlirOperation terminator = mlirBlockGetTerminator(block.get());
+  if (mlirOperationIsNull(terminator))
+    throw SetPyError(PyExc_ValueError, "Block has no terminator");
+  PyOperationRef terminatorOpRef = PyOperation::forOperation(
+      block.getParentOperation()->getContext(), terminator);
+  return PyInsertionPoint{block, std::move(terminatorOpRef)};
+}
+
+py::object PyInsertionPoint::contextEnter() {
+  auto context = block.getParentOperation()->getContext().getObject();
+  py::object self = py::cast(this);
+  PyThreadContextEntry::push(/*context=*/std::move(context),
+                             /*insertionPoint=*/self);
+  return self;
+}
+
+void PyInsertionPoint::contextExit(pybind11::object excType,
+                                   pybind11::object excVal,
+                                   pybind11::object excTb) {
+  auto &stack = PyThreadContextEntry::getStack();
+  if (stack.empty())
+    throw SetPyError(PyExc_RuntimeError,
+                     "Unbalanced insertion point enter/exit");
+  auto &tos = stack.back();
+  PyInsertionPoint *current = tos.getInsertionPoint();
+  if (current != this)
+    throw SetPyError(PyExc_RuntimeError,
+                     "Unbalanced insertion point enter/exit");
+  stack.pop_back();
+}
+
 //------------------------------------------------------------------------------
 // PyAttribute.
 //------------------------------------------------------------------------------
@@ -2388,6 +2511,24 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           "Returns the assembly form of the block.");
 
+  //----------------------------------------------------------------------------
+  // Mapping of PyInsertionPoint.
+  //----------------------------------------------------------------------------
+
+  py::class_<PyInsertionPoint>(m, "InsertionPoint")
+      .def(py::init<PyBlock &>(), py::arg("block"),
+           "Inserts after the last operation but still inside the block.")
+      .def("__enter__", &PyInsertionPoint::contextEnter)
+      .def("__exit__", &PyInsertionPoint::contextExit)
+      .def(py::init<PyOperation &>(), py::arg("beforeOperation"),
+           "Inserts before a referenced operation.")
+      .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
+                  py::arg("block"), "Inserts at the beginning of the block.")
+      .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
+                  py::arg("block"), "Inserts before the block terminator.")
+      .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
+           "Inserts an operation.");
+
   //----------------------------------------------------------------------------
   // Mapping of PyAttribute.
   //----------------------------------------------------------------------------

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 89cca9c1c85f..6b1b69941958 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
 
+#include <vector>
+
 #include <pybind11/pybind11.h>
 
 #include "mlir-c/IR.h"
@@ -18,6 +20,7 @@ namespace mlir {
 namespace python {
 
 class PyBlock;
+class PyInsertionPoint;
 class PyLocation;
 class PyMlirContext;
 class PyModule;
@@ -61,6 +64,7 @@ class PyObjectRef {
     return stolen;
   }
 
+  T *get() { return referrent; }
   T *operator->() {
     assert(referrent && object);
     return referrent;
@@ -76,9 +80,48 @@ class PyObjectRef {
   pybind11::object object;
 };
 
-using PyMlirContextRef = PyObjectRef<PyMlirContext>;
+/// Tracks an entry in the thread context stack. New entries are pushed onto
+/// here for each with block that activates a new InsertionPoint or Context.
+/// Pushing either a context or an insertion point resets the other:
+///   - a new context activates a new entry with a null insertion point.
+///   - a new insertion point activates a new entry with the context that the
+///     insertion point is bound to.
+class PyThreadContextEntry {
+public:
+  PyThreadContextEntry(pybind11::object context,
+                       pybind11::object insertionPoint)
+      : context(std::move(context)), insertionPoint(std::move(insertionPoint)) {
+  }
+
+  /// Gets the top of stack context and return nullptr if not defined.
+  /// If required is true and there is no default, a nice user-facing exception
+  /// is raised.
+  static PyMlirContext *getDefaultContext(bool required);
+
+  /// Gets the top of stack insertion point and return nullptr if not defined.
+  /// If required is true and there is no default, a nice user-facing exception
+  /// is raised.
+  static PyInsertionPoint *getDefaultInsertionPoint(bool required);
+
+  PyMlirContext *getContext();
+  PyInsertionPoint *getInsertionPoint();
+
+  /// Stack management.
+  static PyThreadContextEntry *getTos();
+  static void push(pybind11::object context, pybind11::object insertionPoint);
+
+  /// Gets the thread local stack.
+  static std::vector<PyThreadContextEntry> &getStack();
+
+private:
+  /// An object reference to the PyContext.
+  pybind11::object context;
+  /// An object reference to the current insertion point.
+  pybind11::object insertionPoint;
+};
 
 /// Wrapper around MlirContext.
+using PyMlirContextRef = PyObjectRef<PyMlirContext>;
 class PyMlirContext {
 public:
   PyMlirContext() = delete;
@@ -287,8 +330,7 @@ class PyOperation : public BaseContextObject {
 public:
   ~PyOperation();
   /// Returns a PyOperation for the given MlirOperation, optionally associating
-  /// it with a parentKeepAlive (which must match on all such calls for the
-  /// same operation).
+  /// it with a parentKeepAlive.
   static PyOperationRef
   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
                pybind11::object parentKeepAlive = pybind11::object());
@@ -326,6 +368,14 @@ class PyOperation : public BaseContextObject {
                           bool enableDebugInfo, bool prettyDebugInfo,
                           bool printGenericOpForm, bool useLocalScope);
 
+  /// Gets the owning block or raises an exception if the operation has no
+  /// owning block.
+  PyBlock getBlock();
+
+  /// Gets the parent operation or raises an exception if the operation has
+  /// no parent.
+  PyOperationRef getParentOperation();
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,
@@ -403,6 +453,41 @@ class PyBlock {
   MlirBlock block;
 };
 
+/// An insertion point maintains a pointer to a Block and a reference operation.
+/// Calls to insert() will insert a new operation before the
+/// reference operation. If the reference operation is null, then appends to
+/// the end of the block.
+class PyInsertionPoint {
+public:
+  /// Creates an insertion point positioned after the last operation in the
+  /// block, but still inside the block.
+  PyInsertionPoint(PyBlock &block);
+  /// Creates an insertion point positioned before a reference operation.
+  PyInsertionPoint(PyOperation &beforeOperation);
+
+  /// Shortcut to create an insertion point at the beginning of the block.
+  static PyInsertionPoint atBlockBegin(PyBlock &block);
+  /// Shortcut to create an insertion point before the block terminator.
+  static PyInsertionPoint atBlockTerminator(PyBlock &block);
+
+  /// Inserts an operation.
+  void insert(PyOperation &operation);
+
+  /// Enter and exit the context manager.
+  pybind11::object contextEnter();
+  void contextExit(pybind11::object excType, pybind11::object excVal,
+                   pybind11::object excTb);
+
+private:
+  // Trampoline constructor that avoids null initializing members while
+  // looking up parents.
+  PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
+      : block(std::move(block)), refOperation(std::move(refOperation)) {}
+
+  PyBlock block;
+  llvm::Optional<PyOperationRef> refOperation;
+};
+
 /// Wrapper around the generic MlirAttribute.
 /// The lifetime of a type is bound by the PyContext that created it.
 class PyAttribute : public BaseContextObject {

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index f3c91d1fae24..7bc89f66e2d5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -249,6 +249,14 @@ int mlirOperationEqual(MlirOperation op, MlirOperation other) {
   return unwrap(op) == unwrap(other);
 }
 
+MlirBlock mlirOperationGetBlock(MlirOperation op) {
+  return wrap(unwrap(op)->getBlock());
+}
+
+MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
+  return wrap(unwrap(op)->getParentOp());
+}
+
 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
 }
@@ -403,6 +411,16 @@ MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
   return wrap(&cppBlock->front());
 }
 
+MlirOperation mlirBlockGetTerminator(MlirBlock block) {
+  Block *cppBlock = unwrap(block);
+  if (cppBlock->empty())
+    return wrap(static_cast<Operation *>(nullptr));
+  Operation &back = cppBlock->back();
+  if (!back.isKnownTerminator())
+    return wrap(static_cast<Operation *>(nullptr));
+  return wrap(&back);
+}
+
 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
   unwrap(block)->push_back(unwrap(operation));
 }

diff  --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index ef95163c7743..172258dd7840 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -73,30 +73,21 @@ def testCustomOpView():
   f32 = mlir.ir.F32Type.get(ctx)
   loc = ctx.get_unknown_location()
   m = ctx.create_module(loc)
-  m_block = m.body
-  # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
-  ip = [0]
+
   def createInput():
     op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32])
-    m_block.operations.insert(ip[0], op)
-    ip[0] += 1
     # TODO: Auto result cast from operation
     return op.results[0]
 
-  # Create via dialects context collection.
-  input1 = createInput()
-  input2 = createInput()
-  op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
-  # TODO: Auto operation cast from OpView
-  # TODO: Context manager insertion point
-  m_block.operations.insert(ip[0], op1.operation)
-  ip[0] += 1
-
-  # Create via an import
-  from mlir.dialects.std import AddFOp
-  op2 = AddFOp(loc, input1, op1.result)
-  m_block.operations.insert(ip[0], op2.operation)
-  ip[0] += 1
+  with mlir.ir.InsertionPoint.at_block_terminator(m.body):
+    # Create via dialects context collection.
+    input1 = createInput()
+    input2 = createInput()
+    op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
+
+    # Create via an import
+    from mlir.dialects.std import AddFOp
+    AddFOp(loc, input1, op1.result)
 
   # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
   # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"

diff  --git a/mlir/test/Bindings/Python/insertion_point.py b/mlir/test/Bindings/Python/insertion_point.py
new file mode 100644
index 000000000000..bbdd670e6aa7
--- /dev/null
+++ b/mlir/test/Bindings/Python/insertion_point.py
@@ -0,0 +1,152 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import io
+import itertools
+from mlir.ir import *
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: test_insert_at_block_end
+def test_insert_at_block_end():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  module = ctx.parse_module(r"""
+    func @foo() -> () {
+      "custom.op1"() : () -> ()
+    }
+  """)
+  entry_block = module.body.operations[0].regions[0].blocks[0]
+  ip = InsertionPoint(entry_block)
+  ip.insert(ctx.create_operation("custom.op2", loc))
+  # CHECK: "custom.op1"
+  # CHECK: "custom.op2"
+  module.operation.print()
+
+run(test_insert_at_block_end)
+
+
+# CHECK-LABEL: TEST: test_insert_before_operation
+def test_insert_before_operation():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  module = ctx.parse_module(r"""
+    func @foo() -> () {
+      "custom.op1"() : () -> ()
+      "custom.op2"() : () -> ()
+    }
+  """)
+  entry_block = module.body.operations[0].regions[0].blocks[0]
+  ip = InsertionPoint(entry_block.operations[1])
+  ip.insert(ctx.create_operation("custom.op3", loc))
+  # CHECK: "custom.op1"
+  # CHECK: "custom.op3"
+  # CHECK: "custom.op2"
+  module.operation.print()
+
+run(test_insert_before_operation)
+
+
+# CHECK-LABEL: TEST: test_insert_at_block_begin
+def test_insert_at_block_begin():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  module = ctx.parse_module(r"""
+    func @foo() -> () {
+      "custom.op2"() : () -> ()
+    }
+  """)
+  entry_block = module.body.operations[0].regions[0].blocks[0]
+  ip = InsertionPoint.at_block_begin(entry_block)
+  ip.insert(ctx.create_operation("custom.op1", loc))
+  # CHECK: "custom.op1"
+  # CHECK: "custom.op2"
+  module.operation.print()
+
+run(test_insert_at_block_begin)
+
+
+# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
+def test_insert_at_block_begin_empty():
+  # TODO: Write this test case when we can create such a situation.
+  pass
+
+run(test_insert_at_block_begin_empty)
+
+
+# CHECK-LABEL: TEST: test_insert_at_terminator
+def test_insert_at_terminator():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  module = ctx.parse_module(r"""
+    func @foo() -> () {
+      "custom.op1"() : () -> ()
+      return
+    }
+  """)
+  entry_block = module.body.operations[0].regions[0].blocks[0]
+  ip = InsertionPoint.at_block_terminator(entry_block)
+  ip.insert(ctx.create_operation("custom.op2", loc))
+  # CHECK: "custom.op1"
+  # CHECK: "custom.op2"
+  module.operation.print()
+
+run(test_insert_at_terminator)
+
+
+# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
+def test_insert_at_block_terminator_missing():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  module = ctx.parse_module(r"""
+    func @foo() -> () {
+      "custom.op1"() : () -> ()
+    }
+  """)
+  entry_block = module.body.operations[0].regions[0].blocks[0]
+  try:
+    ip = InsertionPoint.at_block_terminator(entry_block)
+  except ValueError as e:
+    # CHECK: Block has no terminator
+    print(e)
+  else:
+    assert False, "Expected exception"
+
+run(test_insert_at_block_terminator_missing)
+
+
+# CHECK-LABEL: TEST: test_insertion_point_context
+def test_insertion_point_context():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  loc = ctx.get_unknown_location()
+  module = ctx.parse_module(r"""
+    func @foo() -> () {
+      "custom.op1"() : () -> ()
+    }
+  """)
+  entry_block = module.body.operations[0].regions[0].blocks[0]
+  with InsertionPoint(entry_block):
+    ctx.create_operation("custom.op2", loc)
+    with InsertionPoint.at_block_begin(entry_block):
+      ctx.create_operation("custom.opa", loc)
+      ctx.create_operation("custom.opb", loc)
+    ctx.create_operation("custom.op3", loc)
+  # CHECK: "custom.opa"
+  # CHECK: "custom.opb"
+  # CHECK: "custom.op1"
+  # CHECK: "custom.op2"
+  # CHECK: "custom.op3"
+  module.operation.print()
+
+run(test_insertion_point_context)

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 84f303ca570b..8bc2ced60dca 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -111,7 +111,7 @@ def testBlockArgumentList():
       return
     }
   """)
-  func = module.operation.regions[0].blocks[0].operations[0]
+  func = module.body.operations[0]
   entry_block = func.regions[0].blocks[0]
   assert len(entry_block.arguments) == 3
   # CHECK: Argument 0, type i32
@@ -152,8 +152,8 @@ def testDetachedOperation():
 run(testDetachedOperation)
 
 
-# CHECK-LABEL: TEST: testOperationInsert
-def testOperationInsert():
+# CHECK-LABEL: TEST: testOperationInsertionPoint
+def testOperationInsertionPoint():
   ctx = mlir.ir.Context()
   ctx.allow_unregistered_dialects = True
   module = ctx.parse_module(r"""
@@ -168,10 +168,11 @@ def testOperationInsert():
   op1 = ctx.create_operation("custom.op1", loc)
   op2 = ctx.create_operation("custom.op2", loc)
 
-  func = module.operation.regions[0].blocks[0].operations[0]
+  func = module.body.operations[0]
   entry_block = func.regions[0].blocks[0]
-  entry_block.operations.insert(0, op1)
-  entry_block.operations.insert(1, op2)
+  ip = mlir.ir.InsertionPoint.at_block_begin(entry_block)
+  ip.insert(op1)
+  ip.insert(op2)
   # CHECK: func @f1
   # CHECK: "custom.op1"()
   # CHECK: "custom.op2"()
@@ -180,13 +181,13 @@ def testOperationInsert():
 
   # Trying to add a previously added op should raise.
   try:
-    entry_block.operations.insert(0, op1)
+    ip.insert(op1)
   except ValueError:
     pass
   else:
     assert False, "expected insert of attached op to raise"
 
-run(testOperationInsert)
+run(testOperationInsertionPoint)
 
 
 # CHECK-LABEL: TEST: testOperationWithRegion
@@ -202,7 +203,8 @@ def testOperationWithRegion():
   # CHECK:   "custom.terminator"() : () -> ()
   # CHECK: }) : () -> ()
   terminator = ctx.create_operation("custom.terminator", loc)
-  block.operations.insert(0, terminator)
+  ip = mlir.ir.InsertionPoint(block)
+  ip.insert(terminator)
   print(op1)
 
   # Now add the whole operation to another op.
@@ -216,9 +218,10 @@ def testOperationWithRegion():
       return %1 : i32
     }
   """)
-  func = module.operation.regions[0].blocks[0].operations[0]
+  func = module.body.operations[0]
   entry_block = func.regions[0].blocks[0]
-  entry_block.operations.insert(0, op1)
+  ip = mlir.ir.InsertionPoint.at_block_begin(entry_block)
+  ip.insert(op1)
   # CHECK: func @f1
   # CHECK: "custom.op1"()
   # CHECK:   "custom.terminator"
@@ -238,7 +241,7 @@ def testOperationResultList():
     }
     func @f2() -> (i32, f64, index)
   """)
-  caller = module.operation.regions[0].blocks[0].operations[0]
+  caller = module.body.operations[0]
   call = caller.regions[0].blocks[0].operations[0]
   assert len(call.results) == 3
   # CHECK: Result 0, type i32

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 7a19b2dd8f69..51b3dca88253 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -261,18 +261,32 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   MlirBlock block = mlirRegionGetFirstBlock(region);
   operation = mlirBlockGetFirstOperation(block);
   region = mlirOperationGetRegion(operation, 0);
+  MlirOperation parentOperation = operation;
   block = mlirRegionGetFirstBlock(region);
   operation = mlirBlockGetFirstOperation(block);
 
-  // In the module we created, the first operation of the first function is an
-  // "std.dim", which has an attribute and a single result that we can use to
-  // test the printing mechanism.
+  // Verify that parent operation and block report correctly.
+  fprintf(stderr, "Parent operation eq: %d\n",
+          mlirOperationEqual(mlirOperationGetParentOperation(operation),
+                             parentOperation));
+  fprintf(stderr, "Block eq: %d\n",
+          mlirBlockEqual(mlirOperationGetBlock(operation), block));
+
+  // In the module we created, the first operation of the first function is
+  // an "std.dim", which has an attribute and a single result that we can
+  // use to test the printing mechanism.
   mlirBlockPrint(block, printToStderr, NULL);
   fprintf(stderr, "\n");
   fprintf(stderr, "First operation: ");
   mlirOperationPrint(operation, printToStderr, NULL);
   fprintf(stderr, "\n");
 
+  // Get the block terminator and print it.
+  MlirOperation terminator = mlirBlockGetTerminator(block);
+  fprintf(stderr, "Terminator: ");
+  mlirOperationPrint(terminator, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
   // Get the attribute by index.
   MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
   fprintf(stderr, "Get attr 0: ");
@@ -1100,6 +1114,8 @@ int main() {
 
   printFirstOfEach(ctx, module);
   // clang-format off
+  // CHECK: Parent operation eq: 1
+  // CHECK: Block eq: 1
   // CHECK:   %[[C0:.*]] = constant 0 : index
   // CHECK:   %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
   // CHECK:   %[[C1:.*]] = constant 1 : index
@@ -1111,6 +1127,7 @@ int main() {
   // CHECK:   }
   // CHECK: return
   // CHECK: First operation: {{.*}} = constant 0 : index
+  // CHECK: Terminator: return
   // CHECK: Get attr 0: 0 : index
   // CHECK: Get attr 0 by name: 0 : index
   // CHECK: does_not_exist is null: 1


        


More information about the Mlir-commits mailing list