[Mlir-commits] [mlir] [mlir][python] bind block successors (PR #145116)

Maksim Levental llvmlistbot at llvm.org
Fri Jun 20 20:09:49 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/145116

>From c49fd8abf843a6e8af3be8585435d46deab82a84 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Fri, 20 Jun 2025 19:20:09 -0400
Subject: [PATCH] [mlir][python] bind block successors and predecessors

---
 mlir/include/mlir-c/IR.h            | 14 +++++
 mlir/lib/Bindings/Python/IRCore.cpp | 94 ++++++++++++++++++++++++++++-
 mlir/lib/CAPI/IR/IR.cpp             | 20 ++++++
 mlir/test/python/ir/blocks.py       | 19 +++++-
 4 files changed, 143 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 1a8e8737f7fed..71aaee931d543 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
 MLIR_CAPI_EXPORTED void
 mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
 
+/// Returns the number of successor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
+
+/// Returns `pos`-th successor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
+                                                   intptr_t pos);
+
+/// Returns the number of predecessor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
+
+/// Returns `pos`-th predecessor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
+                                                     intptr_t pos);
+
 //===----------------------------------------------------------------------===//
 // Value API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cbd35f2974ae9..7527049952bb4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2626,6 +2626,85 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
   PyOperationRef operation;
 };
 
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+  static constexpr const char *pyClassName = "BlockSuccessors";
+
+  PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+                    intptr_t startIndex = 0, intptr_t length = -1,
+                    intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirBlockGetNumSuccessors(block.get())
+                               : length,
+                  step),
+        operation(operation), block(block) {}
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+  intptr_t getRawNumElements() {
+    block.checkValid();
+    return mlirBlockGetNumSuccessors(block.get());
+  }
+
+  PyBlock getRawElement(intptr_t pos) {
+    MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+    return PyBlock(operation, block);
+  }
+
+  PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyBlockSuccessors(block, operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+  PyBlock block;
+};
+
+/// A list of block predecessors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+  static constexpr const char *pyClassName = "BlockPredecessors";
+
+  PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+                      intptr_t startIndex = 0, intptr_t length = -1,
+                      intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirBlockGetNumPredecessors(block.get())
+                               : length,
+                  step),
+        operation(operation), block(block) {}
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+  intptr_t getRawNumElements() {
+    block.checkValid();
+    return mlirBlockGetNumPredecessors(block.get());
+  }
+
+  PyBlock getRawElement(intptr_t pos) {
+    MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+    return PyBlock(operation, block);
+  }
+
+  PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+                            intptr_t step) {
+    return PyBlockPredecessors(block, operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+  PyBlock block;
+};
+
 /// A list of operation attributes. Can be indexed by name, producing
 /// attributes, or by index, producing named attributes.
 class PyOpAttributeMap {
@@ -3655,7 +3734,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("operation"),
           "Appends an operation to this block. If the operation is currently "
-          "in another block, it will be moved.");
+          "in another block, it will be moved.")
+      .def_prop_ro(
+          "successors",
+          [](PyBlock &self) {
+            return PyBlockSuccessors(self, self.getParentOperation());
+          },
+          "Returns the list of Block successors.")
+      .def_prop_ro(
+          "predecessors",
+          [](PyBlock &self) {
+            return PyBlockPredecessors(self, self.getParentOperation());
+          },
+          "Returns the list of Block predecessors.");
 
   //----------------------------------------------------------------------------
   // Mapping of PyInsertionPoint.
@@ -4099,6 +4190,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   PyBlockArgumentList::bind(m);
   PyBlockIterator::bind(m);
   PyBlockList::bind(m);
+  PyBlockSuccessors::bind(m);
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
   PyOpAttributeMap::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e0e386d55ede1..fbc66bcf5c2d0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
   unwrap(block)->print(stream);
 }
 
+intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
+  return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
+}
+
+MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
+  return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
+}
+
+intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
+  Block *b = unwrap(block);
+  return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
+}
+
+MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
+  Block *b = unwrap(block);
+  Block::pred_iterator it = b->pred_begin();
+  std::advance(it, pos);
+  return wrap(*it);
+}
+
 //===----------------------------------------------------------------------===//
 // Value API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 70ccaeeb5435b..6200242ee6835 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -1,12 +1,11 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import io
-import itertools
-from mlir.ir import *
+
 from mlir.dialects import builtin
 from mlir.dialects import cf
 from mlir.dialects import func
+from mlir.ir import *
 
 
 def run(f):
@@ -54,10 +53,24 @@ def testBlockCreation():
             with InsertionPoint(middle_block) as middle_ip:
                 assert middle_ip.block == middle_block
                 cf.BranchOp([i32_arg], dest=successor_block)
+
         module.print(enable_debug_info=True)
         # Ensure region back references are coherent.
         assert entry_block.region == middle_block.region == successor_block.region
 
+        assert len(entry_block.successors) == 1
+        assert len(entry_block.predecessors) == 0
+        assert middle_block == entry_block.successors[0]
+        assert len(middle_block.predecessors) == 1
+        assert entry_block == middle_block.predecessors[0]
+
+        assert len(middle_block.successors) == 1
+        assert successor_block == middle_block.successors[0]
+        assert len(successor_block.predecessors) == 1
+        assert middle_block == successor_block.predecessors[0]
+
+        assert len(successor_block.successors) == 0
+
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run



More information about the Mlir-commits mailing list