[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL constraint function registering (PR #160520)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 24 06:13:34 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

This is a follow-up to #<!-- -->159926.

That PR (#<!-- -->159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`.

In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.


---
Full diff: https://github.com/llvm/llvm-project/pull/160520.diff


4 Files Affected:

- (modified) mlir/include/mlir-c/Rewrite.h (+14) 
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+32-5) 
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+25-5) 
- (modified) mlir/test/python/integration/dialects/pdl.py (+56) 


``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index f4974348945c5..77be1f480eacf 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
     MlirPDLPatternModule pdlModule, MlirStringRef name,
     MlirPDLRewriteFunction rewriteFn, void *userData);
 
+/// This function type is used as callbacks for PDL native constraint functions.
+/// Input values can be accessed by `values` with its size `nValues`;
+/// output values can be added into `results` by `mlirPDLResultListPushBack*`
+/// APIs. And the return value indicates whether the constraint holds.
+typedef MlirLogicalResult (*MlirPDLConstraintFunction)(
+    MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+    MlirPDLValue *values, void *userData);
+
+/// Register a constraint function into the given PDL pattern module.
+/// `userData` will be provided as an argument to the constraint function.
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(
+    MlirPDLPatternModule pdlModule, MlirStringRef name,
+    MlirPDLConstraintFunction constraintFn, void *userData);
+
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
 #undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c53c6cf0dab1e..20392b9002706 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) {
   throw std::runtime_error("unsupported PDL value type");
 }
 
+static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
+                                                    MlirPDLValue *values) {
+  std::vector<nb::object> args;
+  args.reserve(nValues);
+  for (size_t i = 0; i < nValues; ++i)
+    args.push_back(objectFromPDLValue(values[i]));
+  return args;
+}
+
 // Convert the Python object to a boolean.
 // If it evaluates to False, treat it as success;
 // otherwise, treat it as failure.
@@ -74,11 +83,22 @@ class PyPDLPatternModule {
            size_t nValues, MlirPDLValue *values,
            void *userData) -> MlirLogicalResult {
           nb::handle f = nb::handle(static_cast<PyObject *>(userData));
-          std::vector<nb::object> args;
-          args.reserve(nValues);
-          for (size_t i = 0; i < nValues; ++i)
-            args.push_back(objectFromPDLValue(values[i]));
-          return logicalResultFromObject(f(rewriter, results, args));
+          return logicalResultFromObject(
+              f(rewriter, results, objectsFromPDLValues(nValues, values)));
+        },
+        fn.ptr());
+  }
+
+  void registerConstraintFunction(const std::string &name,
+                                  const nb::callable &fn) {
+    mlirPDLPatternModuleRegisterConstraintFunction(
+        get(), mlirStringRefCreate(name.data(), name.size()),
+        [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+           size_t nValues, MlirPDLValue *values,
+           void *userData) -> MlirLogicalResult {
+          nb::handle f = nb::handle(static_cast<PyObject *>(userData));
+          return logicalResultFromObject(
+              f(rewriter, results, objectsFromPDLValues(nValues, values)));
         },
         fn.ptr());
   }
@@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
              const nb::callable &fn) {
             self.registerRewriteFunction(name, fn);
           },
+          nb::keep_alive<1, 3>())
+      .def(
+          "register_constraint_function",
+          [](PyPDLPatternModule &self, const std::string &name,
+             const nb::callable &fn) {
+            self.registerConstraintFunction(name, fn);
+          },
           nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 9ecce956a05b9..8ee6308cadf83 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -398,6 +398,15 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
   unwrap(results)->push_back(unwrap(value));
 }
 
+inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
+  std::vector<MlirPDLValue> mlirValues;
+  mlirValues.reserve(values.size());
+  for (auto &value : values) {
+    mlirValues.push_back(wrap(&value));
+  }
+  return mlirValues;
+}
+
 void mlirPDLPatternModuleRegisterRewriteFunction(
     MlirPDLPatternModule pdlModule, MlirStringRef name,
     MlirPDLRewriteFunction rewriteFn, void *userData) {
@@ -405,14 +414,25 @@ void mlirPDLPatternModuleRegisterRewriteFunction(
       unwrap(name),
       [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
                             ArrayRef<PDLValue> values) -> LogicalResult {
-        std::vector<MlirPDLValue> mlirValues;
-        mlirValues.reserve(values.size());
-        for (auto &value : values) {
-          mlirValues.push_back(wrap(&value));
-        }
+        std::vector<MlirPDLValue> mlirValues = wrap(values);
         return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
                                 mlirValues.size(), mlirValues.data(),
                                 userData));
       });
 }
+
+void mlirPDLPatternModuleRegisterConstraintFunction(
+    MlirPDLPatternModule pdlModule, MlirStringRef name,
+    MlirPDLConstraintFunction constraintFn, void *userData) {
+  unwrap(pdlModule)->registerConstraintFunction(
+      unwrap(name),
+      [userData, constraintFn](PatternRewriter &rewriter,
+                               PDLResultList &results,
+                               ArrayRef<PDLValue> values) -> LogicalResult {
+        std::vector<MlirPDLValue> mlirValues = wrap(values);
+        return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
+                                   mlirValues.size(), mlirValues.data(),
+                                   userData));
+      });
+}
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index e85c6c77ef955..c8e6197e03842 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -153,12 +153,43 @@ def rew():
                 )
                 pdl.ReplaceOp(op0, with_op=newOp)
 
+        @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
+        def pat():
+            t = pdl.TypeOp(i32)
+            v0 = pdl.OperandOp()
+            v1 = pdl.OperandOp()
+            v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
+            op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                pdl.ReplaceOp(op0, with_values=[v])
+
     def add_fold(rewriter, results, values):
         a0, a1 = values
         results.append(IntegerAttr.get(i32, a0.value + a1.value))
 
+    def is_zero(value):
+        op = value.owner
+        if isinstance(op, Operation):
+            return op.name == "myint.constant" and op.attributes["value"].value == 0
+        return False
+
+    # Check if either operand is a constant zero,
+    # and append the other operand to the results if so.
+    def has_zero(rewriter, results, values):
+        v0, v1 = values
+        if is_zero(v0):
+            results.append(v1)
+            return False
+        if is_zero(v1):
+            results.append(v0)
+            return False
+        return True
+
     pdl_module = PDLModule(m)
     pdl_module.register_rewrite_function("add_fold", add_fold)
+    pdl_module.register_constraint_function("has_zero", has_zero)
     return pdl_module.freeze()
 
 
@@ -181,3 +212,28 @@ def test_pdl_register_function(module_):
     apply_patterns_and_fold_greedily(module_, frozen)
 
     return module_
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_constraint
+# CHECK: return %arg0 : i32
+ at construct_and_print_in_module
+def test_pdl_register_function_constraint(module_):
+    load_myint_dialect()
+
+    module_ = Module.parse(
+        """
+        func.func @f(%x : i32) -> i32 {
+            %c0 = "myint.constant"() { value = 1 }: () -> (i32)
+            %c1 = "myint.constant"() { value = -1 }: () -> (i32)
+            %a = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+            %b = "myint.add"(%a, %x): (i32, i32) -> (i32)
+            %c = "myint.add"(%b, %a): (i32, i32) -> (i32)
+            func.return %c : i32
+        }
+        """
+    )
+
+    frozen = get_pdl_pattern_fold()
+    apply_patterns_and_fold_greedily(module_, frozen)
+
+    return module_

``````````

</details>


https://github.com/llvm/llvm-project/pull/160520


More information about the Mlir-commits mailing list