[Mlir-commits] [mlir] fb761aa - [MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion (#143779)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 11 13:19:55 PDT 2025


Author: Rolf Morel
Date: 2025-06-11T21:19:52+01:00
New Revision: fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4

URL: https://github.com/llvm/llvm-project/commit/fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4
DIFF: https://github.com/llvm/llvm-project/commit/fb761aa38b0bc01ab911f5dbbfb474b70aaafbb4.diff

LOG: [MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion (#143779)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/test/Dialect/Transform/test-pass-application.mlir
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index f75ba27e58e76..0aa750e625436 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     of targeted ops.
   }];
 
-  let arguments = (ins StrAttr:$pass_name,
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       StrAttr:$pass_name,
                        DefaultValuedAttr<DictionaryAttr, "{}">:$options,
-                       Variadic<TransformParamTypeInterface>:$dynamic_options,
-                       TransformHandleTypeInterface:$target);
+                       Variadic<TransformParamTypeInterface>:$dynamic_options);
   let results = (outs TransformHandleTypeInterface:$result);
   let assemblyFormat = [{
     $pass_name (`with` `options` `=`

diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 10a04b0cc14e0..bfe96b1b3e5d4 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
     def __init__(
         self,
         result: Type,
-        pass_name: Union[str, StringAttr],
         target: Union[Operation, Value, OpView],
+        pass_name: Union[str, StringAttr],
         *,
         options: Optional[
             Dict[
                 Union[str, StringAttr],
-                Union[Attribute, Value, Operation, OpView],
+                Union[Attribute, Value, Operation, OpView, str, int, bool],
             ]
         ] = None,
         loc=None,
@@ -253,17 +253,21 @@ def __init__(
                 cur_param_operand_idx += 1
             elif isinstance(value, Attribute):
                 options_dict[key] = value
+            # The following cases auto-convert Python values to attributes.
+            elif isinstance(value, bool):
+                options_dict[key] = BoolAttr.get(value)
+            elif isinstance(value, int):
+                default_int_type = IntegerType.get_signless(64, context)
+                options_dict[key] = IntegerAttr.get(default_int_type, value)
             elif isinstance(value, str):
                 options_dict[key] = StringAttr.get(value)
             else:
                 raise TypeError(f"Unsupported option type: {type(value)}")
-        if len(options_dict) > 0:
-            print(options_dict, cur_param_operand_idx)
         super().__init__(
             result,
+            _get_op_result_or_value(target),
             pass_name,
             dynamic_options,
-            target=_get_op_result_or_value(target),
             options=DictAttr.get(options_dict),
             loc=loc,
             ip=ip,
@@ -272,13 +276,13 @@ def __init__(
 
 def apply_registered_pass(
     result: Type,
-    pass_name: Union[str, StringAttr],
     target: Union[Operation, Value, OpView],
+    pass_name: Union[str, StringAttr],
     *,
     options: Optional[
         Dict[
             Union[str, StringAttr],
-            Union[Attribute, Value, Operation, OpView],
+            Union[Attribute, Value, Operation, OpView, str, int, bool],
         ]
     ] = None,
     loc=None,

diff  --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 6e6d4eb7e249f..1d1be9eda3496 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
                          "test-convergence" = true,
                          "max-num-rewrites" =  %max_rewrites }
         to %1
-        : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+        : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
     transform.yield
   }
 }
@@ -171,7 +171,6 @@ func.func @invalid_options_as_str() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
     // expected-error @+2 {{expected '{' in options dictionary}}
     %2 = transform.apply_registered_pass "canonicalize"
         with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
@@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} {
     // expected-error @+2 {{expected '{' in options dictionary}}
     transform.apply_registered_pass "canonicalize"
         with options = %pass_options to %1
-        : (!transform.any_param, !transform.any_op) -> !transform.any_op
+        : (!transform.any_op, !transform.any_param) -> !transform.any_op
     transform.yield
   }
 }
@@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} {
     // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
     transform.apply_registered_pass "canonicalize"
         with options = { "top-down" = %topdown_options } to %1
-        : (!transform.any_param, !transform.any_op) -> !transform.any_op
+        : (!transform.any_op, !transform.any_param) -> !transform.any_op
     transform.yield
   }
 }
@@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} {
     %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
-    %2 = "transform.apply_registered_pass"(%1, %0) <{
+    %2 = "transform.apply_registered_pass"(%0, %1) <{
       options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
-    : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    : (!transform.any_op, !transform.any_param) -> !transform.any_op
     "transform.yield"() : () -> ()
   }) : () -> ()
 }) {transform.with_named_sequence} : () -> ()
@@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} {
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
     // expected-error @below {{dynamic option index 0 is already used in options}}
-    %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+    %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
       options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
                  "max-num-rewrites" = #transform.param_operand<index=0 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
-    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
     "transform.yield"() : () -> ()
   }) : () -> ()
 }) {transform.with_named_sequence} : () -> ()
@@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} {
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
     // expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
-    %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+    %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
       options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
-    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
     "transform.yield"() : () -> ()
   }) : () -> ()
 }) {transform.with_named_sequence} : () -> ()

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 48bc9bad37a1e..eeb95605d7a9a 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module):
     )
     with InsertionPoint(sequence.body):
         mod = transform.ApplyRegisteredPassOp(
-            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+            transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
         )
         mod = transform.ApplyRegisteredPassOp(
             transform.AnyOpType.get(),
-            "canonicalize",
             mod.result,
+            "canonicalize",
             options={"top-down": BoolAttr.get(False)},
         )
         max_iter = transform.param_constant(
@@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module):
         )
         transform.apply_registered_pass(
             transform.AnyOpType.get(),
-            "canonicalize",
             mod,
+            "canonicalize",
             options={
                 "top-down": BoolAttr.get(False),
                 "max-iterations": max_iter,
-                "test-convergence": BoolAttr.get(True),
+                "test-convergence": True,
                 "max-rewrites": max_rewrites,
             },
         )
@@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module):
     # CHECK-SAME:                    "max-rewrites" =  %[[MAX_REWRITE]],
     # CHECK-SAME:                    "test-convergence" = true,
     # CHECK-SAME:                    "top-down" = false}
-    # CHECK-SAME:    to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    # CHECK-SAME:    to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op


        


More information about the Mlir-commits mailing list