[Mlir-commits] [mlir] [MLIR][Transform] friendlier Python-bindings apply_registered_pass op (PR #143159)

Rolf Morel llvmlistbot at llvm.org
Fri Jun 6 08:27:15 PDT 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/143159

>From 461d7cfaf359ee07f34c9b3eb91f402b97afe312 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:13:44 -0700
Subject: [PATCH 1/2] [MLIR][Transform] friendlier Python-bindings
 apply_registered_pass op

In particular, use similar syntax for providing options as in the
(pretty-)printed IR.
---
 .../mlir/dialects/transform/__init__.py       | 35 ++++++++++++++++++
 mlir/test/python/dialects/transform.py        | 36 +++++++++++++++++++
 2 files changed, 71 insertions(+)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+    def __init__(
+        self,
+        result: Type,
+        pass_name: Union[str, StringAttr],
+        target: Value,
+        *,
+        options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+        loc=None,
+        ip=None,
+    ):
+        static_options = []
+        dynamic_options = []
+        for opt in options:
+            if isinstance(opt, str):
+                static_options.append(StringAttr.get(opt))
+            elif isinstance(opt, StringAttr):
+                static_options.append(opt)
+            elif isinstance(opt, Value):
+                static_options.append(UnitAttr.get())
+                dynamic_options.append(_get_op_result_or_value(opt))
+            else:
+                raise TypeError(f"Unsupported option type: {type(opt)}")
+        super().__init__(
+            result,
+            pass_name,
+            dynamic_options,
+            target=_get_op_result_or_value(target),
+            options=static_options,
+            loc=loc,
+            ip=ip,
+        )
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
     # CHECK: %[[FIRST:.+]] = pdl_match
     # CHECK: %[[SECOND:.+]] = pdl_match
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+ at run
+def testApplyRegisteredPassOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+        )
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+        )
+        max_iter = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+        )
+        max_rewrites = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+        )
+        transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod,
+            options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+    # CHECK: transform.sequence
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
+    # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+    # CHECK-SAME:    with options = "top-down=false" %[[MAX_ITER]]
+    # CHECK-SAME:   "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

>From 18360e7f5279bc89d35cef81be22b579faf0fb28 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:26:39 -0700
Subject: [PATCH 2/2] snake_case_helper

---
 mlir/python/mlir/dialects/transform/__init__.py | 4 ++++
 mlir/test/python/dialects/transform.py          | 7 +++++--
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index cdcdeadd54cd3..90282df49fb7d 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -249,6 +249,10 @@ def __init__(
         )
 
 
+def apply_registered_pass(result, pass_name, target, *, options=[], loc=None, ip=None) -> Value:
+  return ApplyRegisteredPassOp(result=result, pass_name=pass_name, target=target, options=options, loc=loc, ip=ip).result
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index dc0987e769a09..6492b58570814 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -266,7 +266,10 @@ def testApplyRegisteredPassOp(module: Module):
             transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
         )
         mod = transform.ApplyRegisteredPassOp(
-            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod.result,
+            options=("top-down=false",),
         )
         max_iter = transform.param_constant(
             transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
@@ -274,7 +277,7 @@ def testApplyRegisteredPassOp(module: Module):
         max_rewrites = transform.param_constant(
             transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
         )
-        transform.ApplyRegisteredPassOp(
+        transform.apply_registered_pass(
             transform.AnyOpType.get(),
             "canonicalize",
             mod,



More information about the Mlir-commits mailing list