[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL native rewrite function registering (PR #159926)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 23 07:12:13 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/159926

>From 118e4c7da941001295148006c3a45193ff078259 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 01:47:41 +0800
Subject: [PATCH 01/12] [MLIR][Python] Add bindings for PDL native function
 registering

---
 mlir/include/mlir-c/Rewrite.h                | 32 +++++++
 mlir/lib/Bindings/Python/Rewrite.cpp         | 74 +++++++++++++--
 mlir/lib/CAPI/Transforms/Rewrite.cpp         | 99 ++++++++++++++++++++
 mlir/test/python/integration/dialects/pdl.py | 82 ++++++++++++++++
 4 files changed, 281 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 374d2fb78de88..c20558fc8f9d9 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+DEFINE_C_API_STRUCT(MlirPDLValue, const void);
+DEFINE_C_API_STRUCT(MlirPDLResultList, void);
 
 MLIR_CAPI_EXPORTED MlirPDLPatternModule
 mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
 
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
+
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
+                                                      MlirType value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                   MlirOperation value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                   MlirAttribute value);
+
+typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
+    MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+    MlirPDLValue *values, void *userData);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData);
+
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
 #undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5b7de50f02e6a..89dda560702ba 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,10 +9,13 @@
 #include "Rewrite.h"
 
 #include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir/Config/mlir-config.h"
+#include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -36,6 +39,22 @@ class PyPDLPatternModule {
   }
   MlirPDLPatternModule get() { return module; }
 
+  void registerRewriteFunction(const std::string &name,
+                               const nb::callable &fn) {
+    mlirPDLPatternModuleRegisterRewriteFunction(
+        get(), mlirStringRefCreate(name.data(), name.size()),
+        [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+           size_t nValues, MlirPDLValue *values,
+           void *userData) -> MlirLogicalResult {
+          auto f = nb::handle(static_cast<PyObject *>(userData));
+          auto valueVec = std::vector<MlirPDLValue>(values, values + nValues);
+          return nb::cast<bool>(f(rewriter, results, valueVec))
+                     ? mlirLogicalResultSuccess()
+                     : mlirLogicalResultFailure();
+        },
+        fn.ptr());
+  }
+
 private:
   MlirPDLPatternModule module;
 };
@@ -76,10 +95,43 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
-  // Mapping of the top-level PassManager
+  // Mapping of the PDLModule
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  nb::class_<MlirPDLValue>(m, "PDLValue").def("get", [](MlirPDLValue value) {
+    if (mlirPDLValueIsValue(value)) {
+      return nb::cast(mlirPDLValueAsValue(value));
+    }
+    if (mlirPDLValueIsOperation(value)) {
+      return nb::cast(mlirPDLValueAsOperation(value));
+    }
+    if (mlirPDLValueIsAttribute(value)) {
+      return nb::cast(mlirPDLValueAsAttribute(value));
+    }
+    if (mlirPDLValueIsType(value)) {
+      return nb::cast(mlirPDLValueAsType(value));
+    }
+
+    throw std::runtime_error("unsupported PDL value type");
+  });
+  nb::class_<MlirPDLResultList>(m, "PDLResultList")
+      .def("push_back",
+           [](MlirPDLResultList results, const PyValue &value) {
+             mlirPDLResultListPushBackValue(results, value);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyOperation &op) {
+             mlirPDLResultListPushBackOperation(results, op);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyType &type) {
+             mlirPDLResultListPushBackType(results, type);
+           })
+      .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) {
+        mlirPDLResultListPushBackAttribute(results, attr);
+      });
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
       .def(
           "__init__",
@@ -88,10 +140,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
           },
           "module"_a, "Create a PDL module from the given module.")
-      .def("freeze", [](PyPDLPatternModule &self) {
-        return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
-            mlirRewritePatternSetFromPDLPatternModule(self.get())));
-      });
+      .def(
+          "freeze",
+          [](PyPDLPatternModule &self) {
+            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+                mlirRewritePatternSetFromPDLPatternModule(self.get())));
+          },
+          nb::keep_alive<0, 1>())
+      .def(
+          "register_rewrite_function",
+          [](PyPDLPatternModule &self, const std::string &name,
+             const nb::callable &fn) {
+            self.registerRewriteFunction(name, fn);
+          },
+          nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a18..0033abde986ea 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
 #include "mlir/CAPI/Rewrite.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+  assert(rewriter.ptr && "unexpected null rewriter");
+  return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+  return {rewriter};
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
@@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
   op.ptr = nullptr;
   return wrap(m);
 }
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+  assert(value.ptr && "unexpected null PDL value");
+  return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+  assert(results.ptr && "unexpected null PDL results");
+  return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+  return {results};
+}
+
+bool mlirPDLValueIsValue(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Value>();
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Value>());
+}
+
+bool mlirPDLValueIsType(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Type>();
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Type>());
+}
+
+bool mlirPDLValueIsOperation(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Operation *>();
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Operation *>());
+}
+
+bool mlirPDLValueIsAttribute(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Attribute>();
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+                                    MlirValue value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                        MlirOperation value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                        MlirAttribute value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData) {
+  unwrap(module)->registerRewriteFunction(
+      unwrap(name),
+      [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+                            ArrayRef<PDLValue> values) -> LogicalResult {
+        std::vector<MlirPDLValue> mlirValues;
+        for (auto &value : values) {
+          mlirValues.push_back(wrap(&value));
+        }
+        return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+                                mlirValues.size(), mlirValues.data(),
+                                userData));
+      });
+}
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index dd6c74ce622c8..8954e3622d3ef 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -86,3 +86,85 @@ def add_func(a, b):
     frozen = get_pdl_patterns()
     apply_patterns_and_fold_greedily(module_.operation, frozen)
     return module_
+
+
+# If we use arith.constant and arith.addi here,
+# these C++-defined folding/canonicalization will be applied
+# implicitly in the greedy pattern rewrite driver to
+# make our Python-defined folding useless,
+# so here we define a new dialect to workaround this.
+def load_myint_dialect():
+    from mlir.dialects import irdl
+    m = Module.create()
+    with InsertionPoint(m.body):
+        myint = irdl.dialect("myint")
+        with InsertionPoint(myint.body):
+            constant = irdl.operation_("constant")
+            with InsertionPoint(constant.body):
+                iattr = irdl.base(base_name="#builtin.integer")
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.attributes_([iattr], ["value"])
+                irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
+            add = irdl.operation_("add")
+            with InsertionPoint(add.body):
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+                irdl.results_([i32], ["res"], [irdl.Variadicity.single])
+
+    m.operation.verify()
+    irdl.load_dialects(m)
+
+# this PDL pattern is to fold constant additions,
+# i.e. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1
+def get_pdl_pattern_fold():
+    m = Module.create()
+    with InsertionPoint(m.body):
+        @pdl.pattern(benefit=1, sym_name="myint_add_fold")
+        def pat():
+            t = pdl.TypeOp(IntegerType.get_signless(32))
+            a0 = pdl.AttributeOp()
+            a1 = pdl.AttributeOp()
+            c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
+            c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+            v0 = pdl.ResultOp(c0, 0)
+            v1 = pdl.ResultOp(c1, 0)
+            op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+                newOp = pdl.OperationOp(
+                    name="myint.constant", attributes={"value": sum}, types=[t]
+                )
+                pdl.ReplaceOp(op0, with_op=newOp)
+
+    pdl_module = PDLModule(m)
+    def add_fold(rewriter, results, values):
+        a0, a1 = [i.get() for i in values]
+        i32 = IntegerType.get_signless(32)
+        results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+        return True
+    pdl_module.register_rewrite_function("add_fold", add_fold)
+    return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function
+# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
+ at construct_and_print_in_module
+def test_pdl_register_function(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        %c0 = "myint.constant"() { value = 2 }: () -> (i32)
+        %c1 = "myint.constant"() { value = 3 }: () -> (i32)
+        %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+        "myint.add"(%x, %c1): (i32, i32) -> (i32)
+        """
+    )
+
+    frozen = get_pdl_pattern_fold()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

>From f1315a6088031967a673406c239173333bd3103a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 20:56:58 +0800
Subject: [PATCH 02/12] fix style

---
 mlir/test/python/integration/dialects/pdl.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 8954e3622d3ef..64628db072be9 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -114,9 +114,9 @@ def load_myint_dialect():
     m.operation.verify()
     irdl.load_dialects(m)
 
-# this PDL pattern is to fold constant additions,
+# This PDL pattern is to fold constant additions,
 # i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1
+# where constant2 = constant0 + constant1.
 def get_pdl_pattern_fold():
     m = Module.create()
     with InsertionPoint(m.body):
@@ -139,12 +139,13 @@ def rew():
                 )
                 pdl.ReplaceOp(op0, with_op=newOp)
 
-    pdl_module = PDLModule(m)
     def add_fold(rewriter, results, values):
         a0, a1 = [i.get() for i in values]
         i32 = IntegerType.get_signless(32)
         results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
         return True
+
+    pdl_module = PDLModule(m)
     pdl_module.register_rewrite_function("add_fold", add_fold)
     return pdl_module.freeze()
 

>From 2f3da2e0c356721311d2828b1293c35204d2fd6e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 22:08:15 +0800
Subject: [PATCH 03/12] format

---
 mlir/test/python/integration/dialects/pdl.py | 21 ++++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 64628db072be9..5b33802cefaba 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -95,6 +95,7 @@ def add_func(a, b):
 # so here we define a new dialect to workaround this.
 def load_myint_dialect():
     from mlir.dialects import irdl
+
     m = Module.create()
     with InsertionPoint(m.body):
         myint = irdl.dialect("myint")
@@ -108,32 +109,44 @@ def load_myint_dialect():
             add = irdl.operation_("add")
             with InsertionPoint(add.body):
                 i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
-                irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+                irdl.operands_(
+                    [i32, i32],
+                    ["lhs", "rhs"],
+                    [irdl.Variadicity.single, irdl.Variadicity.single]
+                )
                 irdl.results_([i32], ["res"], [irdl.Variadicity.single])
 
     m.operation.verify()
     irdl.load_dialects(m)
 
+
 # This PDL pattern is to fold constant additions,
 # i.e. add(constant0, constant1) -> constant2
 # where constant2 = constant0 + constant1.
 def get_pdl_pattern_fold():
     m = Module.create()
     with InsertionPoint(m.body):
+
         @pdl.pattern(benefit=1, sym_name="myint_add_fold")
         def pat():
             t = pdl.TypeOp(IntegerType.get_signless(32))
             a0 = pdl.AttributeOp()
             a1 = pdl.AttributeOp()
-            c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
-            c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+            c0 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": a0}, types=[t]
+            )
+            c1 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": a1}, types=[t]
+            )
             v0 = pdl.ResultOp(c0, 0)
             v1 = pdl.ResultOp(c1, 0)
             op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
 
             @pdl.rewrite()
             def rew():
-                sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+                sum = pdl.apply_native_rewrite(
+                    [pdl.AttributeType.get()], "add_fold", [a0, a1]
+                )
                 newOp = pdl.OperationOp(
                     name="myint.constant", attributes={"value": sum}, types=[t]
                 )

>From b2de4ec9ea125ce6ff181fce352da75fec00d9ba Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 22:27:58 +0800
Subject: [PATCH 04/12] remove useless bindings

---
 mlir/lib/Bindings/Python/Rewrite.cpp         | 43 +++++++++++---------
 mlir/test/python/integration/dialects/pdl.py |  8 ++--
 2 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 89dda560702ba..e1cf61677fef2 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -39,6 +39,23 @@ class PyPDLPatternModule {
   }
   MlirPDLPatternModule get() { return module; }
 
+  static nb::object fromPDLValue(MlirPDLValue value) {
+    if (mlirPDLValueIsValue(value)) {
+      return nb::cast(mlirPDLValueAsValue(value));
+    }
+    if (mlirPDLValueIsOperation(value)) {
+      return nb::cast(mlirPDLValueAsOperation(value));
+    }
+    if (mlirPDLValueIsAttribute(value)) {
+      return nb::cast(mlirPDLValueAsAttribute(value));
+    }
+    if (mlirPDLValueIsType(value)) {
+      return nb::cast(mlirPDLValueAsType(value));
+    }
+
+    throw std::runtime_error("unsupported PDL value type");
+  }
+
   void registerRewriteFunction(const std::string &name,
                                const nb::callable &fn) {
     mlirPDLPatternModuleRegisterRewriteFunction(
@@ -47,8 +64,12 @@ class PyPDLPatternModule {
            size_t nValues, MlirPDLValue *values,
            void *userData) -> MlirLogicalResult {
           auto f = nb::handle(static_cast<PyObject *>(userData));
-          auto valueVec = std::vector<MlirPDLValue>(values, values + nValues);
-          return nb::cast<bool>(f(rewriter, results, valueVec))
+          std::vector<nb::object> args;
+          args.reserve(nValues);
+          for (size_t i = 0; i < nValues; ++i) {
+            args.push_back(fromPDLValue(values[i]));
+          }
+          return nb::cast<bool>(f(rewriter, results, args))
                      ? mlirLogicalResultSuccess()
                      : mlirLogicalResultFailure();
         },
@@ -97,25 +118,9 @@ class PyFrozenRewritePatternSet {
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
   nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
-  // Mapping of the PDLModule
+  // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-  nb::class_<MlirPDLValue>(m, "PDLValue").def("get", [](MlirPDLValue value) {
-    if (mlirPDLValueIsValue(value)) {
-      return nb::cast(mlirPDLValueAsValue(value));
-    }
-    if (mlirPDLValueIsOperation(value)) {
-      return nb::cast(mlirPDLValueAsOperation(value));
-    }
-    if (mlirPDLValueIsAttribute(value)) {
-      return nb::cast(mlirPDLValueAsAttribute(value));
-    }
-    if (mlirPDLValueIsType(value)) {
-      return nb::cast(mlirPDLValueAsType(value));
-    }
-
-    throw std::runtime_error("unsupported PDL value type");
-  });
   nb::class_<MlirPDLResultList>(m, "PDLResultList")
       .def("push_back",
            [](MlirPDLResultList results, const PyValue &value) {
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 5b33802cefaba..c78f2d4f9a0dc 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -112,7 +112,7 @@ def load_myint_dialect():
                 irdl.operands_(
                     [i32, i32],
                     ["lhs", "rhs"],
-                    [irdl.Variadicity.single, irdl.Variadicity.single]
+                    [irdl.Variadicity.single, irdl.Variadicity.single],
                 )
                 irdl.results_([i32], ["res"], [irdl.Variadicity.single])
 
@@ -125,11 +125,12 @@ def load_myint_dialect():
 # where constant2 = constant0 + constant1.
 def get_pdl_pattern_fold():
     m = Module.create()
+    i32 = IntegerType.get_signless(32)
     with InsertionPoint(m.body):
 
         @pdl.pattern(benefit=1, sym_name="myint_add_fold")
         def pat():
-            t = pdl.TypeOp(IntegerType.get_signless(32))
+            t = pdl.TypeOp(i32)
             a0 = pdl.AttributeOp()
             a1 = pdl.AttributeOp()
             c0 = pdl.OperationOp(
@@ -153,8 +154,7 @@ def rew():
                 pdl.ReplaceOp(op0, with_op=newOp)
 
     def add_fold(rewriter, results, values):
-        a0, a1 = [i.get() for i in values]
-        i32 = IntegerType.get_signless(32)
+        a0, a1 = values
         results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
         return True
 

>From d6db1e5be4d3f11127c8b978276efc19da9f9eb0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 22:33:59 +0800
Subject: [PATCH 05/12] fix header order

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e1cf61677fef2..52be91223c5f8 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,11 +9,13 @@
 #include "Rewrite.h"
 
 #include "IRModule.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
 #include "mlir-c/Support.h"
+// clang-format off
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
 #include "mlir/Config/mlir-config.h"
 #include "nanobind/nanobind.h"
 

>From c2f727e12c525a00c0192037cb79ead204cf5ab5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 23:04:35 +0800
Subject: [PATCH 06/12] fix

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 52be91223c5f8..eceb5895fd901 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -136,7 +136,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
            [](MlirPDLResultList results, const PyType &type) {
              mlirPDLResultListPushBackType(results, type);
            })
-      .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) {
+      .def("push_back", [](MlirPDLResultList results, const PyAttribute &attr) {
         mlirPDLResultListPushBackAttribute(results, attr);
       });
   nb::class_<PyPDLPatternModule>(m, "PDLModule")

>From 64350847220748ddc87ff764cf402bb03367baf6 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Sun, 21 Sep 2025 11:27:08 +0800
Subject: [PATCH 07/12] Apply suggestion from @makslevental

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
 mlir/lib/Bindings/Python/Rewrite.cpp | 12 ++++--------
 1 file changed, 4 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index eceb5895fd901..d8194388b195b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -42,18 +42,14 @@ class PyPDLPatternModule {
   MlirPDLPatternModule get() { return module; }
 
   static nb::object fromPDLValue(MlirPDLValue value) {
-    if (mlirPDLValueIsValue(value)) {
+    if (mlirPDLValueIsValue(value))
       return nb::cast(mlirPDLValueAsValue(value));
-    }
-    if (mlirPDLValueIsOperation(value)) {
+    if (mlirPDLValueIsOperation(value))
       return nb::cast(mlirPDLValueAsOperation(value));
-    }
-    if (mlirPDLValueIsAttribute(value)) {
+    if (mlirPDLValueIsAttribute(value))
       return nb::cast(mlirPDLValueAsAttribute(value));
-    }
-    if (mlirPDLValueIsType(value)) {
+    if (mlirPDLValueIsType(value))
       return nb::cast(mlirPDLValueAsType(value));
-    }
 
     throw std::runtime_error("unsupported PDL value type");
   }

>From 0653ac60d67ed853d1dea24ac352d5bd59cb806b Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Sun, 21 Sep 2025 11:27:25 +0800
Subject: [PATCH 08/12] Apply suggestion from @makslevental

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
 mlir/lib/Bindings/Python/Rewrite.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index d8194388b195b..aee8534b33f7e 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -64,9 +64,8 @@ class PyPDLPatternModule {
           auto f = nb::handle(static_cast<PyObject *>(userData));
           std::vector<nb::object> args;
           args.reserve(nValues);
-          for (size_t i = 0; i < nValues; ++i) {
+          for (size_t i = 0; i < nValues; ++i)
             args.push_back(fromPDLValue(values[i]));
-          }
           return nb::cast<bool>(f(rewriter, results, args))
                      ? mlirLogicalResultSuccess()
                      : mlirLogicalResultFailure();

>From e9845d639f6243d6133cd79d4a861563dc6784ca Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 21 Sep 2025 13:25:11 +0800
Subject: [PATCH 09/12] move out pdlvalue cast and logical result conversion

---
 mlir/lib/Bindings/Python/Rewrite.cpp         | 42 +++++++++++---------
 mlir/test/python/integration/dialects/pdl.py |  1 -
 2 files changed, 24 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index ca4e3e331e9b5..3870e6887dfdb 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -27,6 +27,27 @@ using namespace mlir::python;
 namespace {
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+nb::object objectFromPDLValue(MlirPDLValue value) {
+  if (mlirPDLValueIsValue(value))
+    return nb::cast(mlirPDLValueAsValue(value));
+  if (mlirPDLValueIsOperation(value))
+    return nb::cast(mlirPDLValueAsOperation(value));
+  if (mlirPDLValueIsAttribute(value))
+    return nb::cast(mlirPDLValueAsAttribute(value));
+  if (mlirPDLValueIsType(value))
+    return nb::cast(mlirPDLValueAsType(value));
+
+  throw std::runtime_error("unsupported PDL value type");
+}
+
+MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
+  if (obj.is_none())
+    return mlirLogicalResultSuccess();
+
+  return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
+                             : mlirLogicalResultSuccess();
+}
+
 /// Owning Wrapper around a PDLPatternModule.
 class PyPDLPatternModule {
 public:
@@ -41,19 +62,6 @@ class PyPDLPatternModule {
   }
   MlirPDLPatternModule get() { return module; }
 
-  static nb::object fromPDLValue(MlirPDLValue value) {
-    if (mlirPDLValueIsValue(value))
-      return nb::cast(mlirPDLValueAsValue(value));
-    if (mlirPDLValueIsOperation(value))
-      return nb::cast(mlirPDLValueAsOperation(value));
-    if (mlirPDLValueIsAttribute(value))
-      return nb::cast(mlirPDLValueAsAttribute(value));
-    if (mlirPDLValueIsType(value))
-      return nb::cast(mlirPDLValueAsType(value));
-
-    throw std::runtime_error("unsupported PDL value type");
-  }
-
   void registerRewriteFunction(const std::string &name,
                                const nb::callable &fn) {
     mlirPDLPatternModuleRegisterRewriteFunction(
@@ -61,14 +69,12 @@ class PyPDLPatternModule {
         [](MlirPatternRewriter rewriter, MlirPDLResultList results,
            size_t nValues, MlirPDLValue *values,
            void *userData) -> MlirLogicalResult {
-          auto f = nb::handle(static_cast<PyObject *>(userData));
+          nb::handle f = nb::handle(static_cast<PyObject *>(userData));
           std::vector<nb::object> args;
           args.reserve(nValues);
           for (size_t i = 0; i < nValues; ++i)
-            args.push_back(fromPDLValue(values[i]));
-          return nb::cast<bool>(f(rewriter, results, args))
-                     ? mlirLogicalResultSuccess()
-                     : mlirLogicalResultFailure();
+            args.push_back(objectFromPDLValue(values[i]));
+          return logicalResultFromObject(f(rewriter, results, args));
         },
         fn.ptr());
   }
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index c78f2d4f9a0dc..8fbe1a7151f63 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -156,7 +156,6 @@ def rew():
     def add_fold(rewriter, results, values):
         a0, a1 = values
         results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
-        return True
 
     pdl_module = PDLModule(m)
     pdl_module.register_rewrite_function("add_fold", add_fold)

>From bf7940966941de09273a170c98fa5e6ffbdaf867 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 21 Sep 2025 14:28:32 +0800
Subject: [PATCH 10/12] add sigs

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 51 ++++++++++++++++++++--------
 1 file changed, 36 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 3870e6887dfdb..7f71f134e44ae 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -125,21 +125,42 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<MlirPDLResultList>(m, "PDLResultList")
-      .def("push_back",
-           [](MlirPDLResultList results, const PyValue &value) {
-             mlirPDLResultListPushBackValue(results, value);
-           })
-      .def("push_back",
-           [](MlirPDLResultList results, const PyOperation &op) {
-             mlirPDLResultListPushBackOperation(results, op);
-           })
-      .def("push_back",
-           [](MlirPDLResultList results, const PyType &type) {
-             mlirPDLResultListPushBackType(results, type);
-           })
-      .def("push_back", [](MlirPDLResultList results, const PyAttribute &attr) {
-        mlirPDLResultListPushBackAttribute(results, attr);
-      });
+      .def(
+          "push_back",
+          [](MlirPDLResultList results, const PyValue &value) {
+            mlirPDLResultListPushBackValue(results, value);
+          },
+          // clang-format off
+          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
+          // clang-format on
+          )
+      .def(
+          "push_back",
+          [](MlirPDLResultList results, const PyOperation &op) {
+            mlirPDLResultListPushBackOperation(results, op);
+          },
+          // clang-format off
+          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
+          // clang-format on
+          )
+      .def(
+          "push_back",
+          [](MlirPDLResultList results, const PyType &type) {
+            mlirPDLResultListPushBackType(results, type);
+          },
+          // clang-format off
+          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
+          // clang-format on
+          )
+      .def(
+          "push_back",
+          [](MlirPDLResultList results, const PyAttribute &attr) {
+            mlirPDLResultListPushBackAttribute(results, attr);
+          },
+          // clang-format off
+          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
+          // clang-format on
+      );
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
       .def(
           "__init__",

>From e821fd4339180b45bdcf8bb2aa77edb404b6a7c2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 23 Sep 2025 22:08:42 +0800
Subject: [PATCH 11/12] apply review suggestions

---
 mlir/include/mlir-c/Rewrite.h                | 30 ++++++++++++---
 mlir/include/mlir/IR/PDLPatternMatch.h.inc   |  2 +-
 mlir/lib/Bindings/Python/Rewrite.cpp         | 40 +++++++++++---------
 mlir/lib/CAPI/Transforms/Rewrite.cpp         | 25 +++---------
 mlir/test/python/integration/dialects/pdl.py |  2 +-
 5 files changed, 54 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index c20558fc8f9d9..f4974348945c5 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -327,32 +327,52 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
 
-MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+/// Cast the MlirPDLValue to an MlirValue.
+/// Return a null value if the cast fails, just like llvm::dyn_cast.
 MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
-MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+
+/// Cast the MlirPDLValue to an MlirType.
+/// Return a null value if the cast fails, just like llvm::dyn_cast.
 MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
-MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+
+/// Cast the MlirPDLValue to an MlirOperation.
+/// Return a null value if the cast fails, just like llvm::dyn_cast.
 MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
-MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+
+/// Cast the MlirPDLValue to an MlirAttribute.
+/// Return a null value if the cast fails, just like llvm::dyn_cast.
 MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
 
+/// Push the MlirValue into the given MlirPDLResultList.
 MLIR_CAPI_EXPORTED void
 mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+
+/// Push the MlirType into the given MlirPDLResultList.
 MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
                                                       MlirType value);
+
+/// Push the MlirOperation into the given MlirPDLResultList.
 MLIR_CAPI_EXPORTED void
 mlirPDLResultListPushBackOperation(MlirPDLResultList results,
                                    MlirOperation value);
+
+/// Push the MlirAttribute into the given MlirPDLResultList.
 MLIR_CAPI_EXPORTED void
 mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
                                    MlirAttribute value);
 
+/// This function type is used as callbacks for PDL native rewrite functions.
+/// Input values can be accessed by `values` with its size `nValues`;
+/// output values can be added into `results` by `mlirPDLResultListPushBack*`
+/// APIs. And the return value indicates whether the rewrite succeeds.
 typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
     MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
     MlirPDLValue *values, void *userData);
 
+/// Register a rewrite function into the given PDL pattern module.
+/// `userData` will be provided as an argument to the rewrite function.
 MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
-    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLPatternModule pdlModule, MlirStringRef name,
     MlirPDLRewriteFunction rewriteFn, void *userData);
 
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index 96ba98a850de0..d5fb57d7c360d 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -53,7 +53,7 @@ public:
   /// value is not an instance of `T`.
   template <typename T,
             typename ResultT = std::conditional_t<
-                std::is_convertible<T, bool>::value, T, std::optional<T>>>
+                std::is_constructible_v<bool, T>, T, std::optional<T>>>
   ResultT dyn_cast() const {
     return isa<T>() ? castImpl<T>() : ResultT();
   }
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 7f71f134e44ae..72062f660458b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -27,20 +27,24 @@ using namespace mlir::python;
 namespace {
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-nb::object objectFromPDLValue(MlirPDLValue value) {
-  if (mlirPDLValueIsValue(value))
-    return nb::cast(mlirPDLValueAsValue(value));
-  if (mlirPDLValueIsOperation(value))
-    return nb::cast(mlirPDLValueAsOperation(value));
-  if (mlirPDLValueIsAttribute(value))
-    return nb::cast(mlirPDLValueAsAttribute(value));
-  if (mlirPDLValueIsType(value))
-    return nb::cast(mlirPDLValueAsType(value));
+static nb::object objectFromPDLValue(MlirPDLValue value) {
+  if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
+    return nb::cast(v);
+  if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
+    return nb::cast(v);
+  if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
+    return nb::cast(v);
+  if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
+    return nb::cast(v);
 
   throw std::runtime_error("unsupported PDL value type");
 }
 
-MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
+// Convert the Python object to a boolean.
+// If it evaluates to False, treat it as success;
+// otherwise, treat it as failure.
+// Note that None is considered success.
+static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
   if (obj.is_none())
     return mlirLogicalResultSuccess();
 
@@ -126,39 +130,39 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<MlirPDLResultList>(m, "PDLResultList")
       .def(
-          "push_back",
+          "append",
           [](MlirPDLResultList results, const PyValue &value) {
             mlirPDLResultListPushBackValue(results, value);
           },
           // clang-format off
-          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
+          nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
           // clang-format on
           )
       .def(
-          "push_back",
+          "append",
           [](MlirPDLResultList results, const PyOperation &op) {
             mlirPDLResultListPushBackOperation(results, op);
           },
           // clang-format off
-          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
+          nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
           // clang-format on
           )
       .def(
-          "push_back",
+          "append",
           [](MlirPDLResultList results, const PyType &type) {
             mlirPDLResultListPushBackType(results, type);
           },
           // clang-format off
-          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
+          nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
           // clang-format on
           )
       .def(
-          "push_back",
+          "append",
           [](MlirPDLResultList results, const PyAttribute &attr) {
             mlirPDLResultListPushBackAttribute(results, attr);
           },
           // clang-format off
-          nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
+          nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
           // clang-format on
       );
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 0033abde986ea..8b41e7022bc18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -363,36 +363,20 @@ inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
   return {results};
 }
 
-bool mlirPDLValueIsValue(MlirPDLValue value) {
-  return unwrap(value)->isa<mlir::Value>();
-}
-
 MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
-  return wrap(unwrap(value)->cast<mlir::Value>());
-}
-
-bool mlirPDLValueIsType(MlirPDLValue value) {
-  return unwrap(value)->isa<mlir::Type>();
+  return wrap(unwrap(value)->dyn_cast<mlir::Value>());
 }
 
 MlirType mlirPDLValueAsType(MlirPDLValue value) {
-  return wrap(unwrap(value)->cast<mlir::Type>());
-}
-
-bool mlirPDLValueIsOperation(MlirPDLValue value) {
-  return unwrap(value)->isa<mlir::Operation *>();
+  return wrap(unwrap(value)->dyn_cast<mlir::Type>());
 }
 
 MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
-  return wrap(unwrap(value)->cast<mlir::Operation *>());
-}
-
-bool mlirPDLValueIsAttribute(MlirPDLValue value) {
-  return unwrap(value)->isa<mlir::Attribute>();
+  return wrap(unwrap(value)->dyn_cast<mlir::Operation *>());
 }
 
 MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
-  return wrap(unwrap(value)->cast<mlir::Attribute>());
+  return wrap(unwrap(value)->dyn_cast<mlir::Attribute>());
 }
 
 void mlirPDLResultListPushBackValue(MlirPDLResultList results,
@@ -422,6 +406,7 @@ void mlirPDLPatternModuleRegisterRewriteFunction(
       [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
                             ArrayRef<PDLValue> values) -> LogicalResult {
         std::vector<MlirPDLValue> mlirValues;
+        mlirValues.reserve(values.size());
         for (auto &value : values) {
           mlirValues.push_back(wrap(&value));
         }
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 8fbe1a7151f63..e85c6c77ef955 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -155,7 +155,7 @@ def rew():
 
     def add_fold(rewriter, results, values):
         a0, a1 = values
-        results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+        results.append(IntegerAttr.get(i32, a0.value + a1.value))
 
     pdl_module = PDLModule(m)
     pdl_module.register_rewrite_function("add_fold", add_fold)

>From b08315f65006cfaebaaa56e5f401217ca7300290 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 23 Sep 2025 22:11:54 +0800
Subject: [PATCH 12/12] rename

---
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8b41e7022bc18..9ecce956a05b9 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -399,9 +399,9 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
 }
 
 void mlirPDLPatternModuleRegisterRewriteFunction(
-    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLPatternModule pdlModule, MlirStringRef name,
     MlirPDLRewriteFunction rewriteFn, void *userData) {
-  unwrap(module)->registerRewriteFunction(
+  unwrap(pdlModule)->registerRewriteFunction(
       unwrap(name),
       [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
                             ArrayRef<PDLValue> values) -> LogicalResult {



More information about the Mlir-commits mailing list