[Mlir-commits] [mlir] [MLIR][Transform] friendlier Python-bindings apply_registered_pass op (PR #143159)
Rolf Morel
llvmlistbot at llvm.org
Fri Jun 6 08:27:15 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/143159
>From 461d7cfaf359ee07f34c9b3eb91f402b97afe312 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:13:44 -0700
Subject: [PATCH 1/2] [MLIR][Transform] friendlier Python-bindings
apply_registered_pass op
In particular, use similar syntax for providing options as in the
(pretty-)printed IR.
---
.../mlir/dialects/transform/__init__.py | 35 ++++++++++++++++++
mlir/test/python/dialects/transform.py | 36 +++++++++++++++++++
2 files changed, 71 insertions(+)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+ def __init__(
+ self,
+ result: Type,
+ pass_name: Union[str, StringAttr],
+ target: Value,
+ *,
+ options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+ loc=None,
+ ip=None,
+ ):
+ static_options = []
+ dynamic_options = []
+ for opt in options:
+ if isinstance(opt, str):
+ static_options.append(StringAttr.get(opt))
+ elif isinstance(opt, StringAttr):
+ static_options.append(opt)
+ elif isinstance(opt, Value):
+ static_options.append(UnitAttr.get())
+ dynamic_options.append(_get_op_result_or_value(opt))
+ else:
+ raise TypeError(f"Unsupported option type: {type(opt)}")
+ super().__init__(
+ result,
+ pass_name,
+ dynamic_options,
+ target=_get_op_result_or_value(target),
+ options=static_options,
+ loc=loc,
+ ip=ip,
+ )
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+ at run
+def testApplyRegisteredPassOp(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+ )
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+ )
+ max_iter = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+ )
+ max_rewrites = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+ )
+ transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(),
+ "canonicalize",
+ mod,
+ options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+ # CHECK: transform.sequence
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[MAX_ITER:.+]] = transform.param.constant
+ # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # CHECK-SAME: with options = "top-down=false" %[[MAX_ITER]]
+ # CHECK-SAME: "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
>From 18360e7f5279bc89d35cef81be22b579faf0fb28 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:26:39 -0700
Subject: [PATCH 2/2] snake_case_helper
---
mlir/python/mlir/dialects/transform/__init__.py | 4 ++++
mlir/test/python/dialects/transform.py | 7 +++++--
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index cdcdeadd54cd3..90282df49fb7d 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -249,6 +249,10 @@ def __init__(
)
+def apply_registered_pass(result, pass_name, target, *, options=[], loc=None, ip=None) -> Value:
+ return ApplyRegisteredPassOp(result=result, pass_name=pass_name, target=target, options=options, loc=loc, ip=ip).result
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index dc0987e769a09..6492b58570814 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -266,7 +266,10 @@ def testApplyRegisteredPassOp(module: Module):
transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
)
mod = transform.ApplyRegisteredPassOp(
- transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+ transform.AnyOpType.get(),
+ "canonicalize",
+ mod.result,
+ options=("top-down=false",),
)
max_iter = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
@@ -274,7 +277,7 @@ def testApplyRegisteredPassOp(module: Module):
max_rewrites = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
)
- transform.ApplyRegisteredPassOp(
+ transform.apply_registered_pass(
transform.AnyOpType.get(),
"canonicalize",
mod,
More information about the Mlir-commits
mailing list