[Mlir-commits] [mlir] [mlir][python] bind block predecessors and successors (PR #145116)

Maksim Levental llvmlistbot at llvm.org
Sat Jun 21 14:33:58 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/145116

>From c760955f9ada410ef8325866fed7944cd67f355b Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Fri, 20 Jun 2025 19:20:09 -0400
Subject: [PATCH 1/4] [mlir][python] bind block successors and predecessors

---
 mlir/include/mlir-c/IR.h            | 14 +++++
 mlir/lib/Bindings/Python/IRCore.cpp | 94 ++++++++++++++++++++++++++++-
 mlir/lib/CAPI/IR/IR.cpp             | 20 ++++++
 mlir/test/python/ir/blocks.py       | 20 +++++-
 4 files changed, 144 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 1a8e8737f7fed..71aaee931d543 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
 MLIR_CAPI_EXPORTED void
 mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
 
+/// Returns the number of successor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
+
+/// Returns `pos`-th successor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
+                                                   intptr_t pos);
+
+/// Returns the number of predecessor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
+
+/// Returns `pos`-th predecessor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
+                                                     intptr_t pos);
+
 //===----------------------------------------------------------------------===//
 // Value API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cbd35f2974ae9..c12f036352b30 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2626,6 +2626,84 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
   PyOperationRef operation;
 };
 
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+  static constexpr const char *pyClassName = "BlockSuccessors";
+
+  PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+                    intptr_t startIndex = 0, intptr_t length = -1,
+                    intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirBlockGetNumSuccessors(block.get())
+                               : length,
+                  step),
+        operation(operation), block(block) {}
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+  intptr_t getRawNumElements() {
+    block.checkValid();
+    return mlirBlockGetNumSuccessors(block.get());
+  }
+
+  PyBlock getRawElement(intptr_t pos) {
+    MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+    return PyBlock(operation, block);
+  }
+
+  PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyBlockSuccessors(block, operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+  PyBlock block;
+};
+
+/// A list of block predecessors. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+  static constexpr const char *pyClassName = "BlockPredecessors";
+
+  PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+                      intptr_t startIndex = 0, intptr_t length = -1,
+                      intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirBlockGetNumPredecessors(block.get())
+                               : length,
+                  step),
+        operation(operation), block(block) {}
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+  intptr_t getRawNumElements() {
+    block.checkValid();
+    return mlirBlockGetNumPredecessors(block.get());
+  }
+
+  PyBlock getRawElement(intptr_t pos) {
+    MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+    return PyBlock(operation, block);
+  }
+
+  PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+                            intptr_t step) {
+    return PyBlockPredecessors(block, operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+  PyBlock block;
+};
+
 /// A list of operation attributes. Can be indexed by name, producing
 /// attributes, or by index, producing named attributes.
 class PyOpAttributeMap {
@@ -3655,7 +3733,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("operation"),
           "Appends an operation to this block. If the operation is currently "
-          "in another block, it will be moved.");
+          "in another block, it will be moved.")
+      .def_prop_ro(
+          "successors",
+          [](PyBlock &self) {
+            return PyBlockSuccessors(self, self.getParentOperation());
+          },
+          "Returns the list of Block successors.")
+      .def_prop_ro(
+          "predecessors",
+          [](PyBlock &self) {
+            return PyBlockPredecessors(self, self.getParentOperation());
+          },
+          "Returns the list of Block predecessors.");
 
   //----------------------------------------------------------------------------
   // Mapping of PyInsertionPoint.
@@ -4099,6 +4189,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   PyBlockArgumentList::bind(m);
   PyBlockIterator::bind(m);
   PyBlockList::bind(m);
+  PyBlockSuccessors::bind(m);
+  PyBlockPredecessors::bind(m);
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
   PyOpAttributeMap::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e0e386d55ede1..fbc66bcf5c2d0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
   unwrap(block)->print(stream);
 }
 
+intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
+  return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
+}
+
+MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
+  return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
+}
+
+intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
+  Block *b = unwrap(block);
+  return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
+}
+
+MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
+  Block *b = unwrap(block);
+  Block::pred_iterator it = b->pred_begin();
+  std::advance(it, pos);
+  return wrap(*it);
+}
+
 //===----------------------------------------------------------------------===//
 // Value API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 70ccaeeb5435b..ced5fce434728 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -1,12 +1,11 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import io
-import itertools
-from mlir.ir import *
+
 from mlir.dialects import builtin
 from mlir.dialects import cf
 from mlir.dialects import func
+from mlir.ir import *
 
 
 def run(f):
@@ -54,10 +53,25 @@ def testBlockCreation():
             with InsertionPoint(middle_block) as middle_ip:
                 assert middle_ip.block == middle_block
                 cf.BranchOp([i32_arg], dest=successor_block)
+
         module.print(enable_debug_info=True)
         # Ensure region back references are coherent.
         assert entry_block.region == middle_block.region == successor_block.region
 
+        assert len(entry_block.predecessors) == 0
+
+        assert len(entry_block.successors) == 1
+        assert middle_block == entry_block.successors[0]
+        assert len(middle_block.predecessors) == 1
+        assert entry_block == middle_block.predecessors[0]
+
+        assert len(middle_block.successors) == 1
+        assert successor_block == middle_block.successors[0]
+        assert len(successor_block.predecessors) == 1
+        assert middle_block == successor_block.predecessors[0]
+
+        assert len(successor_block.successors) == 0
+
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run

>From 035e009fe5d04bd373a5bc336697b8724bf53018 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 21 Jun 2025 16:20:39 -0400
Subject: [PATCH 2/4] add C API test

---
 mlir/test/CAPI/ir.c | 55 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 55 insertions(+)

diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 68da79f69cc0a..f1ae1aabc9bcf 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2440,6 +2440,58 @@ void testDiagnostics(void) {
   mlirContextDestroy(ctx);
 }
 
+int testBlockPredecessorsSuccessors(MlirContext ctx) {
+  // CHECK-LABEL: @testBlockPredecessorsSuccessors
+  fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
+
+  const char *moduleString = R"(
+    #loc2 = loc("arg1")
+    #loc3 = loc("middle")
+    #loc4 = loc("successor")
+    module {
+      func.func @test(%arg0: i32 loc("arg0"), %arg1: i16 loc("arg1")) {
+        cf.br ^bb1(%arg1 : i16) loc(#loc)
+      ^bb1(%0: i16 loc("middle")):  // pred: ^bb0
+        cf.br ^bb2(%arg0 : i32) loc(#loc)
+      ^bb2(%1: i32 loc("successor")):  // pred: ^bb1
+        return loc(#loc)
+      } loc(#loc)
+    } loc(#loc)
+    #loc = loc(unknown)
+  )";
+
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+  MlirOperation moduleOp = mlirModuleGetOperation(module);
+  MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0);
+  MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion);
+  MlirOperation function = mlirBlockGetFirstOperation(moduleBlock);
+  MlirRegion funcRegion = mlirOperationGetRegion(function, 0);
+  MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion);
+  MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock);
+  MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock);
+
+  assert(mlirBlockGetNumPredecessors(entryBlock) == 0);
+
+  assert(mlirBlockGetNumSuccessors(entryBlock) == 1);
+  assert(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)));
+  assert(mlirBlockGetNumPredecessors(middleBlock) == 1);
+  assert(mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0)));
+
+  assert(mlirBlockGetNumSuccessors(middleBlock) == 1);
+  assert(mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)));
+  assert(mlirBlockGetNumPredecessors(successorBlock) == 1);
+  assert(
+      mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)));
+
+  assert(mlirBlockGetNumSuccessors(successorBlock) == 0);
+
+  mlirModuleDestroy(module);
+
+  return 0;
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   registerAllUpstreamDialects(ctx);
@@ -2486,6 +2538,9 @@ int main(void) {
   testExplicitThreadPools();
   testDiagnostics();
 
+  if (testBlockPredecessorsSuccessors(ctx))
+    return 17;
+
   // CHECK: DESTROY MAIN CONTEXT
   // CHECK: reportResourceDelete: resource_i64_blob
   fprintf(stderr, "DESTROY MAIN CONTEXT\n");

>From c5787ec9dd7c77135be37fa7b8989dc7cd2eddb4 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 21 Jun 2025 17:19:39 -0400
Subject: [PATCH 3/4] Update ir.c

---
 mlir/test/CAPI/ir.c | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index f1ae1aabc9bcf..8160757f565d1 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2444,7 +2444,7 @@ int testBlockPredecessorsSuccessors(MlirContext ctx) {
   // CHECK-LABEL: @testBlockPredecessorsSuccessors
   fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
 
-  const char *moduleString = R"(
+  const char *moduleString = R"""(
     #loc2 = loc("arg1")
     #loc3 = loc("middle")
     #loc4 = loc("successor")
@@ -2458,7 +2458,7 @@ int testBlockPredecessorsSuccessors(MlirContext ctx) {
       } loc(#loc)
     } loc(#loc)
     #loc = loc(unknown)
-  )";
+  )""";
 
   MlirModule module =
       mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));

>From dc0e8af31de2c5d676f8f9b4d1611c26464ca5fb Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 21 Jun 2025 17:33:48 -0400
Subject: [PATCH 4/4] Update ir.c

---
 mlir/test/CAPI/ir.c | 20 ++++++++------------
 1 file changed, 8 insertions(+), 12 deletions(-)

diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8160757f565d1..88a3e70b4ec5d 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2445,19 +2445,15 @@ int testBlockPredecessorsSuccessors(MlirContext ctx) {
   fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
 
   const char *moduleString = R"""(
-    #loc2 = loc("arg1")
-    #loc3 = loc("middle")
-    #loc4 = loc("successor")
     module {
-      func.func @test(%arg0: i32 loc("arg0"), %arg1: i16 loc("arg1")) {
-        cf.br ^bb1(%arg1 : i16) loc(#loc)
-      ^bb1(%0: i16 loc("middle")):  // pred: ^bb0
-        cf.br ^bb2(%arg0 : i32) loc(#loc)
-      ^bb2(%1: i32 loc("successor")):  // pred: ^bb1
-        return loc(#loc)
-      } loc(#loc)
-    } loc(#loc)
-    #loc = loc(unknown)
+      func.func @test(%arg0: i32, %arg1: i16) {
+        cf.br ^bb1(%arg1 : i16)
+      ^bb1(%0: i16):  // pred: ^bb0
+        cf.br ^bb2(%arg0 : i32)
+      ^bb2(%1: i32):  // pred: ^bb1
+        return
+      }
+    }
   )""";
 
   MlirModule module =



More information about the Mlir-commits mailing list