[Mlir-commits] [mlir] [mlir][python] bind block predecessors and successors (PR #145116)
Maksim Levental
llvmlistbot at llvm.org
Sat Jun 21 14:33:58 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/145116
>From c760955f9ada410ef8325866fed7944cd67f355b Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Fri, 20 Jun 2025 19:20:09 -0400
Subject: [PATCH 1/4] [mlir][python] bind block successors and predecessors
---
mlir/include/mlir-c/IR.h | 14 +++++
mlir/lib/Bindings/Python/IRCore.cpp | 94 ++++++++++++++++++++++++++++-
mlir/lib/CAPI/IR/IR.cpp | 20 ++++++
mlir/test/python/ir/blocks.py | 20 +++++-
4 files changed, 144 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 1a8e8737f7fed..71aaee931d543 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
MLIR_CAPI_EXPORTED void
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
+/// Returns the number of successor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
+
+/// Returns `pos`-th successor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
+ intptr_t pos);
+
+/// Returns the number of predecessor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
+
+/// Returns `pos`-th predecessor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
+ intptr_t pos);
+
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cbd35f2974ae9..c12f036352b30 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2626,6 +2626,84 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
PyOperationRef operation;
};
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockSuccessors";
+
+ PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of block predecessors. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockPredecessors";
+
+ PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumPredecessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
@@ -3655,7 +3733,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("operation"),
"Appends an operation to this block. If the operation is currently "
- "in another block, it will be moved.");
+ "in another block, it will be moved.")
+ .def_prop_ro(
+ "successors",
+ [](PyBlock &self) {
+ return PyBlockSuccessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block successors.")
+ .def_prop_ro(
+ "predecessors",
+ [](PyBlock &self) {
+ return PyBlockPredecessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block predecessors.");
//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
@@ -4099,6 +4189,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
PyBlockList::bind(m);
+ PyBlockSuccessors::bind(m);
+ PyBlockPredecessors::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpAttributeMap::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e0e386d55ede1..fbc66bcf5c2d0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
unwrap(block)->print(stream);
}
+intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
+ return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
+}
+
+MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
+ return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
+}
+
+intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
+ Block *b = unwrap(block);
+ return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
+}
+
+MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
+ Block *b = unwrap(block);
+ Block::pred_iterator it = b->pred_begin();
+ std::advance(it, pos);
+ return wrap(*it);
+}
+
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 70ccaeeb5435b..ced5fce434728 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -1,12 +1,11 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
-import io
-import itertools
-from mlir.ir import *
+
from mlir.dialects import builtin
from mlir.dialects import cf
from mlir.dialects import func
+from mlir.ir import *
def run(f):
@@ -54,10 +53,25 @@ def testBlockCreation():
with InsertionPoint(middle_block) as middle_ip:
assert middle_ip.block == middle_block
cf.BranchOp([i32_arg], dest=successor_block)
+
module.print(enable_debug_info=True)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region
+ assert len(entry_block.predecessors) == 0
+
+ assert len(entry_block.successors) == 1
+ assert middle_block == entry_block.successors[0]
+ assert len(middle_block.predecessors) == 1
+ assert entry_block == middle_block.predecessors[0]
+
+ assert len(middle_block.successors) == 1
+ assert successor_block == middle_block.successors[0]
+ assert len(successor_block.predecessors) == 1
+ assert middle_block == successor_block.predecessors[0]
+
+ assert len(successor_block.successors) == 0
+
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
>From 035e009fe5d04bd373a5bc336697b8724bf53018 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 21 Jun 2025 16:20:39 -0400
Subject: [PATCH 2/4] add C API test
---
mlir/test/CAPI/ir.c | 55 +++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 55 insertions(+)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 68da79f69cc0a..f1ae1aabc9bcf 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2440,6 +2440,58 @@ void testDiagnostics(void) {
mlirContextDestroy(ctx);
}
+int testBlockPredecessorsSuccessors(MlirContext ctx) {
+ // CHECK-LABEL: @testBlockPredecessorsSuccessors
+ fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
+
+ const char *moduleString = R"(
+ #loc2 = loc("arg1")
+ #loc3 = loc("middle")
+ #loc4 = loc("successor")
+ module {
+ func.func @test(%arg0: i32 loc("arg0"), %arg1: i16 loc("arg1")) {
+ cf.br ^bb1(%arg1 : i16) loc(#loc)
+ ^bb1(%0: i16 loc("middle")): // pred: ^bb0
+ cf.br ^bb2(%arg0 : i32) loc(#loc)
+ ^bb2(%1: i32 loc("successor")): // pred: ^bb1
+ return loc(#loc)
+ } loc(#loc)
+ } loc(#loc)
+ #loc = loc(unknown)
+ )";
+
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+ MlirOperation moduleOp = mlirModuleGetOperation(module);
+ MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0);
+ MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion);
+ MlirOperation function = mlirBlockGetFirstOperation(moduleBlock);
+ MlirRegion funcRegion = mlirOperationGetRegion(function, 0);
+ MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion);
+ MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock);
+ MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock);
+
+ assert(mlirBlockGetNumPredecessors(entryBlock) == 0);
+
+ assert(mlirBlockGetNumSuccessors(entryBlock) == 1);
+ assert(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)));
+ assert(mlirBlockGetNumPredecessors(middleBlock) == 1);
+ assert(mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0)));
+
+ assert(mlirBlockGetNumSuccessors(middleBlock) == 1);
+ assert(mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)));
+ assert(mlirBlockGetNumPredecessors(successorBlock) == 1);
+ assert(
+ mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)));
+
+ assert(mlirBlockGetNumSuccessors(successorBlock) == 0);
+
+ mlirModuleDestroy(module);
+
+ return 0;
+}
+
int main(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
@@ -2486,6 +2538,9 @@ int main(void) {
testExplicitThreadPools();
testDiagnostics();
+ if (testBlockPredecessorsSuccessors(ctx))
+ return 17;
+
// CHECK: DESTROY MAIN CONTEXT
// CHECK: reportResourceDelete: resource_i64_blob
fprintf(stderr, "DESTROY MAIN CONTEXT\n");
>From c5787ec9dd7c77135be37fa7b8989dc7cd2eddb4 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 21 Jun 2025 17:19:39 -0400
Subject: [PATCH 3/4] Update ir.c
---
mlir/test/CAPI/ir.c | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index f1ae1aabc9bcf..8160757f565d1 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2444,7 +2444,7 @@ int testBlockPredecessorsSuccessors(MlirContext ctx) {
// CHECK-LABEL: @testBlockPredecessorsSuccessors
fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
- const char *moduleString = R"(
+ const char *moduleString = R"""(
#loc2 = loc("arg1")
#loc3 = loc("middle")
#loc4 = loc("successor")
@@ -2458,7 +2458,7 @@ int testBlockPredecessorsSuccessors(MlirContext ctx) {
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
- )";
+ )""";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
>From dc0e8af31de2c5d676f8f9b4d1611c26464ca5fb Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 21 Jun 2025 17:33:48 -0400
Subject: [PATCH 4/4] Update ir.c
---
mlir/test/CAPI/ir.c | 20 ++++++++------------
1 file changed, 8 insertions(+), 12 deletions(-)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8160757f565d1..88a3e70b4ec5d 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2445,19 +2445,15 @@ int testBlockPredecessorsSuccessors(MlirContext ctx) {
fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
const char *moduleString = R"""(
- #loc2 = loc("arg1")
- #loc3 = loc("middle")
- #loc4 = loc("successor")
module {
- func.func @test(%arg0: i32 loc("arg0"), %arg1: i16 loc("arg1")) {
- cf.br ^bb1(%arg1 : i16) loc(#loc)
- ^bb1(%0: i16 loc("middle")): // pred: ^bb0
- cf.br ^bb2(%arg0 : i32) loc(#loc)
- ^bb2(%1: i32 loc("successor")): // pred: ^bb1
- return loc(#loc)
- } loc(#loc)
- } loc(#loc)
- #loc = loc(unknown)
+ func.func @test(%arg0: i32, %arg1: i16) {
+ cf.br ^bb1(%arg1 : i16)
+ ^bb1(%0: i16): // pred: ^bb0
+ cf.br ^bb2(%arg0 : i32)
+ ^bb2(%1: i32): // pred: ^bb1
+ return
+ }
+ }
)""";
MlirModule module =
More information about the Mlir-commits
mailing list