[Mlir-commits] [mlir] 4cf754c - Implement python iteration over the operation/region/block hierarchy.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Sep 23 07:59:01 PDT 2020
Author: Stella Laurenzo
Date: 2020-09-23T07:57:50-07:00
New Revision: 4cf754c4bca94e957b634a854f57f4c7ec9151fb
URL: https://github.com/llvm/llvm-project/commit/4cf754c4bca94e957b634a854f57f4c7ec9151fb
DIFF: https://github.com/llvm/llvm-project/commit/4cf754c4bca94e957b634a854f57f4c7ec9151fb.diff
LOG: Implement python iteration over the operation/region/block hierarchy.
* Removes the half-completed prior attempt at region/block mutation in favor of new approach to ownership.
* Will re-add mutation more correctly in a follow-on.
* Eliminates the detached state on blocks and regions, simplifying the ownership hierarchy.
* Adds both iterator and index based access at each level.
Differential Revision: https://reviews.llvm.org/D87982
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/CAPI/IR/CMakeLists.txt
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/ir_operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 61a106515196..b9c5bec3aa44 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -91,6 +91,12 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
+/** Sets whether unregistered dialects are allowed in this context. */
+void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow);
+
+/** Returns whether the context allows unregistered dialects. */
+int mlirContextGetAllowUnregisteredDialects(MlirContext context);
+
/*============================================================================*/
/* Location API. */
/*============================================================================*/
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 66e975e3ea56..8eab7dab1675 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -46,45 +46,6 @@ static const char kContextGetUnknownLocationDocstring[] =
static const char kContextGetFileLocationDocstring[] =
R"(Gets a Location representing a file, line and column)";
-static const char kContextCreateBlockDocstring[] =
- R"(Creates a detached block)";
-
-static const char kContextCreateRegionDocstring[] =
- R"(Creates a detached region)";
-
-static const char kRegionAppendBlockDocstring[] =
- R"(Appends a block to a region.
-
-Raises:
- ValueError: If the block is already attached to another region.
-)";
-
-static const char kRegionInsertBlockDocstring[] =
- R"(Inserts a block at a postiion in a region.
-
-Raises:
- ValueError: If the block is already attached to another region.
-)";
-
-static const char kRegionFirstBlockDocstring[] =
- R"(Gets the first block in a region.
-
-Blocks can also be accessed via the `blocks` container.
-
-Raises:
- IndexError: If the region has no blocks.
-)";
-
-static const char kBlockNextInRegionDocstring[] =
- R"(Gets the next block in the enclosing region.
-
-Blocks can also be accessed via the `blocks` container of the owning region.
-This method exists to mirror the lower level API and should not be preferred.
-
-Raises:
- IndexError: If there are no further blocks.
-)";
-
static const char kOperationStrDunderDocstring[] =
R"(Prints the assembly form of the operation with default options.
@@ -170,6 +131,241 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
} // namespace
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+namespace {
+
+class PyRegionIterator {
+public:
+ PyRegionIterator(PyOperationRef operation)
+ : operation(std::move(operation)) {}
+
+ PyRegionIterator &dunderIter() { return *this; }
+
+ PyRegion dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw py::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter)
+ .def("__next__", &PyRegionIterator::dunderNext);
+ }
+
+private:
+ PyOperationRef operation;
+ int nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class PyRegionList {
+public:
+ PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+ }
+
+ PyRegion dunderGetItem(intptr_t index) {
+ // dunderLen checks validity.
+ if (index < 0 || index >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds region");
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), index);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyRegionList>(m, "ReqionSequence")
+ .def("__len__", &PyRegionList::dunderLen)
+ .def("__getitem__", &PyRegionList::dunderGetItem);
+ }
+
+private:
+ PyOperationRef operation;
+};
+
+class PyBlockIterator {
+public:
+ PyBlockIterator(PyOperationRef operation, MlirBlock next)
+ : operation(std::move(operation)), next(next) {}
+
+ PyBlockIterator &dunderIter() { return *this; }
+
+ PyBlock dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw py::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter)
+ .def("__next__", &PyBlockIterator::dunderNext);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimzie
+/// it for forward iteration. Blocks are always owned by a region.
+class PyBlockList {
+public:
+ PyBlockList(PyOperationRef operation, MlirRegion region)
+ : operation(std::move(operation)), region(region) {}
+
+ PyBlockIterator dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+ }
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+ }
+
+ PyBlock dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem)
+ .def("__iter__", &PyBlockList::dunderIter)
+ .def("__len__", &PyBlockList::dunderLen);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirRegion region;
+};
+
+class PyOperationIterator {
+public:
+ PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+ : parentOperation(std::move(parentOperation)), next(next) {}
+
+ PyOperationIterator &dunderIter() { return *this; }
+
+ py::object dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw py::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation.releaseObject();
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter)
+ .def("__next__", &PyOperationIterator::dunderNext);
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimzie it for forward iteration. Iterable operations are always owned
+/// by a block.
+class PyOperationList {
+public:
+ PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {}
+
+ PyOperationIterator dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+ }
+
+ intptr_t dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+ }
+
+ py::object dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ .releaseObject();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds operation");
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem)
+ .def("__iter__", &PyOperationList::dunderIter)
+ .def("__len__", &PyOperationList::dunderLen);
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirBlock block;
+};
+
+} // namespace
+
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -309,24 +505,6 @@ void PyOperation::checkValid() {
}
}
-//------------------------------------------------------------------------------
-// PyBlock, PyRegion.
-//------------------------------------------------------------------------------
-
-void PyRegion::attachToParent() {
- if (!detached) {
- throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
- }
- detached = false;
-}
-
-void PyBlock::attachToParent() {
- if (!detached) {
- throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
- }
- detached = false;
-}
-
//------------------------------------------------------------------------------
// PyAttribute.
//------------------------------------------------------------------------------
@@ -967,6 +1145,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+ .def_property(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ })
.def(
"parse_module",
[](PyMlirContext &self, const std::string moduleAsm) {
@@ -1026,37 +1212,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
self.get(), filename.c_str(), line, col));
},
kContextGetFileLocationDocstring, py::arg("filename"),
- py::arg("line"), py::arg("col"))
- .def(
- "create_region",
- [](PyMlirContext &self) {
- // The creating context is explicitly captured on regions to
- // facilitate illegal assemblies of objects from multiple contexts
- // that would invalidate the memory model.
- return PyRegion(self.get(), mlirRegionCreate(),
- /*detached=*/true);
- },
- py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
- .def(
- "create_block",
- [](PyMlirContext &self, std::vector<PyType> pyTypes) {
- // In order for the keep_alive extend the proper lifetime, all
- // types must be from the same context.
- for (auto pyType : pyTypes) {
- if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
- self.get())) {
- throw SetPyError(
- PyExc_ValueError,
- "All types used to construct a block must be from "
- "the same context as the block");
- }
- }
- llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
- pyTypes.end());
- return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
- /*detached=*/true);
- },
- py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
+ py::arg("line"), py::arg("col"));
py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
@@ -1096,17 +1252,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Operation.
py::class_<PyOperation>(m, "Operation")
.def_property_readonly(
- "first_region",
- [](PyOperation &self) {
- self.checkValid();
- if (mlirOperationGetNumRegions(self.get()) == 0) {
- throw SetPyError(PyExc_IndexError, "Operation has no regions");
- }
- return PyRegion(self.getContext()->get(),
- mlirOperationGetRegion(self.get(), 0),
- /*detached=*/false);
- },
- py::keep_alive<0, 1>(), "Gets the operation's first region")
+ "regions",
+ [](PyOperation &self) { return PyRegionList(self.getRef()); })
+ .def("__iter__",
+ [](PyOperation &self) { return PyRegionIterator(self.getRef()); })
.def(
"__str__",
[](PyOperation &self) {
@@ -1120,63 +1269,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of PyRegion.
py::class_<PyRegion>(m, "Region")
- .def(
- "append_block",
- [](PyRegion &self, PyBlock &block) {
- if (!mlirContextEqual(self.context, block.context)) {
- throw SetPyError(
- PyExc_ValueError,
- "Block must have been created from the same context as "
- "this region");
- }
-
- block.attachToParent();
- mlirRegionAppendOwnedBlock(self.region, block.block);
+ .def_property_readonly(
+ "blocks",
+ [](PyRegion &self) {
+ return PyBlockList(self.getParentOperation(), self.get());
},
- kRegionAppendBlockDocstring)
+ "Returns a forward-optimized sequence of blocks.")
.def(
- "insert_block",
- [](PyRegion &self, int pos, PyBlock &block) {
- if (!mlirContextEqual(self.context, block.context)) {
- throw SetPyError(
- PyExc_ValueError,
- "Block must have been created from the same context as "
- "this region");
- }
- block.attachToParent();
- // TODO: Make this return a failure and raise if out of bounds.
- mlirRegionInsertOwnedBlock(self.region, pos, block.block);
- },
- kRegionInsertBlockDocstring)
- .def_property_readonly(
- "first_block",
+ "__iter__",
[](PyRegion &self) {
- MlirBlock block = mlirRegionGetFirstBlock(self.region);
- if (mlirBlockIsNull(block)) {
- throw SetPyError(PyExc_IndexError, "Region has no blocks");
- }
- return PyBlock(self.context, block, /*detached=*/false);
+ self.checkValid();
+ MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+ return PyBlockIterator(self.getParentOperation(), firstBlock);
},
- kRegionFirstBlockDocstring);
+ "Iterates over blocks in the region.")
+ .def("__eq__", [](PyRegion &self, py::object &other) {
+ try {
+ PyRegion *otherRegion = other.cast<PyRegion *>();
+ return self.get().ptr == otherRegion->get().ptr;
+ } catch (std::exception &e) {
+ return false;
+ }
+ });
// Mapping of PyBlock.
py::class_<PyBlock>(m, "Block")
.def_property_readonly(
- "next_in_region",
+ "operations",
[](PyBlock &self) {
- MlirBlock block = mlirBlockGetNextInRegion(self.block);
- if (mlirBlockIsNull(block)) {
- throw SetPyError(PyExc_IndexError,
- "Attempt to read past last block");
- }
- return PyBlock(self.context, block, /*detached=*/false);
+ return PyOperationList(self.getParentOperation(), self.get());
},
- py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
+ "Returns a forward-optimized sequence of operations.")
+ .def(
+ "__iter__",
+ [](PyBlock &self) {
+ self.checkValid();
+ MlirOperation firstOperation =
+ mlirBlockGetFirstOperation(self.get());
+ return PyOperationIterator(self.getParentOperation(),
+ firstOperation);
+ },
+ "Iterates over operations in the block.")
+ .def("__eq__",
+ [](PyBlock &self, py::object &other) {
+ try {
+ PyBlock *otherBlock = other.cast<PyBlock *>();
+ return self.get().ptr == otherBlock->get().ptr;
+ } catch (std::exception &e) {
+ return false;
+ }
+ })
.def(
"__str__",
[](PyBlock &self) {
+ self.checkValid();
PyPrintAccumulator printAccum;
- mlirBlockPrint(self.block, printAccum.getCallback(),
+ mlirBlockPrint(self.get(), printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
@@ -1310,4 +1458,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyMemRefType::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
+
+ // Container bindings.
+ PyBlockIterator::bind(m);
+ PyBlockList::bind(m);
+ PyOperationIterator::bind(m);
+ PyOperationList::bind(m);
+ PyRegionIterator::bind(m);
+ PyRegionList::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index a7f6ee2425ad..06b697cfd786 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -249,69 +249,43 @@ class PyOperation : public BaseContextObject {
};
/// Wrapper around an MlirRegion.
-/// Note that region can exist in a detached state (where this instance is
-/// responsible for clearing) or an attached state (where its owner is
-/// responsible).
-///
-/// This python wrapper retains a redundant reference to its creating context
-/// in order to facilitate checking that parts of the operation hierarchy
-/// are only assembled from the same context.
+/// Regions are managed completely by their containing operation. Unlike the
+/// C++ API, the python API does not support detached regions.
class PyRegion {
public:
- PyRegion(MlirContext context, MlirRegion region, bool detached)
- : context(context), region(region), detached(detached) {}
- PyRegion(PyRegion &&other)
- : context(other.context), region(other.region), detached(other.detached) {
- other.detached = false;
- }
- ~PyRegion() {
- if (detached)
- mlirRegionDestroy(region);
+ PyRegion(PyOperationRef parentOperation, MlirRegion region)
+ : parentOperation(std::move(parentOperation)), region(region) {
+ assert(!mlirRegionIsNull(region) && "python region cannot be null");
}
- // Call prior to attaching the region to a parent.
- // This will transition to the attached state and will throw an exception
- // if already attached.
- void attachToParent();
+ MlirRegion get() { return region; }
+ PyOperationRef &getParentOperation() { return parentOperation; }
- MlirContext context;
- MlirRegion region;
+ void checkValid() { return parentOperation->checkValid(); }
private:
- bool detached;
+ PyOperationRef parentOperation;
+ MlirRegion region;
};
/// Wrapper around an MlirBlock.
-/// Note that blocks can exist in a detached state (where this instance is
-/// responsible for clearing) or an attached state (where its owner is
-/// responsible).
-///
-/// This python wrapper retains a redundant reference to its creating context
-/// in order to facilitate checking that parts of the operation hierarchy
-/// are only assembled from the same context.
+/// Blocks are managed completely by their containing operation. Unlike the
+/// C++ API, the python API does not support detached blocks.
class PyBlock {
public:
- PyBlock(MlirContext context, MlirBlock block, bool detached)
- : context(context), block(block), detached(detached) {}
- PyBlock(PyBlock &&other)
- : context(other.context), block(other.block), detached(other.detached) {
- other.detached = false;
- }
- ~PyBlock() {
- if (detached)
- mlirBlockDestroy(block);
+ PyBlock(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {
+ assert(!mlirBlockIsNull(block) && "python block cannot be null");
}
- // Call prior to attaching the block to a parent.
- // This will transition to the attached state and will throw an exception
- // if already attached.
- void attachToParent();
+ MlirBlock get() { return block; }
+ PyOperationRef &getParentOperation() { return parentOperation; }
- MlirContext context;
- MlirBlock block;
+ void checkValid() { return parentOperation->checkValid(); }
private:
- bool detached;
+ PyOperationRef parentOperation;
+ MlirBlock block;
};
/// Wrapper around the generic MlirAttribute.
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 4158a4c96efd..e73269ce14f1 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRCAPIIR
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
LINK_LIBS PUBLIC
+ MLIRStandardOps
MLIRIR
MLIRParser
MLIRSupport
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 0304d977f494..2265df1c8234 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -10,6 +10,7 @@
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
@@ -25,6 +26,10 @@ using namespace mlir;
MlirContext mlirContextCreate() {
auto *context = new MLIRContext(/*loadAllDialects=*/false);
+ // TODO: Come up with a story for which dialects to load into the context
+ // and do not expand this beyond StandardOps until done so. This is loaded
+ // by default here because it is hard to make progress otherwise.
+ context->loadDialect<StandardOpsDialect>();
return wrap(context);
}
@@ -34,6 +39,14 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
+void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
+ unwrap(context)->allowUnregisteredDialects(allow);
+}
+
+int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
+ return unwrap(context)->allowsUnregisteredDialects();
+}
+
/* ========================================================================== */
/* Location API. */
/* ========================================================================== */
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 9c4c33a10ab8..9522e4b1ad98 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
+import itertools
import mlir
def run(f):
@@ -10,65 +11,91 @@ def run(f):
assert mlir.ir.Context._get_live_count() == 0
-# CHECK-LABEL: TEST: testDetachedRegionBlock
-def testDetachedRegionBlock():
+# Verify iterator based traversal of the op/region/block hierarchy.
+# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
+def testTraverseOpRegionBlockIterators():
ctx = mlir.ir.Context()
- t = mlir.ir.F32Type(ctx)
- region = ctx.create_region()
- block = ctx.create_block([t, t])
- # CHECK: <<UNLINKED BLOCK>>
- print(block)
+ ctx.allow_unregistered_dialects = True
+ module = ctx.parse_module(r"""
+ func @f1(%arg0: i32) -> i32 {
+ %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+ return %1 : i32
+ }
+ """)
+ op = module.operation
+ # Get the block using iterators off of the named collections.
+ regions = list(op.regions)
+ blocks = list(regions[0].blocks)
+ # CHECK: MODULE REGIONS=1 BLOCKS=1
+ print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
-run(testDetachedRegionBlock)
+ # Get the regions and blocks from the default collections.
+ default_regions = list(op)
+ default_blocks = list(default_regions[0])
+ # They should compare equal regardless of how obtained.
+ assert default_regions == regions
+ assert default_blocks == blocks
+ # Should be able to get the operations from either the named collection
+ # or the block.
+ operations = list(blocks[0].operations)
+ default_operations = list(blocks[0])
+ assert default_operations == operations
-# CHECK-LABEL: TEST: testBlockTypeContextMismatch
-def testBlockTypeContextMismatch():
- c1 = mlir.ir.Context()
- c2 = mlir.ir.Context()
- t1 = mlir.ir.F32Type(c1)
- t2 = mlir.ir.F32Type(c2)
- try:
- block = c1.create_block([t1, t2])
- except ValueError as e:
- # CHECK: ERROR: All types used to construct a block must be from the same context as the block
- print("ERROR:", e)
+ def walk_operations(indent, op):
+ for i, region in enumerate(op):
+ print(f"{indent}REGION {i}:")
+ for j, block in enumerate(region):
+ print(f"{indent} BLOCK {j}:")
+ for k, child_op in enumerate(block):
+ print(f"{indent} OP {k}: {child_op}")
+ walk_operations(indent + " ", child_op)
-run(testBlockTypeContextMismatch)
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: func
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: %0 = "custom.addi"
+ # CHECK: OP 1: return
+ # CHECK: OP 1: "module_terminator"
+ walk_operations("", op)
+run(testTraverseOpRegionBlockIterators)
-# CHECK-LABEL: TEST: testBlockAppend
-def testBlockAppend():
+
+# Verify index based traversal of the op/region/block hierarchy.
+# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
+def testTraverseOpRegionBlockIndices():
ctx = mlir.ir.Context()
- t = mlir.ir.F32Type(ctx)
- region = ctx.create_region()
- try:
- region.first_block
- except IndexError:
- pass
- else:
- raise RuntimeError("Expected exception not raised")
+ ctx.allow_unregistered_dialects = True
+ module = ctx.parse_module(r"""
+ func @f1(%arg0: i32) -> i32 {
+ %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+ return %1 : i32
+ }
+ """)
- block = ctx.create_block([t, t])
- region.append_block(block)
- try:
- region.append_block(block)
- except ValueError:
- pass
- else:
- raise RuntimeError("Expected exception not raised")
+ def walk_operations(indent, op):
+ for i in range(len(op.regions)):
+ region = op.regions[i]
+ print(f"{indent}REGION {i}:")
+ for j in range(len(region.blocks)):
+ block = region.blocks[j]
+ print(f"{indent} BLOCK {j}:")
+ for k in range(len(block.operations)):
+ child_op = block.operations[k]
+ print(f"{indent} OP {k}: {child_op}")
+ walk_operations(indent + " ", child_op)
- block2 = ctx.create_block([t])
- region.insert_block(1, block2)
- # CHECK: <<UNLINKED BLOCK>>
- block_first = region.first_block
- print(block_first)
- block_next = block_first.next_in_region
- try:
- block_next = block_next.next_in_region
- except IndexError:
- pass
- else:
- raise RuntimeError("Expected exception not raised")
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: func
+ # CHECK: REGION 0:
+ # CHECK: BLOCK 0:
+ # CHECK: OP 0: %0 = "custom.addi"
+ # CHECK: OP 1: return
+ # CHECK: OP 1: "module_terminator"
+ walk_operations("", module.operation)
-run(testBlockAppend)
+run(testTraverseOpRegionBlockIndices)
More information about the Mlir-commits
mailing list