[Mlir-commits] [mlir] fb761aa - [MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion (#143779)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 11 13:19:55 PDT 2025
Author: Rolf Morel
Date: 2025-06-11T21:19:52+01:00
New Revision: fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4
URL: https://github.com/llvm/llvm-project/commit/fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4
DIFF: https://github.com/llvm/llvm-project/commit/fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4.diff
LOG: [MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion (#143779)
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/python/mlir/dialects/transform/__init__.py
mlir/test/Dialect/Transform/test-pass-application.mlir
mlir/test/python/dialects/transform.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index f75ba27e58e76..0aa750e625436 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];
- let arguments = (ins StrAttr:$pass_name,
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ StrAttr:$pass_name,
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
- Variadic<TransformParamTypeInterface>:$dynamic_options,
- TransformHandleTypeInterface:$target);
+ Variadic<TransformParamTypeInterface>:$dynamic_options);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
$pass_name (`with` `options` `=`
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 10a04b0cc14e0..bfe96b1b3e5d4 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
self,
result: Type,
- pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
+ pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView],
+ Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
loc=None,
@@ -253,17 +253,21 @@ def __init__(
cur_param_operand_idx += 1
elif isinstance(value, Attribute):
options_dict[key] = value
+ # The following cases auto-convert Python values to attributes.
+ elif isinstance(value, bool):
+ options_dict[key] = BoolAttr.get(value)
+ elif isinstance(value, int):
+ default_int_type = IntegerType.get_signless(64, context)
+ options_dict[key] = IntegerAttr.get(default_int_type, value)
elif isinstance(value, str):
options_dict[key] = StringAttr.get(value)
else:
raise TypeError(f"Unsupported option type: {type(value)}")
- if len(options_dict) > 0:
- print(options_dict, cur_param_operand_idx)
super().__init__(
result,
+ _get_op_result_or_value(target),
pass_name,
dynamic_options,
- target=_get_op_result_or_value(target),
options=DictAttr.get(options_dict),
loc=loc,
ip=ip,
@@ -272,13 +276,13 @@ def __init__(
def apply_registered_pass(
result: Type,
- pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
+ pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView],
+ Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
loc=None,
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 6e6d4eb7e249f..1d1be9eda3496 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
"test-convergence" = true,
"max-num-rewrites" = %max_rewrites }
to %1
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
transform.yield
}
}
@@ -171,7 +171,6 @@ func.func @invalid_options_as_str() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
// expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
@@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @+2 {{expected '{' in options dictionary}}
transform.apply_registered_pass "canonicalize"
with options = %pass_options to %1
- : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
transform.yield
}
}
@@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
transform.apply_registered_pass "canonicalize"
with options = { "top-down" = %topdown_options } to %1
- : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
transform.yield
}
}
@@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} {
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
- %2 = "transform.apply_registered_pass"(%1, %0) <{
+ %2 = "transform.apply_registered_pass"(%0, %1) <{
options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
- : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
@@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 0 is already used in options}}
- %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+ %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"max-num-rewrites" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
@@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
- %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+ %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 48bc9bad37a1e..eeb95605d7a9a 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module):
)
with InsertionPoint(sequence.body):
mod = transform.ApplyRegisteredPassOp(
- transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+ transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
)
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(),
- "canonicalize",
mod.result,
+ "canonicalize",
options={"top-down": BoolAttr.get(False)},
)
max_iter = transform.param_constant(
@@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module):
)
transform.apply_registered_pass(
transform.AnyOpType.get(),
- "canonicalize",
mod,
+ "canonicalize",
options={
"top-down": BoolAttr.get(False),
"max-iterations": max_iter,
- "test-convergence": BoolAttr.get(True),
+ "test-convergence": True,
"max-rewrites": max_rewrites,
},
)
@@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module):
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
# CHECK-SAME: "test-convergence" = true,
# CHECK-SAME: "top-down" = false}
- # CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
More information about the Mlir-commits
mailing list