[Mlir-commits] [mlir] 4cf754c - Implement python iteration over the operation/region/block hierarchy.

Stella Laurenzo llvmlistbot at llvm.org
Wed Sep 23 07:59:01 PDT 2020


Author: Stella Laurenzo
Date: 2020-09-23T07:57:50-07:00
New Revision: 4cf754c4bca94e957b634a854f57f4c7ec9151fb

URL: https://github.com/llvm/llvm-project/commit/4cf754c4bca94e957b634a854f57f4c7ec9151fb
DIFF: https://github.com/llvm/llvm-project/commit/4cf754c4bca94e957b634a854f57f4c7ec9151fb.diff

LOG: Implement python iteration over the operation/region/block hierarchy.

* Removes the half-completed prior attempt at region/block mutation in favor of new approach to ownership.
* Will re-add mutation more correctly in a follow-on.
* Eliminates the detached state on blocks and regions, simplifying the ownership hierarchy.
* Adds both iterator and index based access at each level.

Differential Revision: https://reviews.llvm.org/D87982

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/CAPI/IR/CMakeLists.txt
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 61a106515196..b9c5bec3aa44 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -91,6 +91,12 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
 /** Takes an MLIR context owned by the caller and destroys it. */
 void mlirContextDestroy(MlirContext context);
 
+/** Sets whether unregistered dialects are allowed in this context. */
+void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow);
+
+/** Returns whether the context allows unregistered dialects. */
+int mlirContextGetAllowUnregisteredDialects(MlirContext context);
+
 /*============================================================================*/
 /* Location API.                                                              */
 /*============================================================================*/

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 66e975e3ea56..8eab7dab1675 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -46,45 +46,6 @@ static const char kContextGetUnknownLocationDocstring[] =
 static const char kContextGetFileLocationDocstring[] =
     R"(Gets a Location representing a file, line and column)";
 
-static const char kContextCreateBlockDocstring[] =
-    R"(Creates a detached block)";
-
-static const char kContextCreateRegionDocstring[] =
-    R"(Creates a detached region)";
-
-static const char kRegionAppendBlockDocstring[] =
-    R"(Appends a block to a region.
-
-Raises:
-  ValueError: If the block is already attached to another region.
-)";
-
-static const char kRegionInsertBlockDocstring[] =
-    R"(Inserts a block at a postiion in a region.
-
-Raises:
-  ValueError: If the block is already attached to another region.
-)";
-
-static const char kRegionFirstBlockDocstring[] =
-    R"(Gets the first block in a region.
-
-Blocks can also be accessed via the `blocks` container.
-
-Raises:
-  IndexError: If the region has no blocks.
-)";
-
-static const char kBlockNextInRegionDocstring[] =
-    R"(Gets the next block in the enclosing region.
-
-Blocks can also be accessed via the `blocks` container of the owning region.
-This method exists to mirror the lower level API and should not be preferred.
-
-Raises:
-  IndexError: If there are no further blocks.
-)";
-
 static const char kOperationStrDunderDocstring[] =
     R"(Prints the assembly form of the operation with default options.
 
@@ -170,6 +131,241 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
 
 } // namespace
 
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+namespace {
+
+class PyRegionIterator {
+public:
+  PyRegionIterator(PyOperationRef operation)
+      : operation(std::move(operation)) {}
+
+  PyRegionIterator &dunderIter() { return *this; }
+
+  PyRegion dunderNext() {
+    operation->checkValid();
+    if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+      throw py::stop_iteration();
+    }
+    MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+    return PyRegion(operation, region);
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyRegionIterator>(m, "RegionIterator")
+        .def("__iter__", &PyRegionIterator::dunderIter)
+        .def("__next__", &PyRegionIterator::dunderNext);
+  }
+
+private:
+  PyOperationRef operation;
+  int nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class PyRegionList {
+public:
+  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+
+  intptr_t dunderLen() {
+    operation->checkValid();
+    return mlirOperationGetNumRegions(operation->get());
+  }
+
+  PyRegion dunderGetItem(intptr_t index) {
+    // dunderLen checks validity.
+    if (index < 0 || index >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds region");
+    }
+    MlirRegion region = mlirOperationGetRegion(operation->get(), index);
+    return PyRegion(operation, region);
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyRegionList>(m, "ReqionSequence")
+        .def("__len__", &PyRegionList::dunderLen)
+        .def("__getitem__", &PyRegionList::dunderGetItem);
+  }
+
+private:
+  PyOperationRef operation;
+};
+
+class PyBlockIterator {
+public:
+  PyBlockIterator(PyOperationRef operation, MlirBlock next)
+      : operation(std::move(operation)), next(next) {}
+
+  PyBlockIterator &dunderIter() { return *this; }
+
+  PyBlock dunderNext() {
+    operation->checkValid();
+    if (mlirBlockIsNull(next)) {
+      throw py::stop_iteration();
+    }
+
+    PyBlock returnBlock(operation, next);
+    next = mlirBlockGetNextInRegion(next);
+    return returnBlock;
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyBlockIterator>(m, "BlockIterator")
+        .def("__iter__", &PyBlockIterator::dunderIter)
+        .def("__next__", &PyBlockIterator::dunderNext);
+  }
+
+private:
+  PyOperationRef operation;
+  MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimzie
+/// it for forward iteration. Blocks are always owned by a region.
+class PyBlockList {
+public:
+  PyBlockList(PyOperationRef operation, MlirRegion region)
+      : operation(std::move(operation)), region(region) {}
+
+  PyBlockIterator dunderIter() {
+    operation->checkValid();
+    return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+  }
+
+  intptr_t dunderLen() {
+    operation->checkValid();
+    intptr_t count = 0;
+    MlirBlock block = mlirRegionGetFirstBlock(region);
+    while (!mlirBlockIsNull(block)) {
+      count += 1;
+      block = mlirBlockGetNextInRegion(block);
+    }
+    return count;
+  }
+
+  PyBlock dunderGetItem(intptr_t index) {
+    operation->checkValid();
+    if (index < 0) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds block");
+    }
+    MlirBlock block = mlirRegionGetFirstBlock(region);
+    while (!mlirBlockIsNull(block)) {
+      if (index == 0) {
+        return PyBlock(operation, block);
+      }
+      block = mlirBlockGetNextInRegion(block);
+      index -= 1;
+    }
+    throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyBlockList>(m, "BlockList")
+        .def("__getitem__", &PyBlockList::dunderGetItem)
+        .def("__iter__", &PyBlockList::dunderIter)
+        .def("__len__", &PyBlockList::dunderLen);
+  }
+
+private:
+  PyOperationRef operation;
+  MlirRegion region;
+};
+
+class PyOperationIterator {
+public:
+  PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+      : parentOperation(std::move(parentOperation)), next(next) {}
+
+  PyOperationIterator &dunderIter() { return *this; }
+
+  py::object dunderNext() {
+    parentOperation->checkValid();
+    if (mlirOperationIsNull(next)) {
+      throw py::stop_iteration();
+    }
+
+    PyOperationRef returnOperation =
+        PyOperation::forOperation(parentOperation->getContext(), next);
+    next = mlirOperationGetNextInBlock(next);
+    return returnOperation.releaseObject();
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyOperationIterator>(m, "OperationIterator")
+        .def("__iter__", &PyOperationIterator::dunderIter)
+        .def("__next__", &PyOperationIterator::dunderNext);
+  }
+
+private:
+  PyOperationRef parentOperation;
+  MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimzie it for forward iteration. Iterable operations are always owned
+/// by a block.
+class PyOperationList {
+public:
+  PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+      : parentOperation(std::move(parentOperation)), block(block) {}
+
+  PyOperationIterator dunderIter() {
+    parentOperation->checkValid();
+    return PyOperationIterator(parentOperation,
+                               mlirBlockGetFirstOperation(block));
+  }
+
+  intptr_t dunderLen() {
+    parentOperation->checkValid();
+    intptr_t count = 0;
+    MlirOperation childOp = mlirBlockGetFirstOperation(block);
+    while (!mlirOperationIsNull(childOp)) {
+      count += 1;
+      childOp = mlirOperationGetNextInBlock(childOp);
+    }
+    return count;
+  }
+
+  py::object dunderGetItem(intptr_t index) {
+    parentOperation->checkValid();
+    if (index < 0) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds operation");
+    }
+    MlirOperation childOp = mlirBlockGetFirstOperation(block);
+    while (!mlirOperationIsNull(childOp)) {
+      if (index == 0) {
+        return PyOperation::forOperation(parentOperation->getContext(), childOp)
+            .releaseObject();
+      }
+      childOp = mlirOperationGetNextInBlock(childOp);
+      index -= 1;
+    }
+    throw SetPyError(PyExc_IndexError,
+                     "attempt to access out of bounds operation");
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyOperationList>(m, "OperationList")
+        .def("__getitem__", &PyOperationList::dunderGetItem)
+        .def("__iter__", &PyOperationList::dunderIter)
+        .def("__len__", &PyOperationList::dunderLen);
+  }
+
+private:
+  PyOperationRef parentOperation;
+  MlirBlock block;
+};
+
+} // namespace
+
 //------------------------------------------------------------------------------
 // PyMlirContext
 //------------------------------------------------------------------------------
@@ -309,24 +505,6 @@ void PyOperation::checkValid() {
   }
 }
 
-//------------------------------------------------------------------------------
-// PyBlock, PyRegion.
-//------------------------------------------------------------------------------
-
-void PyRegion::attachToParent() {
-  if (!detached) {
-    throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
-  }
-  detached = false;
-}
-
-void PyBlock::attachToParent() {
-  if (!detached) {
-    throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
-  }
-  detached = false;
-}
-
 //------------------------------------------------------------------------------
 // PyAttribute.
 //------------------------------------------------------------------------------
@@ -967,6 +1145,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
              return ref.releaseObject();
            })
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+      .def_property(
+          "allow_unregistered_dialects",
+          [](PyMlirContext &self) -> bool {
+            return mlirContextGetAllowUnregisteredDialects(self.get());
+          },
+          [](PyMlirContext &self, bool value) {
+            mlirContextSetAllowUnregisteredDialects(self.get(), value);
+          })
       .def(
           "parse_module",
           [](PyMlirContext &self, const std::string moduleAsm) {
@@ -1026,37 +1212,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                                   self.get(), filename.c_str(), line, col));
           },
           kContextGetFileLocationDocstring, py::arg("filename"),
-          py::arg("line"), py::arg("col"))
-      .def(
-          "create_region",
-          [](PyMlirContext &self) {
-            // The creating context is explicitly captured on regions to
-            // facilitate illegal assemblies of objects from multiple contexts
-            // that would invalidate the memory model.
-            return PyRegion(self.get(), mlirRegionCreate(),
-                            /*detached=*/true);
-          },
-          py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
-      .def(
-          "create_block",
-          [](PyMlirContext &self, std::vector<PyType> pyTypes) {
-            // In order for the keep_alive extend the proper lifetime, all
-            // types must be from the same context.
-            for (auto pyType : pyTypes) {
-              if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
-                                    self.get())) {
-                throw SetPyError(
-                    PyExc_ValueError,
-                    "All types used to construct a block must be from "
-                    "the same context as the block");
-              }
-            }
-            llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
-                                                 pyTypes.end());
-            return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
-                           /*detached=*/true);
-          },
-          py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
+          py::arg("line"), py::arg("col"));
 
   py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
     PyPrintAccumulator printAccum;
@@ -1096,17 +1252,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of Operation.
   py::class_<PyOperation>(m, "Operation")
       .def_property_readonly(
-          "first_region",
-          [](PyOperation &self) {
-            self.checkValid();
-            if (mlirOperationGetNumRegions(self.get()) == 0) {
-              throw SetPyError(PyExc_IndexError, "Operation has no regions");
-            }
-            return PyRegion(self.getContext()->get(),
-                            mlirOperationGetRegion(self.get(), 0),
-                            /*detached=*/false);
-          },
-          py::keep_alive<0, 1>(), "Gets the operation's first region")
+          "regions",
+          [](PyOperation &self) { return PyRegionList(self.getRef()); })
+      .def("__iter__",
+           [](PyOperation &self) { return PyRegionIterator(self.getRef()); })
       .def(
           "__str__",
           [](PyOperation &self) {
@@ -1120,63 +1269,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Mapping of PyRegion.
   py::class_<PyRegion>(m, "Region")
-      .def(
-          "append_block",
-          [](PyRegion &self, PyBlock &block) {
-            if (!mlirContextEqual(self.context, block.context)) {
-              throw SetPyError(
-                  PyExc_ValueError,
-                  "Block must have been created from the same context as "
-                  "this region");
-            }
-
-            block.attachToParent();
-            mlirRegionAppendOwnedBlock(self.region, block.block);
+      .def_property_readonly(
+          "blocks",
+          [](PyRegion &self) {
+            return PyBlockList(self.getParentOperation(), self.get());
           },
-          kRegionAppendBlockDocstring)
+          "Returns a forward-optimized sequence of blocks.")
       .def(
-          "insert_block",
-          [](PyRegion &self, int pos, PyBlock &block) {
-            if (!mlirContextEqual(self.context, block.context)) {
-              throw SetPyError(
-                  PyExc_ValueError,
-                  "Block must have been created from the same context as "
-                  "this region");
-            }
-            block.attachToParent();
-            // TODO: Make this return a failure and raise if out of bounds.
-            mlirRegionInsertOwnedBlock(self.region, pos, block.block);
-          },
-          kRegionInsertBlockDocstring)
-      .def_property_readonly(
-          "first_block",
+          "__iter__",
           [](PyRegion &self) {
-            MlirBlock block = mlirRegionGetFirstBlock(self.region);
-            if (mlirBlockIsNull(block)) {
-              throw SetPyError(PyExc_IndexError, "Region has no blocks");
-            }
-            return PyBlock(self.context, block, /*detached=*/false);
+            self.checkValid();
+            MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+            return PyBlockIterator(self.getParentOperation(), firstBlock);
           },
-          kRegionFirstBlockDocstring);
+          "Iterates over blocks in the region.")
+      .def("__eq__", [](PyRegion &self, py::object &other) {
+        try {
+          PyRegion *otherRegion = other.cast<PyRegion *>();
+          return self.get().ptr == otherRegion->get().ptr;
+        } catch (std::exception &e) {
+          return false;
+        }
+      });
 
   // Mapping of PyBlock.
   py::class_<PyBlock>(m, "Block")
       .def_property_readonly(
-          "next_in_region",
+          "operations",
           [](PyBlock &self) {
-            MlirBlock block = mlirBlockGetNextInRegion(self.block);
-            if (mlirBlockIsNull(block)) {
-              throw SetPyError(PyExc_IndexError,
-                               "Attempt to read past last block");
-            }
-            return PyBlock(self.context, block, /*detached=*/false);
+            return PyOperationList(self.getParentOperation(), self.get());
           },
-          py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
+          "Returns a forward-optimized sequence of operations.")
+      .def(
+          "__iter__",
+          [](PyBlock &self) {
+            self.checkValid();
+            MlirOperation firstOperation =
+                mlirBlockGetFirstOperation(self.get());
+            return PyOperationIterator(self.getParentOperation(),
+                                       firstOperation);
+          },
+          "Iterates over operations in the block.")
+      .def("__eq__",
+           [](PyBlock &self, py::object &other) {
+             try {
+               PyBlock *otherBlock = other.cast<PyBlock *>();
+               return self.get().ptr == otherBlock->get().ptr;
+             } catch (std::exception &e) {
+               return false;
+             }
+           })
       .def(
           "__str__",
           [](PyBlock &self) {
+            self.checkValid();
             PyPrintAccumulator printAccum;
-            mlirBlockPrint(self.block, printAccum.getCallback(),
+            mlirBlockPrint(self.get(), printAccum.getCallback(),
                            printAccum.getUserData());
             return printAccum.join();
           },
@@ -1310,4 +1458,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyMemRefType::bind(m);
   PyUnrankedMemRefType::bind(m);
   PyTupleType::bind(m);
+
+  // Container bindings.
+  PyBlockIterator::bind(m);
+  PyBlockList::bind(m);
+  PyOperationIterator::bind(m);
+  PyOperationList::bind(m);
+  PyRegionIterator::bind(m);
+  PyRegionList::bind(m);
 }

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index a7f6ee2425ad..06b697cfd786 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -249,69 +249,43 @@ class PyOperation : public BaseContextObject {
 };
 
 /// Wrapper around an MlirRegion.
-/// Note that region can exist in a detached state (where this instance is
-/// responsible for clearing) or an attached state (where its owner is
-/// responsible).
-///
-/// This python wrapper retains a redundant reference to its creating context
-/// in order to facilitate checking that parts of the operation hierarchy
-/// are only assembled from the same context.
+/// Regions are managed completely by their containing operation. Unlike the
+/// C++ API, the python API does not support detached regions.
 class PyRegion {
 public:
-  PyRegion(MlirContext context, MlirRegion region, bool detached)
-      : context(context), region(region), detached(detached) {}
-  PyRegion(PyRegion &&other)
-      : context(other.context), region(other.region), detached(other.detached) {
-    other.detached = false;
-  }
-  ~PyRegion() {
-    if (detached)
-      mlirRegionDestroy(region);
+  PyRegion(PyOperationRef parentOperation, MlirRegion region)
+      : parentOperation(std::move(parentOperation)), region(region) {
+    assert(!mlirRegionIsNull(region) && "python region cannot be null");
   }
 
-  // Call prior to attaching the region to a parent.
-  // This will transition to the attached state and will throw an exception
-  // if already attached.
-  void attachToParent();
+  MlirRegion get() { return region; }
+  PyOperationRef &getParentOperation() { return parentOperation; }
 
-  MlirContext context;
-  MlirRegion region;
+  void checkValid() { return parentOperation->checkValid(); }
 
 private:
-  bool detached;
+  PyOperationRef parentOperation;
+  MlirRegion region;
 };
 
 /// Wrapper around an MlirBlock.
-/// Note that blocks can exist in a detached state (where this instance is
-/// responsible for clearing) or an attached state (where its owner is
-/// responsible).
-///
-/// This python wrapper retains a redundant reference to its creating context
-/// in order to facilitate checking that parts of the operation hierarchy
-/// are only assembled from the same context.
+/// Blocks are managed completely by their containing operation. Unlike the
+/// C++ API, the python API does not support detached blocks.
 class PyBlock {
 public:
-  PyBlock(MlirContext context, MlirBlock block, bool detached)
-      : context(context), block(block), detached(detached) {}
-  PyBlock(PyBlock &&other)
-      : context(other.context), block(other.block), detached(other.detached) {
-    other.detached = false;
-  }
-  ~PyBlock() {
-    if (detached)
-      mlirBlockDestroy(block);
+  PyBlock(PyOperationRef parentOperation, MlirBlock block)
+      : parentOperation(std::move(parentOperation)), block(block) {
+    assert(!mlirBlockIsNull(block) && "python block cannot be null");
   }
 
-  // Call prior to attaching the block to a parent.
-  // This will transition to the attached state and will throw an exception
-  // if already attached.
-  void attachToParent();
+  MlirBlock get() { return block; }
+  PyOperationRef &getParentOperation() { return parentOperation; }
 
-  MlirContext context;
-  MlirBlock block;
+  void checkValid() { return parentOperation->checkValid(); }
 
 private:
-  bool detached;
+  PyOperationRef parentOperation;
+  MlirBlock block;
 };
 
 /// Wrapper around the generic MlirAttribute.

diff  --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 4158a4c96efd..e73269ce14f1 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRCAPIIR
   ${MLIR_MAIN_INCLUDE_DIR}/mlir-c
 
   LINK_LIBS PUBLIC
+  MLIRStandardOps
   MLIRIR
   MLIRParser
   MLIRSupport

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 0304d977f494..2265df1c8234 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Module.h"
@@ -25,6 +26,10 @@ using namespace mlir;
 
 MlirContext mlirContextCreate() {
   auto *context = new MLIRContext(/*loadAllDialects=*/false);
+  // TODO: Come up with a story for which dialects to load into the context
+  // and do not expand this beyond StandardOps until done so. This is loaded
+  // by default here because it is hard to make progress otherwise.
+  context->loadDialect<StandardOpsDialect>();
   return wrap(context);
 }
 
@@ -34,6 +39,14 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
 
 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
 
+void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
+  unwrap(context)->allowUnregisteredDialects(allow);
+}
+
+int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
+  return unwrap(context)->allowsUnregisteredDialects();
+}
+
 /* ========================================================================== */
 /* Location API.                                                              */
 /* ========================================================================== */

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 9c4c33a10ab8..9522e4b1ad98 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -1,6 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
+import itertools
 import mlir
 
 def run(f):
@@ -10,65 +11,91 @@ def run(f):
   assert mlir.ir.Context._get_live_count() == 0
 
 
-# CHECK-LABEL: TEST: testDetachedRegionBlock
-def testDetachedRegionBlock():
+# Verify iterator based traversal of the op/region/block hierarchy.
+# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
+def testTraverseOpRegionBlockIterators():
   ctx = mlir.ir.Context()
-  t = mlir.ir.F32Type(ctx)
-  region = ctx.create_region()
-  block = ctx.create_block([t, t])
-  # CHECK: <<UNLINKED BLOCK>>
-  print(block)
+  ctx.allow_unregistered_dialects = True
+  module = ctx.parse_module(r"""
+    func @f1(%arg0: i32) -> i32 {
+      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+      return %1 : i32
+    }
+  """)
+  op = module.operation
+  # Get the block using iterators off of the named collections.
+  regions = list(op.regions)
+  blocks = list(regions[0].blocks)
+  # CHECK: MODULE REGIONS=1 BLOCKS=1
+  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
 
-run(testDetachedRegionBlock)
+  # Get the regions and blocks from the default collections.
+  default_regions = list(op)
+  default_blocks = list(default_regions[0])
+  # They should compare equal regardless of how obtained.
+  assert default_regions == regions
+  assert default_blocks == blocks
 
+  # Should be able to get the operations from either the named collection
+  # or the block.
+  operations = list(blocks[0].operations)
+  default_operations = list(blocks[0])
+  assert default_operations == operations
 
-# CHECK-LABEL: TEST: testBlockTypeContextMismatch
-def testBlockTypeContextMismatch():
-  c1 = mlir.ir.Context()
-  c2 = mlir.ir.Context()
-  t1 = mlir.ir.F32Type(c1)
-  t2 = mlir.ir.F32Type(c2)
-  try:
-    block = c1.create_block([t1, t2])
-  except ValueError as e:
-    # CHECK: ERROR: All types used to construct a block must be from the same context as the block
-    print("ERROR:", e)
+  def walk_operations(indent, op):
+    for i, region in enumerate(op):
+      print(f"{indent}REGION {i}:")
+      for j, block in enumerate(region):
+        print(f"{indent}  BLOCK {j}:")
+        for k, child_op in enumerate(block):
+          print(f"{indent}    OP {k}: {child_op}")
+          walk_operations(indent + "      ", child_op)
 
-run(testBlockTypeContextMismatch)
+  # CHECK: REGION 0:
+  # CHECK:   BLOCK 0:
+  # CHECK:     OP 0: func
+  # CHECK:       REGION 0:
+  # CHECK:         BLOCK 0:
+  # CHECK:           OP 0: %0 = "custom.addi"
+  # CHECK:           OP 1: return
+  # CHECK:    OP 1: "module_terminator"
+  walk_operations("", op)
 
+run(testTraverseOpRegionBlockIterators)
 
-# CHECK-LABEL: TEST: testBlockAppend
-def testBlockAppend():
+
+# Verify index based traversal of the op/region/block hierarchy.
+# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
+def testTraverseOpRegionBlockIndices():
   ctx = mlir.ir.Context()
-  t = mlir.ir.F32Type(ctx)
-  region = ctx.create_region()
-  try:
-    region.first_block
-  except IndexError:
-    pass
-  else:
-    raise RuntimeError("Expected exception not raised")
+  ctx.allow_unregistered_dialects = True
+  module = ctx.parse_module(r"""
+    func @f1(%arg0: i32) -> i32 {
+      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+      return %1 : i32
+    }
+  """)
 
-  block = ctx.create_block([t, t])
-  region.append_block(block)
-  try:
-    region.append_block(block)
-  except ValueError:
-    pass
-  else:
-    raise RuntimeError("Expected exception not raised")
+  def walk_operations(indent, op):
+    for i in range(len(op.regions)):
+      region = op.regions[i]
+      print(f"{indent}REGION {i}:")
+      for j in range(len(region.blocks)):
+        block = region.blocks[j]
+        print(f"{indent}  BLOCK {j}:")
+        for k in range(len(block.operations)):
+          child_op = block.operations[k]
+          print(f"{indent}    OP {k}: {child_op}")
+          walk_operations(indent + "      ", child_op)
 
-  block2 = ctx.create_block([t])
-  region.insert_block(1, block2)
-  # CHECK: <<UNLINKED BLOCK>>
-  block_first = region.first_block
-  print(block_first)
-  block_next = block_first.next_in_region
-  try:
-    block_next = block_next.next_in_region
-  except IndexError:
-    pass
-  else:
-    raise RuntimeError("Expected exception not raised")
+  # CHECK: REGION 0:
+  # CHECK:   BLOCK 0:
+  # CHECK:     OP 0: func
+  # CHECK:       REGION 0:
+  # CHECK:         BLOCK 0:
+  # CHECK:           OP 0: %0 = "custom.addi"
+  # CHECK:           OP 1: return
+  # CHECK:    OP 1: "module_terminator"
+  walk_operations("", module.operation)
 
-run(testBlockAppend)
+run(testTraverseOpRegionBlockIndices)


        


More information about the Mlir-commits mailing list