[Mlir-commits] [mlir] [mlir][python]Python Bindings for select edit operations on Block arguments (PR #94305)

Sandeep Dasgupta llvmlistbot at llvm.org
Mon Jun 3 19:38:10 PDT 2024


https://github.com/sdasgup3 updated https://github.com/llvm/llvm-project/pull/94305

>From 3a1c934d434e3d0d085c1e568e13f9f0cea0a06c Mon Sep 17 00:00:00 2001
From: Sandeep Dasgupta <sdasgup at google.com>
Date: Tue, 4 Jun 2024 02:06:04 +0000
Subject: [PATCH] Python Bindings for select edit operations on block arguments

---
 mlir/include/mlir-c/IR.h            |  7 ++++++
 mlir/lib/Bindings/Python/IRCore.cpp | 19 ++++++++++++++++
 mlir/lib/CAPI/IR/IR.cpp             |  8 +++++++
 mlir/test/python/ir/blocks.py       | 35 +++++++++++++++++++++++++++++
 4 files changed, 69 insertions(+)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 32abacf353133..a71592203f5a5 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -858,6 +858,13 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block,
                                                   MlirType type,
                                                   MlirLocation loc);
 
+/// Erase the argument at 'index' and remove it from the argument list.
+MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index);
+
+/// Erases 'num' arguments from the index 'start'.
+MLIR_CAPI_EXPORTED void mlirBlockEraseArguments(MlirBlock block, unsigned start,
+                                                unsigned num);
+
 /// Inserts an argument of the specified type at a specified index to the block.
 /// Returns the newly added argument.
 MLIR_CAPI_EXPORTED MlirValue mlirBlockInsertArgument(MlirBlock block,
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index de20632b4fb7d..11d08c5c30cd3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3238,6 +3238,25 @@ void mlir::python::populateIRCore(py::module &m) {
             return PyBlockArgumentList(self.getParentOperation(), self.get());
           },
           "Returns a list of block arguments.")
+      .def(
+          "add_argument",
+          [](PyBlock &self, const PyType &type, const PyLocation &loc) {
+            return mlirBlockAddArgument(self.get(), type, loc);
+          },
+          "Append an argument of the specified type to the block and returns "
+          "the newly added argument.")
+      .def(
+          "erase_argument",
+          [](PyBlock &self, unsigned index) {
+            return mlirBlockEraseArgument(self.get(), index);
+          },
+          "Erase the argument at 'index' and remove it from the argument list.")
+      .def(
+          "erase_arguments",
+          [](PyBlock &self, unsigned start, unsigned num) {
+            return mlirBlockEraseArguments(self.get(), start, num);
+          },
+          "Erases 'num' arguments from the index 'start'.")
       .def_property_readonly(
           "operations",
           [](PyBlock &self) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a72cd247e73f6..ab8ee65b35d7f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -906,6 +906,14 @@ MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
   return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
 }
 
+void mlirBlockEraseArgument(MlirBlock block, unsigned index) {
+  return unwrap(block)->eraseArgument(index);
+}
+
+void mlirBlockEraseArguments(MlirBlock block, unsigned start, unsigned num) {
+  return unwrap(block)->eraseArguments(start, num);
+}
+
 MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type,
                                   MlirLocation loc) {
   return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc)));
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 8b4d946c97b8d..0ca8fd9e236ab 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -145,3 +145,38 @@ def testBlockHash():
             block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
             block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
             assert hash(block1) != hash(block2)
+
+
+# CHECK-LABEL: TEST: testBlockAddArgs
+ at run
+def testBlockAddArgs():
+    with Context() as ctx, Location.unknown(ctx) as loc:
+        ctx.allow_unregistered_dialects = True
+        f32 = F32Type.get()
+        op = Operation.create("test", regions=1, loc=Location.unknown())
+        blocks = op.regions[0].blocks
+        blocks.append()
+        # CHECK: ^bb0:
+        op.print(enable_debug_info=True)
+        blocks[0].add_argument(f32, loc)
+        # CHECK: ^bb0(%{{.+}}: f32 loc(unknown)):
+        op.print(enable_debug_info=True)
+
+
+# CHECK-LABEL: TEST: testBlockEraseArgs
+ at run
+def testBlockEraseArgs():
+    with Context() as ctx, Location.unknown(ctx) as loc:
+        ctx.allow_unregistered_dialects = True
+        f32 = F32Type.get()
+        op = Operation.create("test", regions=1, loc=Location.unknown())
+        blocks = op.regions[0].blocks
+        blocks.append(f32, f32, f32)
+        # CHECK: ^bb0(%{{.+}}: f32 loc(unknown), %{{.+}}: f32 loc(unknown), %{{.+}}: f32 loc(unknown)):
+        op.print(enable_debug_info=True)
+        blocks[0].erase_argument(0)
+        # CHECK: ^bb0(%{{.+}}: f32 loc(unknown), %{{.+}}: f32 loc(unknown)):
+        op.print(enable_debug_info=True)
+        blocks[0].erase_arguments(0, 2)
+        # CHECK: ^bb0:
+        op.print(enable_debug_info=True)



More information about the Mlir-commits mailing list