[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL native function registering (PR #159926)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 20 06:39:15 PDT 2025


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/159926


In the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need `apply_native_rewrite` to actually perform arithmetic (i.e., compute the sum of two constants). For example, consider the following PDL pseudocode:

```mlir
pdl.pattern : benefit(1) {
  %a0 = pdl.attribute
  %a1 = pdl.attribute
  %c0 = pdl.operation "arith.constant" {value = %a0}
  %c1 = pdl.operation "arith.constant" {value = %a1}

  %op = pdl.operation "arith.addi"(%c0, %c1)

  %sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1)
  %new_cst = pdl.operation "arith.constant" {value = %sum}

  pdl.replace %op with %new_cst
}
```

Here, `addIntegers` cannot be expressed in PDL alone—it requires a *native rewrite function*. This PR introduces a mechanism to support exactly that, allowing complex rewrite patterns to be expressed in Python and enabling many passes to be implemented directly in Python as well.

As a test case, we defined two new operations (`myint.constant` and `myint.add`) in Python and implemented a constant-folding rewrite pattern for them. The core code looks like this:

```python
m = Module.create()
with InsertionPoint(m.body):
    @pdl.pattern(benefit=1, sym_name="myint_add_fold")
    def pat():
        ...
        op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])

        @pdl.rewrite()
        def rew():
            sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
            newOp = pdl.OperationOp(
                name="myint.constant", attributes={"value": sum}, types=[t]
            )
            pdl.ReplaceOp(op0, with_op=newOp)


def add_fold(rewriter, results, values):
    a0, a1 = [i.get() for i in values]
    i32 = IntegerType.get_signless(32)
    results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
    return True


pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
```


>From 118e4c7da941001295148006c3a45193ff078259 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 01:47:41 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add bindings for PDL native function
 registering

---
 mlir/include/mlir-c/Rewrite.h                | 32 +++++++
 mlir/lib/Bindings/Python/Rewrite.cpp         | 74 +++++++++++++--
 mlir/lib/CAPI/Transforms/Rewrite.cpp         | 99 ++++++++++++++++++++
 mlir/test/python/integration/dialects/pdl.py | 82 ++++++++++++++++
 4 files changed, 281 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 374d2fb78de88..c20558fc8f9d9 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+DEFINE_C_API_STRUCT(MlirPDLValue, const void);
+DEFINE_C_API_STRUCT(MlirPDLResultList, void);
 
 MLIR_CAPI_EXPORTED MlirPDLPatternModule
 mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
 
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
+
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
+                                                      MlirType value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                   MlirOperation value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                   MlirAttribute value);
+
+typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
+    MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+    MlirPDLValue *values, void *userData);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData);
+
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
 #undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5b7de50f02e6a..89dda560702ba 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,10 +9,13 @@
 #include "Rewrite.h"
 
 #include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir/Config/mlir-config.h"
+#include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -36,6 +39,22 @@ class PyPDLPatternModule {
   }
   MlirPDLPatternModule get() { return module; }
 
+  void registerRewriteFunction(const std::string &name,
+                               const nb::callable &fn) {
+    mlirPDLPatternModuleRegisterRewriteFunction(
+        get(), mlirStringRefCreate(name.data(), name.size()),
+        [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+           size_t nValues, MlirPDLValue *values,
+           void *userData) -> MlirLogicalResult {
+          auto f = nb::handle(static_cast<PyObject *>(userData));
+          auto valueVec = std::vector<MlirPDLValue>(values, values + nValues);
+          return nb::cast<bool>(f(rewriter, results, valueVec))
+                     ? mlirLogicalResultSuccess()
+                     : mlirLogicalResultFailure();
+        },
+        fn.ptr());
+  }
+
 private:
   MlirPDLPatternModule module;
 };
@@ -76,10 +95,43 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
-  // Mapping of the top-level PassManager
+  // Mapping of the PDLModule
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  nb::class_<MlirPDLValue>(m, "PDLValue").def("get", [](MlirPDLValue value) {
+    if (mlirPDLValueIsValue(value)) {
+      return nb::cast(mlirPDLValueAsValue(value));
+    }
+    if (mlirPDLValueIsOperation(value)) {
+      return nb::cast(mlirPDLValueAsOperation(value));
+    }
+    if (mlirPDLValueIsAttribute(value)) {
+      return nb::cast(mlirPDLValueAsAttribute(value));
+    }
+    if (mlirPDLValueIsType(value)) {
+      return nb::cast(mlirPDLValueAsType(value));
+    }
+
+    throw std::runtime_error("unsupported PDL value type");
+  });
+  nb::class_<MlirPDLResultList>(m, "PDLResultList")
+      .def("push_back",
+           [](MlirPDLResultList results, const PyValue &value) {
+             mlirPDLResultListPushBackValue(results, value);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyOperation &op) {
+             mlirPDLResultListPushBackOperation(results, op);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyType &type) {
+             mlirPDLResultListPushBackType(results, type);
+           })
+      .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) {
+        mlirPDLResultListPushBackAttribute(results, attr);
+      });
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
       .def(
           "__init__",
@@ -88,10 +140,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
           },
           "module"_a, "Create a PDL module from the given module.")
-      .def("freeze", [](PyPDLPatternModule &self) {
-        return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
-            mlirRewritePatternSetFromPDLPatternModule(self.get())));
-      });
+      .def(
+          "freeze",
+          [](PyPDLPatternModule &self) {
+            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+                mlirRewritePatternSetFromPDLPatternModule(self.get())));
+          },
+          nb::keep_alive<0, 1>())
+      .def(
+          "register_rewrite_function",
+          [](PyPDLPatternModule &self, const std::string &name,
+             const nb::callable &fn) {
+            self.registerRewriteFunction(name, fn);
+          },
+          nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a18..0033abde986ea 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
 #include "mlir/CAPI/Rewrite.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+  assert(rewriter.ptr && "unexpected null rewriter");
+  return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+  return {rewriter};
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
@@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
   op.ptr = nullptr;
   return wrap(m);
 }
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+  assert(value.ptr && "unexpected null PDL value");
+  return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+  assert(results.ptr && "unexpected null PDL results");
+  return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+  return {results};
+}
+
+bool mlirPDLValueIsValue(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Value>();
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Value>());
+}
+
+bool mlirPDLValueIsType(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Type>();
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Type>());
+}
+
+bool mlirPDLValueIsOperation(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Operation *>();
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Operation *>());
+}
+
+bool mlirPDLValueIsAttribute(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Attribute>();
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+                                    MlirValue value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                        MlirOperation value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                        MlirAttribute value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData) {
+  unwrap(module)->registerRewriteFunction(
+      unwrap(name),
+      [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+                            ArrayRef<PDLValue> values) -> LogicalResult {
+        std::vector<MlirPDLValue> mlirValues;
+        for (auto &value : values) {
+          mlirValues.push_back(wrap(&value));
+        }
+        return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+                                mlirValues.size(), mlirValues.data(),
+                                userData));
+      });
+}
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index dd6c74ce622c8..8954e3622d3ef 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -86,3 +86,85 @@ def add_func(a, b):
     frozen = get_pdl_patterns()
     apply_patterns_and_fold_greedily(module_.operation, frozen)
     return module_
+
+
+# If we use arith.constant and arith.addi here,
+# these C++-defined folding/canonicalization will be applied
+# implicitly in the greedy pattern rewrite driver to
+# make our Python-defined folding useless,
+# so here we define a new dialect to workaround this.
+def load_myint_dialect():
+    from mlir.dialects import irdl
+    m = Module.create()
+    with InsertionPoint(m.body):
+        myint = irdl.dialect("myint")
+        with InsertionPoint(myint.body):
+            constant = irdl.operation_("constant")
+            with InsertionPoint(constant.body):
+                iattr = irdl.base(base_name="#builtin.integer")
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.attributes_([iattr], ["value"])
+                irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
+            add = irdl.operation_("add")
+            with InsertionPoint(add.body):
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+                irdl.results_([i32], ["res"], [irdl.Variadicity.single])
+
+    m.operation.verify()
+    irdl.load_dialects(m)
+
+# this PDL pattern is to fold constant additions,
+# i.e. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1
+def get_pdl_pattern_fold():
+    m = Module.create()
+    with InsertionPoint(m.body):
+        @pdl.pattern(benefit=1, sym_name="myint_add_fold")
+        def pat():
+            t = pdl.TypeOp(IntegerType.get_signless(32))
+            a0 = pdl.AttributeOp()
+            a1 = pdl.AttributeOp()
+            c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
+            c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+            v0 = pdl.ResultOp(c0, 0)
+            v1 = pdl.ResultOp(c1, 0)
+            op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+                newOp = pdl.OperationOp(
+                    name="myint.constant", attributes={"value": sum}, types=[t]
+                )
+                pdl.ReplaceOp(op0, with_op=newOp)
+
+    pdl_module = PDLModule(m)
+    def add_fold(rewriter, results, values):
+        a0, a1 = [i.get() for i in values]
+        i32 = IntegerType.get_signless(32)
+        results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+        return True
+    pdl_module.register_rewrite_function("add_fold", add_fold)
+    return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function
+# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
+ at construct_and_print_in_module
+def test_pdl_register_function(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        %c0 = "myint.constant"() { value = 2 }: () -> (i32)
+        %c1 = "myint.constant"() { value = 3 }: () -> (i32)
+        %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+        "myint.add"(%x, %c1): (i32, i32) -> (i32)
+        """
+    )
+
+    frozen = get_pdl_pattern_fold()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

>From f1315a6088031967a673406c239173333bd3103a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 20:56:58 +0800
Subject: [PATCH 2/2] fix style

---
 mlir/test/python/integration/dialects/pdl.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 8954e3622d3ef..64628db072be9 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -114,9 +114,9 @@ def load_myint_dialect():
     m.operation.verify()
     irdl.load_dialects(m)
 
-# this PDL pattern is to fold constant additions,
+# This PDL pattern is to fold constant additions,
 # i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1
+# where constant2 = constant0 + constant1.
 def get_pdl_pattern_fold():
     m = Module.create()
     with InsertionPoint(m.body):
@@ -139,12 +139,13 @@ def rew():
                 )
                 pdl.ReplaceOp(op0, with_op=newOp)
 
-    pdl_module = PDLModule(m)
     def add_fold(rewriter, results, values):
         a0, a1 = [i.get() for i in values]
         i32 = IntegerType.get_signless(32)
         results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
         return True
+
+    pdl_module = PDLModule(m)
     pdl_module.register_rewrite_function("add_fold", add_fold)
     return pdl_module.freeze()
 



More information about the Mlir-commits mailing list