[Mlir-commits] [mlir] [MLIR][Python] Add bindings for PDL constraint function registering (PR #160520)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 24 19:09:27 PDT 2025


================
@@ -153,12 +153,43 @@ def rew():
                 )
                 pdl.ReplaceOp(op0, with_op=newOp)
 
+        @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
+        def pat():
+            t = pdl.TypeOp(i32)
+            v0 = pdl.OperandOp()
+            v1 = pdl.OperandOp()
+            v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
+            op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+            @pdl.rewrite()
+            def rew():
+                pdl.ReplaceOp(op0, with_values=[v])
+
     def add_fold(rewriter, results, values):
         a0, a1 = values
         results.append(IntegerAttr.get(i32, a0.value + a1.value))
 
+    def is_zero(value):
+        op = value.owner
+        if isinstance(op, Operation):
+            return op.name == "myint.constant" and op.attributes["value"].value == 0
+        return False
+
+    # Check if either operand is a constant zero,
+    # and append the other operand to the results if so.
+    def has_zero(rewriter, results, values):
+        v0, v1 = values
+        if is_zero(v0):
+            results.append(v1)
+            return False
+        if is_zero(v1):
+            results.append(v0)
----------------
PragmaTwice wrote:

Ahh sorry there's no type annotation so easy to get wrong here. `results` here is typed `MlirPDLResultList` and we expose an `append` method for this type (this is the only method of this type for now). And `values` is typed `std::vector<nb::object>` so it is a list of `Value`/`Attribute`/`Type`.. .

The `results` can be used in the `pdl.apply_native_constraint`, for example
```
%res = pdl.apply_native_constraint("some_constraint", %v1: pdl.value, %v2: pdl.value) -> pdl.value
```
Then the callable passed with `some_constraint` should looks like:
```
values = [v1, v2] # corresponding to argument %v1 and %v2
if not some_constraint(rewriter, results, values):
   assert(len(results) == 1) # results[0] corresponding to %res
```

https://github.com/llvm/llvm-project/pull/160520


More information about the Mlir-commits mailing list