[Mlir-commits] [mlir] [mlir][python]Python Bindings for select edit operations on Block arguments (PR #94305)
Sandeep Dasgupta
llvmlistbot at llvm.org
Mon Jun 3 19:38:10 PDT 2024
https://github.com/sdasgup3 updated https://github.com/llvm/llvm-project/pull/94305
>From 3a1c934d434e3d0d085c1e568e13f9f0cea0a06c Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Tue, 4 Jun 2024 02:06:04 +0000
Subject: [PATCH] Python Bindings for select edit operations on block arguments
---
mlir/include/mlir-c/IR.h | 7 ++++++
mlir/lib/Bindings/Python/IRCore.cpp | 19 ++++++++++++++++
mlir/lib/CAPI/IR/IR.cpp | 8 +++++++
mlir/test/python/ir/blocks.py | 35 +++++++++++++++++++++++++++++
4 files changed, 69 insertions(+)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 32abacf353133..a71592203f5a5 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -858,6 +858,13 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block,
MlirType type,
MlirLocation loc);
+/// Erase the argument at 'index' and remove it from the argument list.
+MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index);
+
+/// Erases 'num' arguments from the index 'start'.
+MLIR_CAPI_EXPORTED void mlirBlockEraseArguments(MlirBlock block, unsigned start,
+ unsigned num);
+
/// Inserts an argument of the specified type at a specified index to the block.
/// Returns the newly added argument.
MLIR_CAPI_EXPORTED MlirValue mlirBlockInsertArgument(MlirBlock block,
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index de20632b4fb7d..11d08c5c30cd3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3238,6 +3238,25 @@ void mlir::python::populateIRCore(py::module &m) {
return PyBlockArgumentList(self.getParentOperation(), self.get());
},
"Returns a list of block arguments.")
+ .def(
+ "add_argument",
+ [](PyBlock &self, const PyType &type, const PyLocation &loc) {
+ return mlirBlockAddArgument(self.get(), type, loc);
+ },
+ "Append an argument of the specified type to the block and returns "
+ "the newly added argument.")
+ .def(
+ "erase_argument",
+ [](PyBlock &self, unsigned index) {
+ return mlirBlockEraseArgument(self.get(), index);
+ },
+ "Erase the argument at 'index' and remove it from the argument list.")
+ .def(
+ "erase_arguments",
+ [](PyBlock &self, unsigned start, unsigned num) {
+ return mlirBlockEraseArguments(self.get(), start, num);
+ },
+ "Erases 'num' arguments from the index 'start'.")
.def_property_readonly(
"operations",
[](PyBlock &self) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a72cd247e73f6..ab8ee65b35d7f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -906,6 +906,14 @@ MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
}
+void mlirBlockEraseArgument(MlirBlock block, unsigned index) {
+ return unwrap(block)->eraseArgument(index);
+}
+
+void mlirBlockEraseArguments(MlirBlock block, unsigned start, unsigned num) {
+ return unwrap(block)->eraseArguments(start, num);
+}
+
MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type,
MlirLocation loc) {
return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc)));
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 8b4d946c97b8d..0ca8fd9e236ab 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -145,3 +145,38 @@ def testBlockHash():
block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
assert hash(block1) != hash(block2)
+
+
+# CHECK-LABEL: TEST: testBlockAddArgs
+ at run
+def testBlockAddArgs():
+ with Context() as ctx, Location.unknown(ctx) as loc:
+ ctx.allow_unregistered_dialects = True
+ f32 = F32Type.get()
+ op = Operation.create("test", regions=1, loc=Location.unknown())
+ blocks = op.regions[0].blocks
+ blocks.append()
+ # CHECK: ^bb0:
+ op.print(enable_debug_info=True)
+ blocks[0].add_argument(f32, loc)
+ # CHECK: ^bb0(%{{.+}}: f32 loc(unknown)):
+ op.print(enable_debug_info=True)
+
+
+# CHECK-LABEL: TEST: testBlockEraseArgs
+ at run
+def testBlockEraseArgs():
+ with Context() as ctx, Location.unknown(ctx) as loc:
+ ctx.allow_unregistered_dialects = True
+ f32 = F32Type.get()
+ op = Operation.create("test", regions=1, loc=Location.unknown())
+ blocks = op.regions[0].blocks
+ blocks.append(f32, f32, f32)
+ # CHECK: ^bb0(%{{.+}}: f32 loc(unknown), %{{.+}}: f32 loc(unknown), %{{.+}}: f32 loc(unknown)):
+ op.print(enable_debug_info=True)
+ blocks[0].erase_argument(0)
+ # CHECK: ^bb0(%{{.+}}: f32 loc(unknown), %{{.+}}: f32 loc(unknown)):
+ op.print(enable_debug_info=True)
+ blocks[0].erase_arguments(0, 2)
+ # CHECK: ^bb0:
+ op.print(enable_debug_info=True)
More information about the Mlir-commits
mailing list