[Mlir-commits] [mlir] [MLIR][Python] Expose the insertion point of pattern rewriter (PR #161001)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 3 20:56:12 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/161001

>From 28d65b8d5e0a059f790ff2a56423ab9e813c5e72 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 18:31:40 +0800
Subject: [PATCH 1/7] [MLIR][Python] Expose the insertion point of pattern
 rewriter

---
 mlir/include/mlir-c/Rewrite.h                | 11 +++
 mlir/lib/Bindings/Python/Rewrite.cpp         | 16 ++++-
 mlir/lib/CAPI/Transforms/Rewrite.cpp         | 16 +++++
 mlir/test/python/integration/dialects/pdl.py | 76 +++++++++++++++++++-
 4 files changed, 116 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 77be1f480eacf..b0f60901c5301 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -101,6 +101,9 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
 MLIR_CAPI_EXPORTED MlirBlock
 mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
 
+MLIR_CAPI_EXPORTED MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
+
 //===----------------------------------------------------------------------===//
 /// Block and operation creation/insertion/cloning
 //===----------------------------------------------------------------------===//
@@ -310,6 +313,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Cast the PatternRewriter to a RewriterBase
+MLIR_CAPI_EXPORTED MlirRewriterBase
+mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 20392b9002706..b520d8d3f1ecc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -143,7 +143,21 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
+  nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
+      .def("ip", [](MlirPatternRewriter rewriter) {
+        MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
+        MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+        MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+        MlirOperation owner = mlirBlockGetParentOperation(block);
+        auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
+                       ->getRef();
+        if (mlirOperationIsNull(op)) {
+          auto parent = PyOperation::forOperation(ctx, owner);
+          return PyInsertionPoint(PyBlock(parent, block));
+        }
+
+        return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
+      });
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308cadf83..b149d35f0d88b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
   return wrap(unwrap(rewriter)->getBlock());
 }
 
+MlirOperation
+mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
+  mlir::RewriterBase *base = unwrap(rewriter);
+  mlir::Block *block = base->getInsertionBlock();
+  auto it = base->getInsertionPoint();
+  if (it == block->end()) {
+    return {nullptr};
+  }
+
+  return wrap(std::addressof(*it));
+}
+
 //===----------------------------------------------------------------------===//
 /// Block and operation creation/insertion/cloning
 //===----------------------------------------------------------------------===//
@@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
   return {rewriter};
 }
 
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+  return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index c8e6197e03842..b8c7e277f1776 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -121,8 +121,10 @@ def load_myint_dialect():
 
 
 # This PDL pattern is to fold constant additions,
-# i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1.
+# including two patterns:
+# 1. add(constant0, constant1) -> constant2
+#    where constant2 = constant0 + constant1;
+# 2. add(x, 0) or add(0, x) -> x.
 def get_pdl_pattern_fold():
     m = Module.create()
     i32 = IntegerType.get_signless(32)
@@ -237,3 +239,73 @@ def test_pdl_register_function_constraint(module_):
     apply_patterns_and_fold_greedily(module_, frozen)
 
     return module_
+
+
+# This pattern is to expand constant to additions
+# unless the constant is no more than 1,
+# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
+def get_pdl_pattern_expand():
+    m = Module.create()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(m.body):
+
+        @pdl.pattern(benefit=1, sym_name="myint_constant_expand")
+        def pat():
+            t = pdl.TypeOp(i32)
+            cst = pdl.AttributeOp()
+            pdl.apply_native_constraint([], "is_one", [cst])
+            op0 = pdl.OperationOp(name="myint.constant", attributes={"value": cst}, types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                expanded = pdl.apply_native_rewrite([pdl.OperationType.get()], "expand", [cst])
+                pdl.ReplaceOp(op0, with_op=expanded)
+
+    def is_one(rewriter, results, values):
+        cst = values[0].value
+        return cst <= 1
+    
+    def expand(rewriter, results, values):
+        cst = values[0].value
+        c1 = cst // 2
+        c2 = cst - c1
+        with rewriter.ip():
+            op1 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c1)})
+            op2 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c2)})
+            res = Operation.create("myint.add", results=[i32], operands=[op1.result, op2.result])
+        results.append(res)
+
+    pdl_module = PDLModule(m)
+    pdl_module.register_constraint_function("is_one", is_one)
+    pdl_module.register_rewrite_function("expand", expand)
+    return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_expand
+# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
+# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
+# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
+# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
+# CHECK: return %8 : i32
+ at construct_and_print_in_module
+def test_pdl_register_function_expand(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        func.func @f() -> i32 {
+          %0 = "myint.constant"() { value = 5 }: () -> (i32)
+          return %0 : i32
+        }
+        """
+    )
+
+    frozen = get_pdl_pattern_expand()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

>From beab53db10c980b9a326abc1c1bb3ca73b4bbddd Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 19:16:43 +0800
Subject: [PATCH 2/7] format

---
 mlir/test/python/integration/dialects/pdl.py | 27 +++++++++++++++-----
 1 file changed, 21 insertions(+), 6 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index b8c7e277f1776..752d213673a70 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
             print(module)
     return f
 
+
 def get_pdl_patterns():
     # Create a rewrite from add to mul. This will match
     # - operation name is arith.addi
@@ -254,25 +255,39 @@ def pat():
             t = pdl.TypeOp(i32)
             cst = pdl.AttributeOp()
             pdl.apply_native_constraint([], "is_one", [cst])
-            op0 = pdl.OperationOp(name="myint.constant", attributes={"value": cst}, types=[t])
+            op0 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": cst}, types=[t]
+            )
 
             @pdl.rewrite()
             def rew():
-                expanded = pdl.apply_native_rewrite([pdl.OperationType.get()], "expand", [cst])
+                expanded = pdl.apply_native_rewrite(
+                    [pdl.OperationType.get()], "expand", [cst]
+                )
                 pdl.ReplaceOp(op0, with_op=expanded)
 
     def is_one(rewriter, results, values):
         cst = values[0].value
         return cst <= 1
-    
+
     def expand(rewriter, results, values):
         cst = values[0].value
         c1 = cst // 2
         c2 = cst - c1
         with rewriter.ip():
-            op1 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c1)})
-            op2 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c2)})
-            res = Operation.create("myint.add", results=[i32], operands=[op1.result, op2.result])
+            op1 = Operation.create(
+                "myint.constant",
+                results=[i32],
+                attributes={"value": IntegerAttr.get(i32, c1)},
+            )
+            op2 = Operation.create(
+                "myint.constant",
+                results=[i32],
+                attributes={"value": IntegerAttr.get(i32, c2)},
+            )
+            res = Operation.create(
+                "myint.add", results=[i32], operands=[op1.result, op2.result]
+            )
         results.append(res)
 
     pdl_module = PDLModule(m)

>From 68fbb0f18f8e86d51f16c4d6d9f4936133ef6d13 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 22:20:11 +0800
Subject: [PATCH 3/7] add comment for c api

---
 mlir/include/mlir-c/Rewrite.h | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index b0f60901c5301..c53470ca09960 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -101,6 +101,9 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
 MLIR_CAPI_EXPORTED MlirBlock
 mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
 
+/// Returns the operation right after the current insertion point
+/// of the rewriter. A null MlirOperation will be returned
+// if the current insertion block is empty.
 MLIR_CAPI_EXPORTED MlirOperation
 mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
 

>From 02f17827662ae3f5fd6fc5ff498aa3196bfb97c1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 27 Sep 2025 22:27:29 +0800
Subject: [PATCH 4/7] fix doc

---
 mlir/include/mlir-c/Rewrite.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index c53470ca09960..5dd285ee076c4 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -103,7 +103,7 @@ mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
 
 /// Returns the operation right after the current insertion point
 /// of the rewriter. A null MlirOperation will be returned
-// if the current insertion block is empty.
+// if the current insertion point is at the end of the block.
 MLIR_CAPI_EXPORTED MlirOperation
 mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
 

>From 68264afd67d5a87172901cefc836fc92b3fb30da Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 3 Oct 2025 16:46:52 +0800
Subject: [PATCH 5/7] refactor to a class

---
 mlir/lib/Bindings/Python/IRCore.cpp  |  3 ++
 mlir/lib/Bindings/Python/IRModule.h  |  2 ++
 mlir/lib/Bindings/Python/Rewrite.cpp | 48 ++++++++++++++++++----------
 3 files changed, 36 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 32b2b0c648cff..7b1710656243a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
     : refOperation(beforeOperationBase.getOperation().getRef()),
       block((*refOperation)->getBlock()) {}
 
+PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
+    : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}
+
 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
   PyOperation &operation = operationBase.getOperation();
   if (operation.isAttached())
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index edbd73eade906..e706be3b4d32a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -841,6 +841,8 @@ class PyInsertionPoint {
   PyInsertionPoint(const PyBlock &block);
   /// Creates an insertion point positioned before a reference operation.
   PyInsertionPoint(PyOperationBase &beforeOperationBase);
+  /// Creates an insertion point positioned before a reference operation.
+  PyInsertionPoint(PyOperationRef beforeOperationRef);
 
   /// Shortcut to create an insertion point at the beginning of the block.
   static PyInsertionPoint atBlockBegin(PyBlock &block);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9be084faf66a2..10b539a7b3c07 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -26,6 +26,31 @@ using namespace mlir::python;
 
 namespace {
 
+class PyPatternRewriter {
+public:
+  PyPatternRewriter(MlirPatternRewriter rewriter)
+      : rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
+        ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
+
+  PyInsertionPoint getInsertionPoint() const {
+    MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
+    MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
+
+    if (mlirOperationIsNull(op)) {
+      MlirOperation owner = mlirBlockGetParentOperation(block);
+      auto parent = PyOperation::forOperation(ctx, owner);
+      return PyInsertionPoint(PyBlock(parent, block));
+    }
+
+    return PyInsertionPoint(PyOperation::forOperation(ctx, op));
+  }
+
+private:
+  MlirPatternRewriter rewriter [[maybe_unused]];
+  MlirRewriterBase base;
+  PyMlirContextRef ctx;
+};
+
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 static nb::object objectFromPDLValue(MlirPDLValue value) {
   if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -84,7 +109,8 @@ class PyPDLPatternModule {
            void *userData) -> MlirLogicalResult {
           nb::handle f = nb::handle(static_cast<PyObject *>(userData));
           return logicalResultFromObject(
-              f(rewriter, results, objectsFromPDLValues(nValues, values)));
+              f(PyPatternRewriter(rewriter), results,
+                objectsFromPDLValues(nValues, values)));
         },
         fn.ptr());
   }
@@ -98,7 +124,8 @@ class PyPDLPatternModule {
            void *userData) -> MlirLogicalResult {
           nb::handle f = nb::handle(static_cast<PyObject *>(userData));
           return logicalResultFromObject(
-              f(rewriter, results, objectsFromPDLValues(nValues, values)));
+              f(PyPatternRewriter(rewriter), results,
+                objectsFromPDLValues(nValues, values)));
         },
         fn.ptr());
   }
@@ -143,21 +170,8 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
-      .def("ip", [](MlirPatternRewriter rewriter) {
-        MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
-        MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
-        MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
-        MlirOperation owner = mlirBlockGetParentOperation(block);
-        auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
-                       ->getRef();
-        if (mlirOperationIsNull(op)) {
-          auto parent = PyOperation::forOperation(ctx, owner);
-          return PyInsertionPoint(PyBlock(parent, block));
-        }
-
-        return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
-      });
+  nb::class_<PyPatternRewriter>(m, "PyPatternRewriter")
+      .def("ip", &PyPatternRewriter::getInsertionPoint);
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------

>From a87e193fb7c1ad0645eadd45b4440594b58689d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 4 Oct 2025 00:31:13 +0800
Subject: [PATCH 6/7] make ip a property

---
 mlir/lib/Bindings/Python/Rewrite.cpp         | 5 +++--
 mlir/test/python/integration/dialects/pdl.py | 2 +-
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 10b539a7b3c07..e94c583b85aba 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -170,8 +170,9 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<PyPatternRewriter>(m, "PyPatternRewriter")
-      .def("ip", &PyPatternRewriter::getInsertionPoint);
+  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
+      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+                   "The current insertion point of the PatternRewriter.");
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 752d213673a70..fe27dd4203a21 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -274,7 +274,7 @@ def expand(rewriter, results, values):
         cst = values[0].value
         c1 = cst // 2
         c2 = cst - c1
-        with rewriter.ip():
+        with rewriter.ip:
             op1 = Operation.create(
                 "myint.constant",
                 results=[i32],

>From 8694c3405acfbcb5cd2ffa3fd9c3eecb98cfb572 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 4 Oct 2025 11:54:35 +0800
Subject: [PATCH 7/7] remove useless field

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e94c583b85aba..9e3d9703c82e8 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -29,7 +29,7 @@ namespace {
 class PyPatternRewriter {
 public:
   PyPatternRewriter(MlirPatternRewriter rewriter)
-      : rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
+      : base(mlirPatternRewriterAsBase(rewriter)),
         ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
 
   PyInsertionPoint getInsertionPoint() const {
@@ -46,7 +46,6 @@ class PyPatternRewriter {
   }
 
 private:
-  MlirPatternRewriter rewriter [[maybe_unused]];
   MlirRewriterBase base;
   PyMlirContextRef ctx;
 };



More information about the Mlir-commits mailing list