[Mlir-commits] [mlir] [MLIR][Transform] friendlier Python-bindings apply_registered_pass op (PR #143159)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 6 08:23:02 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

In particular, use similar syntax for providing options as in the (pretty-)printed IR.

---
Full diff: https://github.com/llvm/llvm-project/pull/143159.diff


2 Files Affected:

- (modified) mlir/python/mlir/dialects/transform/__init__.py (+35) 
- (modified) mlir/test/python/dialects/transform.py (+36) 


``````````diff
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+    def __init__(
+        self,
+        result: Type,
+        pass_name: Union[str, StringAttr],
+        target: Value,
+        *,
+        options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+        loc=None,
+        ip=None,
+    ):
+        static_options = []
+        dynamic_options = []
+        for opt in options:
+            if isinstance(opt, str):
+                static_options.append(StringAttr.get(opt))
+            elif isinstance(opt, StringAttr):
+                static_options.append(opt)
+            elif isinstance(opt, Value):
+                static_options.append(UnitAttr.get())
+                dynamic_options.append(_get_op_result_or_value(opt))
+            else:
+                raise TypeError(f"Unsupported option type: {type(opt)}")
+        super().__init__(
+            result,
+            pass_name,
+            dynamic_options,
+            target=_get_op_result_or_value(target),
+            options=static_options,
+            loc=loc,
+            ip=ip,
+        )
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
     # CHECK: %[[FIRST:.+]] = pdl_match
     # CHECK: %[[SECOND:.+]] = pdl_match
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+ at run
+def testApplyRegisteredPassOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+        )
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+        )
+        max_iter = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+        )
+        max_rewrites = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+        )
+        transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod,
+            options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+    # CHECK: transform.sequence
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
+    # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+    # CHECK-SAME:    with options = "top-down=false" %[[MAX_ITER]]
+    # CHECK-SAME:   "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

``````````

</details>


https://github.com/llvm/llvm-project/pull/143159


More information about the Mlir-commits mailing list