[Mlir-commits] [mlir] afb2ed8 - [mlir][Python] Add a simple PyOpOperand iterator for PyValue uses.

Mike Urbach llvmlistbot at llvm.org
Tue Dec 13 18:20:37 PST 2022


Author: Mike Urbach
Date: 2022-12-13T19:20:29-07:00
New Revision: afb2ed80cb1639236a3aa15f2c9ff96dbb45c3d0

URL: https://github.com/llvm/llvm-project/commit/afb2ed80cb1639236a3aa15f2c9ff96dbb45c3d0
DIFF: https://github.com/llvm/llvm-project/commit/afb2ed80cb1639236a3aa15f2c9ff96dbb45c3d0.diff

LOG: [mlir][Python] Add a simple PyOpOperand iterator for PyValue uses.

This adds a simple PyOpOperand based on MlirOpOperand, which can has
properties for the owner op and operation number.

This also adds a PyOpOperandIterator that defines methods for __iter__
and __next__ so PyOpOperands can be iterated over using the the
MlirOpOperand C API.

Finally, a uses psuedo-container is added to PyValue so the uses can
generically be iterated.

Depends on D139596

Reviewed By: stellaraccident, jdd

Differential Revision: https://reviews.llvm.org/D139597

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/python/ir/value.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0a32ff598feb6..b46fe44e9e19f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -447,6 +447,55 @@ class PyOperationList {
   MlirBlock block;
 };
 
+class PyOpOperand {
+public:
+  PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+  py::object getOwner() {
+    MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+    PyMlirContextRef context =
+        PyMlirContext::forContext(mlirOperationGetContext(owner));
+    return PyOperation::forOperation(context, owner)->createOpView();
+  }
+
+  size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+
+  static void bind(py::module &m) {
+    py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
+        .def_property_readonly("owner", &PyOpOperand::getOwner)
+        .def_property_readonly("operand_number",
+                               &PyOpOperand::getOperandNumber);
+  }
+
+private:
+  MlirOpOperand opOperand;
+};
+
+class PyOpOperandIterator {
+public:
+  PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+  PyOpOperandIterator &dunderIter() { return *this; }
+
+  PyOpOperand dunderNext() {
+    if (mlirOpOperandIsNull(opOperand))
+      throw py::stop_iteration();
+
+    PyOpOperand returnOpOperand(opOperand);
+    opOperand = mlirOpOperandGetNextUse(opOperand);
+    return returnOpOperand;
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
+        .def("__iter__", &PyOpOperandIterator::dunderIter)
+        .def("__next__", &PyOpOperandIterator::dunderNext);
+  }
+
+private:
+  MlirOpOperand opOperand;
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -3156,6 +3205,11 @@ void mlir::python::populateIRCore(py::module &m) {
             assert(false && "Value must be a block argument or an op result");
             return py::none();
           })
+      .def_property_readonly("uses",
+                             [](PyValue &self) {
+                               return PyOpOperandIterator(
+                                   mlirValueGetFirstUse(self.get()));
+                             })
       .def("__eq__",
            [](PyValue &self, PyValue &other) {
              return self.get().ptr == other.get().ptr;
@@ -3182,6 +3236,7 @@ void mlir::python::populateIRCore(py::module &m) {
       });
   PyBlockArgument::bind(m);
   PyOpResult::bind(m);
+  PyOpOperand::bind(m);
 
   //----------------------------------------------------------------------------
   // Mapping of SymbolTable.
@@ -3220,6 +3275,7 @@ void mlir::python::populateIRCore(py::module &m) {
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
   PyOpAttributeMap::bind(m);
+  PyOpOperandIterator::bind(m);
   PyOpOperandList::bind(m);
   PyOpResultList::bind(m);
   PyRegionIterator::bind(m);

diff  --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 262896ec317f9..98f55de41e150 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -89,3 +89,25 @@ def testValueHash():
   op, ret = block.operations
   assert hash(block.arguments[0]) == hash(op.operands[0])
   assert hash(op.result) == hash(ret.operands[0])
+
+# CHECK-LABEL: TEST: testValueUses
+ at run
+def testValueUses():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signless(32)
+    module = Module.create()
+    with InsertionPoint(module.body):
+      value = Operation.create("custom.op1", results=[i32]).results[0]
+      op1 = Operation.create("custom.op2", operands=[value])
+      op2 = Operation.create("custom.op2", operands=[value])
+
+  # CHECK: Use owner: "custom.op2"
+  # CHECK: Use operand_number: 0
+  # CHECK: Use owner: "custom.op2"
+  # CHECK: Use operand_number: 0
+  for use in value.uses:
+    assert use.owner in [op1, op2]
+    print(f"Use owner: {use.owner}")
+    print(f"Use operand_number: {use.operand_number}")


        


More information about the Mlir-commits mailing list