[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL native function registering (PR #159926)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 20 07:08:32 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/159926
>From 118e4c7da941001295148006c3a45193ff078259 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 01:47:41 +0800
Subject: [PATCH 1/3] [MLIR][Python] Add bindings for PDL native function
registering
---
mlir/include/mlir-c/Rewrite.h | 32 +++++++
mlir/lib/Bindings/Python/Rewrite.cpp | 74 +++++++++++++--
mlir/lib/CAPI/Transforms/Rewrite.cpp | 99 ++++++++++++++++++++
mlir/test/python/integration/dialects/pdl.py | 82 ++++++++++++++++
4 files changed, 281 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 374d2fb78de88..c20558fc8f9d9 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+DEFINE_C_API_STRUCT(MlirPDLValue, const void);
+DEFINE_C_API_STRUCT(MlirPDLResultList, void);
MLIR_CAPI_EXPORTED MlirPDLPatternModule
mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
+
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
+ MlirType value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+ MlirOperation value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+ MlirAttribute value);
+
+typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
+ MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+ MlirPDLValue *values, void *userData);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
+ MlirPDLPatternModule module, MlirStringRef name,
+ MlirPDLRewriteFunction rewriteFn, void *userData);
+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
#undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5b7de50f02e6a..89dda560702ba 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,10 +9,13 @@
#include "Rewrite.h"
#include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir/Config/mlir-config.h"
+#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir;
@@ -36,6 +39,22 @@ class PyPDLPatternModule {
}
MlirPDLPatternModule get() { return module; }
+ void registerRewriteFunction(const std::string &name,
+ const nb::callable &fn) {
+ mlirPDLPatternModuleRegisterRewriteFunction(
+ get(), mlirStringRefCreate(name.data(), name.size()),
+ [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+ size_t nValues, MlirPDLValue *values,
+ void *userData) -> MlirLogicalResult {
+ auto f = nb::handle(static_cast<PyObject *>(userData));
+ auto valueVec = std::vector<MlirPDLValue>(values, values + nValues);
+ return nb::cast<bool>(f(rewriter, results, valueVec))
+ ? mlirLogicalResultSuccess()
+ : mlirLogicalResultFailure();
+ },
+ fn.ptr());
+ }
+
private:
MlirPDLPatternModule module;
};
@@ -76,10 +95,43 @@ class PyFrozenRewritePatternSet {
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+ nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
//----------------------------------------------------------------------------
- // Mapping of the top-level PassManager
+ // Mapping of the PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+ nb::class_<MlirPDLValue>(m, "PDLValue").def("get", [](MlirPDLValue value) {
+ if (mlirPDLValueIsValue(value)) {
+ return nb::cast(mlirPDLValueAsValue(value));
+ }
+ if (mlirPDLValueIsOperation(value)) {
+ return nb::cast(mlirPDLValueAsOperation(value));
+ }
+ if (mlirPDLValueIsAttribute(value)) {
+ return nb::cast(mlirPDLValueAsAttribute(value));
+ }
+ if (mlirPDLValueIsType(value)) {
+ return nb::cast(mlirPDLValueAsType(value));
+ }
+
+ throw std::runtime_error("unsupported PDL value type");
+ });
+ nb::class_<MlirPDLResultList>(m, "PDLResultList")
+ .def("push_back",
+ [](MlirPDLResultList results, const PyValue &value) {
+ mlirPDLResultListPushBackValue(results, value);
+ })
+ .def("push_back",
+ [](MlirPDLResultList results, const PyOperation &op) {
+ mlirPDLResultListPushBackOperation(results, op);
+ })
+ .def("push_back",
+ [](MlirPDLResultList results, const PyType &type) {
+ mlirPDLResultListPushBackType(results, type);
+ })
+ .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) {
+ mlirPDLResultListPushBackAttribute(results, attr);
+ });
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
@@ -88,10 +140,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
},
"module"_a, "Create a PDL module from the given module.")
- .def("freeze", [](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
- mlirRewritePatternSetFromPDLPatternModule(self.get())));
- });
+ .def(
+ "freeze",
+ [](PyPDLPatternModule &self) {
+ return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ mlirRewritePatternSetFromPDLPatternModule(self.get())));
+ },
+ nb::keep_alive<0, 1>())
+ .def(
+ "register_rewrite_function",
+ [](PyPDLPatternModule &self, const std::string &name,
+ const nb::callable &fn) {
+ self.registerRewriteFunction(name, fn);
+ },
+ nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a18..0033abde986ea 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
#include "mlir/CAPI/Rewrite.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+ assert(rewriter.ptr && "unexpected null rewriter");
+ return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+ return {rewriter};
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
@@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
op.ptr = nullptr;
return wrap(m);
}
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+ assert(value.ptr && "unexpected null PDL value");
+ return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+ assert(results.ptr && "unexpected null PDL results");
+ return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+ return {results};
+}
+
+bool mlirPDLValueIsValue(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Value>();
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Value>());
+}
+
+bool mlirPDLValueIsType(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Type>();
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Type>());
+}
+
+bool mlirPDLValueIsOperation(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Operation *>();
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Operation *>());
+}
+
+bool mlirPDLValueIsAttribute(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Attribute>();
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+ MlirValue value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+ MlirOperation value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+ MlirAttribute value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+ MlirPDLPatternModule module, MlirStringRef name,
+ MlirPDLRewriteFunction rewriteFn, void *userData) {
+ unwrap(module)->registerRewriteFunction(
+ unwrap(name),
+ [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ std::vector<MlirPDLValue> mlirValues;
+ for (auto &value : values) {
+ mlirValues.push_back(wrap(&value));
+ }
+ return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+ mlirValues.size(), mlirValues.data(),
+ userData));
+ });
+}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index dd6c74ce622c8..8954e3622d3ef 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -86,3 +86,85 @@ def add_func(a, b):
frozen = get_pdl_patterns()
apply_patterns_and_fold_greedily(module_.operation, frozen)
return module_
+
+
+# If we use arith.constant and arith.addi here,
+# these C++-defined folding/canonicalization will be applied
+# implicitly in the greedy pattern rewrite driver to
+# make our Python-defined folding useless,
+# so here we define a new dialect to workaround this.
+def load_myint_dialect():
+ from mlir.dialects import irdl
+ m = Module.create()
+ with InsertionPoint(m.body):
+ myint = irdl.dialect("myint")
+ with InsertionPoint(myint.body):
+ constant = irdl.operation_("constant")
+ with InsertionPoint(constant.body):
+ iattr = irdl.base(base_name="#builtin.integer")
+ i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+ irdl.attributes_([iattr], ["value"])
+ irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
+ add = irdl.operation_("add")
+ with InsertionPoint(add.body):
+ i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+ irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+ irdl.results_([i32], ["res"], [irdl.Variadicity.single])
+
+ m.operation.verify()
+ irdl.load_dialects(m)
+
+# this PDL pattern is to fold constant additions,
+# i.e. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1
+def get_pdl_pattern_fold():
+ m = Module.create()
+ with InsertionPoint(m.body):
+ @pdl.pattern(benefit=1, sym_name="myint_add_fold")
+ def pat():
+ t = pdl.TypeOp(IntegerType.get_signless(32))
+ a0 = pdl.AttributeOp()
+ a1 = pdl.AttributeOp()
+ c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
+ c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+ v0 = pdl.ResultOp(c0, 0)
+ v1 = pdl.ResultOp(c1, 0)
+ op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+ @pdl.rewrite()
+ def rew():
+ sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+ newOp = pdl.OperationOp(
+ name="myint.constant", attributes={"value": sum}, types=[t]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ pdl_module = PDLModule(m)
+ def add_fold(rewriter, results, values):
+ a0, a1 = [i.get() for i in values]
+ i32 = IntegerType.get_signless(32)
+ results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+ return True
+ pdl_module.register_rewrite_function("add_fold", add_fold)
+ return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function
+# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
+ at construct_and_print_in_module
+def test_pdl_register_function(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ %c0 = "myint.constant"() { value = 2 }: () -> (i32)
+ %c1 = "myint.constant"() { value = 3 }: () -> (i32)
+ %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+ "myint.add"(%x, %c1): (i32, i32) -> (i32)
+ """
+ )
+
+ frozen = get_pdl_pattern_fold()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_
>From f1315a6088031967a673406c239173333bd3103a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 20:56:58 +0800
Subject: [PATCH 2/3] fix style
---
mlir/test/python/integration/dialects/pdl.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 8954e3622d3ef..64628db072be9 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -114,9 +114,9 @@ def load_myint_dialect():
m.operation.verify()
irdl.load_dialects(m)
-# this PDL pattern is to fold constant additions,
+# This PDL pattern is to fold constant additions,
# i.e. add(constant0, constant1) -> constant2
-# where constant2 = constant0 + constant1
+# where constant2 = constant0 + constant1.
def get_pdl_pattern_fold():
m = Module.create()
with InsertionPoint(m.body):
@@ -139,12 +139,13 @@ def rew():
)
pdl.ReplaceOp(op0, with_op=newOp)
- pdl_module = PDLModule(m)
def add_fold(rewriter, results, values):
a0, a1 = [i.get() for i in values]
i32 = IntegerType.get_signless(32)
results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
return True
+
+ pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
return pdl_module.freeze()
>From 2f3da2e0c356721311d2828b1293c35204d2fd6e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 20 Sep 2025 22:08:15 +0800
Subject: [PATCH 3/3] format
---
mlir/test/python/integration/dialects/pdl.py | 21 ++++++++++++++++----
1 file changed, 17 insertions(+), 4 deletions(-)
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 64628db072be9..5b33802cefaba 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -95,6 +95,7 @@ def add_func(a, b):
# so here we define a new dialect to workaround this.
def load_myint_dialect():
from mlir.dialects import irdl
+
m = Module.create()
with InsertionPoint(m.body):
myint = irdl.dialect("myint")
@@ -108,32 +109,44 @@ def load_myint_dialect():
add = irdl.operation_("add")
with InsertionPoint(add.body):
i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
- irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single])
+ irdl.operands_(
+ [i32, i32],
+ ["lhs", "rhs"],
+ [irdl.Variadicity.single, irdl.Variadicity.single]
+ )
irdl.results_([i32], ["res"], [irdl.Variadicity.single])
m.operation.verify()
irdl.load_dialects(m)
+
# This PDL pattern is to fold constant additions,
# i.e. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1.
def get_pdl_pattern_fold():
m = Module.create()
with InsertionPoint(m.body):
+
@pdl.pattern(benefit=1, sym_name="myint_add_fold")
def pat():
t = pdl.TypeOp(IntegerType.get_signless(32))
a0 = pdl.AttributeOp()
a1 = pdl.AttributeOp()
- c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t])
- c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t])
+ c0 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": a0}, types=[t]
+ )
+ c1 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": a1}, types=[t]
+ )
v0 = pdl.ResultOp(c0, 0)
v1 = pdl.ResultOp(c1, 0)
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
@pdl.rewrite()
def rew():
- sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1])
+ sum = pdl.apply_native_rewrite(
+ [pdl.AttributeType.get()], "add_fold", [a0, a1]
+ )
newOp = pdl.OperationOp(
name="myint.constant", attributes={"value": sum}, types=[t]
)
More information about the Mlir-commits
mailing list