[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL constraint function registering (PR #160520)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 24 06:12:59 PDT 2025
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/160520
This is a follow-up to #159926.
That PR (#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`.
In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
>From 21953e841c67e5e129f1f902196c4ab660c093e9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 24 Sep 2025 21:03:46 +0800
Subject: [PATCH] [MLIR][Python] Add bindings for PDL constraint function
registering
---
mlir/include/mlir-c/Rewrite.h | 14 +++++
mlir/lib/Bindings/Python/Rewrite.cpp | 37 +++++++++++--
mlir/lib/CAPI/Transforms/Rewrite.cpp | 30 +++++++++--
mlir/test/python/integration/dialects/pdl.py | 56 ++++++++++++++++++++
4 files changed, 127 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index f4974348945c5..77be1f480eacf 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData);
+/// This function type is used as callbacks for PDL native constraint functions.
+/// Input values can be accessed by `values` with its size `nValues`;
+/// output values can be added into `results` by `mlirPDLResultListPushBack*`
+/// APIs. And the return value indicates whether the constraint holds.
+typedef MlirLogicalResult (*MlirPDLConstraintFunction)(
+ MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+ MlirPDLValue *values, void *userData);
+
+/// Register a constraint function into the given PDL pattern module.
+/// `userData` will be provided as an argument to the constraint function.
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(
+ MlirPDLPatternModule pdlModule, MlirStringRef name,
+ MlirPDLConstraintFunction constraintFn, void *userData);
+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
#undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c53c6cf0dab1e..20392b9002706 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) {
throw std::runtime_error("unsupported PDL value type");
}
+static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
+ MlirPDLValue *values) {
+ std::vector<nb::object> args;
+ args.reserve(nValues);
+ for (size_t i = 0; i < nValues; ++i)
+ args.push_back(objectFromPDLValue(values[i]));
+ return args;
+}
+
// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
@@ -74,11 +83,22 @@ class PyPDLPatternModule {
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
- std::vector<nb::object> args;
- args.reserve(nValues);
- for (size_t i = 0; i < nValues; ++i)
- args.push_back(objectFromPDLValue(values[i]));
- return logicalResultFromObject(f(rewriter, results, args));
+ return logicalResultFromObject(
+ f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ },
+ fn.ptr());
+ }
+
+ void registerConstraintFunction(const std::string &name,
+ const nb::callable &fn) {
+ mlirPDLPatternModuleRegisterConstraintFunction(
+ get(), mlirStringRefCreate(name.data(), name.size()),
+ [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+ size_t nValues, MlirPDLValue *values,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f = nb::handle(static_cast<PyObject *>(userData));
+ return logicalResultFromObject(
+ f(rewriter, results, objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
const nb::callable &fn) {
self.registerRewriteFunction(name, fn);
},
+ nb::keep_alive<1, 3>())
+ .def(
+ "register_constraint_function",
+ [](PyPDLPatternModule &self, const std::string &name,
+ const nb::callable &fn) {
+ self.registerConstraintFunction(name, fn);
+ },
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 9ecce956a05b9..8ee6308cadf83 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -398,6 +398,15 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
unwrap(results)->push_back(unwrap(value));
}
+inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
+ std::vector<MlirPDLValue> mlirValues;
+ mlirValues.reserve(values.size());
+ for (auto &value : values) {
+ mlirValues.push_back(wrap(&value));
+ }
+ return mlirValues;
+}
+
void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData) {
@@ -405,14 +414,25 @@ void mlirPDLPatternModuleRegisterRewriteFunction(
unwrap(name),
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) -> LogicalResult {
- std::vector<MlirPDLValue> mlirValues;
- mlirValues.reserve(values.size());
- for (auto &value : values) {
- mlirValues.push_back(wrap(&value));
- }
+ std::vector<MlirPDLValue> mlirValues = wrap(values);
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
mlirValues.size(), mlirValues.data(),
userData));
});
}
+
+void mlirPDLPatternModuleRegisterConstraintFunction(
+ MlirPDLPatternModule pdlModule, MlirStringRef name,
+ MlirPDLConstraintFunction constraintFn, void *userData) {
+ unwrap(pdlModule)->registerConstraintFunction(
+ unwrap(name),
+ [userData, constraintFn](PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ std::vector<MlirPDLValue> mlirValues = wrap(values);
+ return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
+ mlirValues.size(), mlirValues.data(),
+ userData));
+ });
+}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index e85c6c77ef955..c8e6197e03842 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -153,12 +153,43 @@ def rew():
)
pdl.ReplaceOp(op0, with_op=newOp)
+ @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
+ def pat():
+ t = pdl.TypeOp(i32)
+ v0 = pdl.OperandOp()
+ v1 = pdl.OperandOp()
+ v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
+ op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+ @pdl.rewrite()
+ def rew():
+ pdl.ReplaceOp(op0, with_values=[v])
+
def add_fold(rewriter, results, values):
a0, a1 = values
results.append(IntegerAttr.get(i32, a0.value + a1.value))
+ def is_zero(value):
+ op = value.owner
+ if isinstance(op, Operation):
+ return op.name == "myint.constant" and op.attributes["value"].value == 0
+ return False
+
+ # Check if either operand is a constant zero,
+ # and append the other operand to the results if so.
+ def has_zero(rewriter, results, values):
+ v0, v1 = values
+ if is_zero(v0):
+ results.append(v1)
+ return False
+ if is_zero(v1):
+ results.append(v0)
+ return False
+ return True
+
pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
+ pdl_module.register_constraint_function("has_zero", has_zero)
return pdl_module.freeze()
@@ -181,3 +212,28 @@ def test_pdl_register_function(module_):
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_constraint
+# CHECK: return %arg0 : i32
+ at construct_and_print_in_module
+def test_pdl_register_function_constraint(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ func.func @f(%x : i32) -> i32 {
+ %c0 = "myint.constant"() { value = 1 }: () -> (i32)
+ %c1 = "myint.constant"() { value = -1 }: () -> (i32)
+ %a = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+ %b = "myint.add"(%a, %x): (i32, i32) -> (i32)
+ %c = "myint.add"(%b, %a): (i32, i32) -> (i32)
+ func.return %c : i32
+ }
+ """
+ )
+
+ frozen = get_pdl_pattern_fold()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_
More information about the Mlir-commits
mailing list