[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL native function registering (PR #159926)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 20 07:08:32 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/159926

>From 118e4c7da941001295148006c3a45193ff078259 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 01:47:41 +0800
Subject: [PATCH 1/3] [MLIR][Python] Add bindings for PDL native function
 registering

---
 mlir/include/mlir-c/Rewrite.h                | 32 +++++++
 mlir/lib/Bindings/Python/Rewrite.cpp         | 74 +++++++++++++--
 mlir/lib/CAPI/Transforms/Rewrite.cpp         | 99 ++++++++++++++++++++
 mlir/test/python/integration/dialects/pdl.py | 82 ++++++++++++++++
 4 files changed, 281 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 374d2fb78de88..c20558fc8f9d9 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+DEFINE_C_API_STRUCT(MlirPDLValue, const void);
+DEFINE_C_API_STRUCT(MlirPDLResultList, void);
 
 MLIR_CAPI_EXPORTED MlirPDLPatternModule
 mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
 
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
+
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
+                                                      MlirType value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                   MlirOperation value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                   MlirAttribute value);
+
+typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
+    MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+    MlirPDLValue *values, void *userData);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData);
+
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
 #undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5b7de50f02e6a..89dda560702ba 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,10 +9,13 @@
 #include "Rewrite.h"
 
 #include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir-c/IR.h"
 #include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir/Config/mlir-config.h"
+#include "nanobind/nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -36,6 +39,22 @@ class PyPDLPatternModule {
   }
   MlirPDLPatternModule get() { return module; }
 
+  void registerRewriteFunction(const std::string &name,
+                               const nb::callable &fn) {
+    mlirPDLPatternModuleRegisterRewriteFunction(
+        get(), mlirStringRefCreate(name.data(), name.size()),
+        [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+           size_t nValues, MlirPDLValue *values,
+           void *userData) -> MlirLogicalResult {
+          auto f = nb::handle(static_cast<PyObject *>(userData));
+          auto valueVec = std::vector<MlirPDLValue>(values, values + nValues);
+          return nb::cast<bool>(f(rewriter, results, valueVec))
+                     ? mlirLogicalResultSuccess()
+                     : mlirLogicalResultFailure();
+        },
+        fn.ptr());
+  }
+
 private:
   MlirPDLPatternModule module;
 };
@@ -76,10 +95,43 @@ class PyFrozenRewritePatternSet {
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
-  // Mapping of the top-level PassManager
+  // Mapping of the PDLModule
   //----------------------------------------------------------------------------
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  nb::class_<MlirPDLValue>(m, "PDLValue").def("get", [](MlirPDLValue value) {
+    if (mlirPDLValueIsValue(value)) {
+      return nb::cast(mlirPDLValueAsValue(value));
+    }
+    if (mlirPDLValueIsOperation(value)) {
+      return nb::cast(mlirPDLValueAsOperation(value));
+    }
+    if (mlirPDLValueIsAttribute(value)) {
+      return nb::cast(mlirPDLValueAsAttribute(value));
+    }
+    if (mlirPDLValueIsType(value)) {
+      return nb::cast(mlirPDLValueAsType(value));
+    }
+
+    throw std::runtime_error("unsupported PDL value type");
+  });
+  nb::class_<MlirPDLResultList>(m, "PDLResultList")
+      .def("push_back",
+           [](MlirPDLResultList results, const PyValue &value) {
+             mlirPDLResultListPushBackValue(results, value);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyOperation &op) {
+             mlirPDLResultListPushBackOperation(results, op);
+           })
+      .def("push_back",
+           [](MlirPDLResultList results, const PyType &type) {
+             mlirPDLResultListPushBackType(results, type);
+           })
+      .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) {
+        mlirPDLResultListPushBackAttribute(results, attr);
+      });
   nb::class_<PyPDLPatternModule>(m, "PDLModule")
       .def(
           "__init__",
@@ -88,10 +140,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
                 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
           },
           "module"_a, "Create a PDL module from the given module.")
-      .def("freeze", [](PyPDLPatternModule &self) {
-        return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
-            mlirRewritePatternSetFromPDLPatternModule(self.get())));
-      });
+      .def(
+          "freeze",
+          [](PyPDLPatternModule &self) {
+            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+                mlirRewritePatternSetFromPDLPatternModule(self.get())));
+          },
+          nb::keep_alive<0, 1>())
+      .def(
+          "register_rewrite_function",
+          [](PyPDLPatternModule &self, const std::string &name,
+             const nb::callable &fn) {
+            self.registerRewriteFunction(name, fn);
+          },
+          nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a18..0033abde986ea 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
 #include "mlir/CAPI/Rewrite.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+  assert(rewriter.ptr && "unexpected null rewriter");
+  return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+  return {rewriter};
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
@@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
   op.ptr = nullptr;
   return wrap(m);
 }
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+  assert(value.ptr && "unexpected null PDL value");
+  return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+  assert(results.ptr && "unexpected null PDL results");
+  return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+  return {results};
+}
+
+bool mlirPDLValueIsValue(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Value>();
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Value>());
+}
+
+bool mlirPDLValueIsType(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Type>();
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Type>());
+}
+
+bool mlirPDLValueIsOperation(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Operation *>();
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Operation *>());
+}
+
+bool mlirPDLValueIsAttribute(MlirPDLValue value) {
+  return unwrap(value)->isa<mlir::Attribute>();
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+  return wrap(unwrap(value)->cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+                                    MlirValue value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+                                        MlirOperation value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+                                        MlirAttribute value) {
+  unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+    MlirPDLPatternModule module, MlirStringRef name,
+    MlirPDLRewriteFunction rewriteFn, void *userData) {
+  unwrap(module)->registerRewriteFunction(
+      unwrap(name),
+      [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+                            ArrayRef<PDLValue> values) -> LogicalResult {
+        std::vector<MlirPDLValue> mlirValues;
+        for (auto &value : values) {
+          mlirValues.push_back(wrap(&value));
+        }
+        return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+                                mlirValues.size(), mlirValues.data(),
+                                userData));
+      });
+}
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index dd6c74ce622c8..8954e3622d3ef 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -86,3 +86,85 @@ def add_func(a, b):
     frozen = get_pdl_patterns()
     apply_patterns_and_fold_greedily(module_.operation, frozen)
     return module_
+
+
+# If we use arith.constant and arith.addi here,
+# these C++-defined folding/canonicalization will be applied
+# implicitly in the greedy pattern rewrite driver to
+# make our Python-defined folding useless,
+# so here we define a new dialect to workaround this.
+def load_myint_dialect():
+    from mlir.dialects import irdl
+    m = Module.create()
+    with InsertionPoint(m.body):
+        myint = irdl.dialect("myint")
+        with InsertionPoint(myint.body):
+            constant = irdl.operation_("constant")
+            with InsertionPoint(constant.body):
+                iattr = irdl.base(base_name="#builtin.integer")
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.attributes_([iattr], ["value"])
+                irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
+            add = irdl.operation_("add")
+            with InsertionPoint(add.body):
+                i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+                irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+                irdl.results_([i32], ["res"], [irdl.Variadicity.single])
+
+    m.operation.verify()
+    irdl.load_dialects(m)
+
+# this PDL pattern is to fold constant additions,
+# i.e. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1
+def get_pdl_pattern_fold():
+    m = Module.create()
+    with InsertionPoint(m.body):
+        @pdl.pattern(benefit=1, sym_name="myint_add_fold")
+        def pat():
+            t = pdl.TypeOp(IntegerType.get_signless(32))
+            a0 = pdl.AttributeOp()
+            a1 = pdl.AttributeOp()
+            c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
+            c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+            v0 = pdl.ResultOp(c0, 0)
+            v1 = pdl.ResultOp(c1, 0)
+            op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+                newOp = pdl.OperationOp(
+                    name="myint.constant", attributes={"value": sum}, types=[t]
+                )
+                pdl.ReplaceOp(op0, with_op=newOp)
+
+    pdl_module = PDLModule(m)
+    def add_fold(rewriter, results, values):
+        a0, a1 = [i.get() for i in values]
+        i32 = IntegerType.get_signless(32)
+        results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+        return True
+    pdl_module.register_rewrite_function("add_fold", add_fold)
+    return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function
+# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
+ at construct_and_print_in_module
+def test_pdl_register_function(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        %c0 = "myint.constant"() { value = 2 }: () -> (i32)
+        %c1 = "myint.constant"() { value = 3 }: () -> (i32)
+        %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+        "myint.add"(%x, %c1): (i32, i32) -> (i32)
+        """
+    )
+
+    frozen = get_pdl_pattern_fold()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

>From f1315a6088031967a673406c239173333bd3103a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 20:56:58 +0800
Subject: [PATCH 2/3] fix style

---
 mlir/test/python/integration/dialects/pdl.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 8954e3622d3ef..64628db072be9 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -114,9 +114,9 @@ def load_myint_dialect():
     m.operation.verify()
     irdl.load_dialects(m)
 
-# this PDL pattern is to fold constant additions,
+# This PDL pattern is to fold constant additions,
 # i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1
+# where constant2 = constant0 + constant1.
 def get_pdl_pattern_fold():
     m = Module.create()
     with InsertionPoint(m.body):
@@ -139,12 +139,13 @@ def rew():
                 )
                 pdl.ReplaceOp(op0, with_op=newOp)
 
-    pdl_module = PDLModule(m)
     def add_fold(rewriter, results, values):
         a0, a1 = [i.get() for i in values]
         i32 = IntegerType.get_signless(32)
         results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
         return True
+
+    pdl_module = PDLModule(m)
     pdl_module.register_rewrite_function("add_fold", add_fold)
     return pdl_module.freeze()
 

>From 2f3da2e0c356721311d2828b1293c35204d2fd6e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 22:08:15 +0800
Subject: [PATCH 3/3] format

---
 mlir/test/python/integration/dialects/pdl.py | 21 ++++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 64628db072be9..5b33802cefaba 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -95,6 +95,7 @@ def add_func(a, b):
 # so here we define a new dialect to workaround this.
 def load_myint_dialect():
     from mlir.dialects import irdl
+
     m = Module.create()
     with InsertionPoint(m.body):
         myint = irdl.dialect("myint")
@@ -108,32 +109,44 @@ def load_myint_dialect():
             add = irdl.operation_("add")
             with InsertionPoint(add.body):
                 i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
-                irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+                irdl.operands_(
+                    [i32, i32],
+                    ["lhs", "rhs"],
+                    [irdl.Variadicity.single, irdl.Variadicity.single]
+                )
                 irdl.results_([i32], ["res"], [irdl.Variadicity.single])
 
     m.operation.verify()
     irdl.load_dialects(m)
 
+
 # This PDL pattern is to fold constant additions,
 # i.e. add(constant0, constant1) -> constant2
 # where constant2 = constant0 + constant1.
 def get_pdl_pattern_fold():
     m = Module.create()
     with InsertionPoint(m.body):
+
         @pdl.pattern(benefit=1, sym_name="myint_add_fold")
         def pat():
             t = pdl.TypeOp(IntegerType.get_signless(32))
             a0 = pdl.AttributeOp()
             a1 = pdl.AttributeOp()
-            c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
-            c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+            c0 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": a0}, types=[t]
+            )
+            c1 = pdl.OperationOp(
+                name="myint.constant", attributes={"value": a1}, types=[t]
+            )
             v0 = pdl.ResultOp(c0, 0)
             v1 = pdl.ResultOp(c1, 0)
             op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
 
             @pdl.rewrite()
             def rew():
-                sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+                sum = pdl.apply_native_rewrite(
+                    [pdl.AttributeType.get()], "add_fold", [a0, a1]
+                )
                 newOp = pdl.OperationOp(
                     name="myint.constant", attributes={"value": sum}, types=[t]
                 )



More information about the Mlir-commits mailing list