[Mlir-commits] [mlir] d78e0de - [MLIR][Transform][Python] Sync derived classes and their wrappers (#166871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 7 06:04:57 PST 2025


Author: Rolf Morel
Date: 2025-11-07T14:04:53Z
New Revision: d78e0ded5215824a63ac04fb87effd9eacf875eb

URL: https://github.com/llvm/llvm-project/commit/d78e0ded5215824a63ac04fb87effd9eacf875eb
DIFF: https://github.com/llvm/llvm-project/commit/d78e0ded5215824a63ac04fb87effd9eacf875eb.diff

LOG: [MLIR][Transform][Python] Sync derived classes and their wrappers (#166871)

Updates the derived Op-classes for the main transform ops to have all
the arguments, etc, from the auto-generated classes. Additionally
updates and adds missing snake_case wrappers for the derived classes
which shadow the snake_case wrappers of the auto-generated classes,
which were hitherto exposed alongside the derived classes.

Added: 
    

Modified: 
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b075919d1ef0f..de414dc52c0a0 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -39,16 +39,32 @@ def __init__(
         super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
+def cast(
+    result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+) -> OpResult:
+    return CastOp(result_type=result_type, target=target, loc=loc, ip=ip).result
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ApplyPatternsOp(ApplyPatternsOp):
     def __init__(
         self,
         target: Union[Operation, Value, OpView],
+        apply_cse: bool = False,
+        max_iterations: Optional[Union[IntegerAttr, int]] = None,
+        max_num_rewrites: Optional[Union[IntegerAttr, int]] = None,
         *,
         loc=None,
         ip=None,
     ):
-        super().__init__(target, loc=loc, ip=ip)
+        super().__init__(
+            target,
+            apply_cse=apply_cse,
+            max_iterations=max_iterations,
+            max_num_rewrites=max_num_rewrites,
+            loc=loc,
+            ip=ip,
+        )
         self.regions[0].blocks.append()
 
     @property
@@ -56,6 +72,25 @@ def patterns(self) -> Block:
         return self.regions[0].blocks[0]
 
 
+def apply_patterns(
+    target: Union[Operation, Value, OpView],
+    apply_cse: bool = False,
+    max_iterations: Optional[Union[IntegerAttr, int]] = None,
+    max_num_rewrites: Optional[Union[IntegerAttr, int]] = None,
+    *,
+    loc=None,
+    ip=None,
+) -> ApplyPatternsOp:
+    return ApplyPatternsOp(
+        target=target,
+        apply_cse=apply_cse,
+        max_iterations=max_iterations,
+        max_num_rewrites=max_num_rewrites,
+        loc=loc,
+        ip=ip,
+    )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GetParentOp(GetParentOp):
     def __init__(
@@ -64,6 +99,7 @@ def __init__(
         target: Union[Operation, Value],
         *,
         isolated_from_above: bool = False,
+        allow_empty_results: bool = False,
         op_name: Optional[str] = None,
         deduplicate: bool = False,
         nth_parent: int = 1,
@@ -74,6 +110,7 @@ def __init__(
             result_type,
             _get_op_result_or_value(target),
             isolated_from_above=isolated_from_above,
+            allow_empty_results=allow_empty_results,
             op_name=op_name,
             deduplicate=deduplicate,
             nth_parent=nth_parent,
@@ -82,6 +119,31 @@ def __init__(
         )
 
 
+def get_parent_op(
+    result_type: Type,
+    target: Union[Operation, Value],
+    *,
+    isolated_from_above: bool = False,
+    allow_empty_results: bool = False,
+    op_name: Optional[str] = None,
+    deduplicate: bool = False,
+    nth_parent: int = 1,
+    loc=None,
+    ip=None,
+) -> OpResult:
+    return GetParentOp(
+        result_type=result_type,
+        target=target,
+        isolated_from_above=isolated_from_above,
+        allow_empty_results=allow_empty_results,
+        op_name=op_name,
+        deduplicate=deduplicate,
+        nth_parent=nth_parent,
+        loc=loc,
+        ip=ip,
+    ).result
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class MergeHandlesOp(MergeHandlesOp):
     def __init__(
@@ -89,17 +151,32 @@ def __init__(
         handles: Sequence[Union[Operation, Value]],
         *,
         deduplicate: bool = False,
+        results: Optional[Sequence[Type]] = None,
         loc=None,
         ip=None,
     ):
         super().__init__(
             [_get_op_result_or_value(h) for h in handles],
             deduplicate=deduplicate,
+            results=results,
             loc=loc,
             ip=ip,
         )
 
 
+def merge_handles(
+    handles: Sequence[Union[Operation, Value]],
+    *,
+    deduplicate: bool = False,
+    results: Optional[Sequence[Type]] = None,
+    loc=None,
+    ip=None,
+) -> OpResult:
+    return MergeHandlesOp(
+        handles=handles, deduplicate=deduplicate, results=results, loc=loc, ip=ip
+    ).result
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ReplicateOp(ReplicateOp):
     def __init__(
@@ -119,16 +196,31 @@ def __init__(
         )
 
 
+def replicate(
+    pattern: Union[Operation, Value],
+    handles: Sequence[Union[Operation, Value]],
+    *,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, ReplicateOp]:
+    op = ReplicateOp(pattern=pattern, handles=handles, loc=loc, ip=ip)
+    results = op.results
+    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class SequenceOp(SequenceOp):
     def __init__(
         self,
-        failure_propagation_mode,
+        failure_propagation_mode: FailurePropagationMode,
         results: Sequence[Type],
         target: Union[Operation, Value, Type],
         extra_bindings: Optional[
             Union[Sequence[Value], Sequence[Type], Operation, OpView]
         ] = None,
+        *,
+        loc=None,
+        ip=None,
     ):
         root = (
             _get_op_result_or_value(target)
@@ -155,6 +247,8 @@ def __init__(
             failure_propagation_mode=failure_propagation_mode,
             root=root,
             extra_bindings=extra_bindings,
+            loc=loc,
+            ip=ip,
         )
         self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
 
@@ -171,16 +265,42 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
+def sequence(
+    failure_propagation_mode: FailurePropagationMode,
+    results: Sequence[Type],
+    target: Union[Operation, Value, Type],
+    extra_bindings: Optional[
+        Union[Sequence[Value], Sequence[Type], Operation, OpView]
+    ] = None,
+    *,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, SequenceOp]:
+    op = SequenceOp(
+        results=results,
+        failure_propagation_mode=failure_propagation_mode,
+        extra_bindings=extra_bindings,
+        target=target,
+        loc=loc,
+        ip=ip,
+    )
+    results = op.results
+    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class NamedSequenceOp(NamedSequenceOp):
     def __init__(
         self,
-        sym_name,
+        sym_name: Union[str, SymbolRefAttr],
         input_types: Sequence[Type],
         result_types: Sequence[Type],
-        sym_visibility=None,
-        arg_attrs=None,
-        res_attrs=None,
+        *,
+        sym_visibility: Optional[Union[str, StringAttr]] = None,
+        arg_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
+        res_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
+        loc=None,
+        ip=None,
     ):
         function_type = FunctionType.get(input_types, result_types)
         super().__init__(
@@ -205,6 +325,29 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
+def named_sequence(
+    sym_name: Union[str, SymbolRefAttr],
+    input_types: Sequence[Type],
+    result_types: Sequence[Type],
+    *,
+    sym_visibility: Optional[Union[str, StringAttr]] = None,
+    arg_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
+    res_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
+    loc=None,
+    ip=None,
+) -> NamedSequenceOp:
+    return NamedSequenceOp(
+        sym_name=sym_name,
+        input_types=input_types,
+        result_types=result_types,
+        sym_visibility=sym_visibility,
+        arg_attrs=arg_attrs,
+        res_attrs=res_attrs,
+        loc=loc,
+        ip=ip,
+    )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(
@@ -219,6 +362,12 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+def yield_(
+    operands: Optional[Union[Operation, Sequence[Value]]] = None, *, loc=None, ip=None
+) -> YieldOp:
+    return YieldOp(operands=operands, loc=loc, ip=ip)
+
+
 OptionValueTypes = Union[
     Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
 ]
@@ -247,7 +396,7 @@ def __init__(
         def option_value_to_attr(value):
             nonlocal cur_param_operand_idx
             if isinstance(value, (Value, Operation, OpView)):
-                dynamic_options.append(_get_op_result_or_value(value))
+                dynamic_options.append(value)
                 cur_param_operand_idx += 1
                 return ParamOperandAttr(cur_param_operand_idx - 1, context)
             elif isinstance(value, Attribute):

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6c5e4e5505b1c..f58442d04fc66 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -43,6 +43,26 @@ def testTypes(module: Module):
     print(param.type)
 
 
+ at run
+def testSequenceOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [transform.AnyOpType.get()],
+        transform.AnyOpType.get(),
+    )
+    with InsertionPoint(sequence.body):
+        res = transform.CastOp(transform.AnyOpType.get(), sequence.bodyTarget)
+        res2 = transform.cast(transform.any_op_t(), res.result)
+        transform.YieldOp([res2])
+    # CHECK-LABEL: TEST: testSequenceOp
+    # CHECK: transform.sequence
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   %[[RES:.+]] = cast %[[ARG0]] : !transform.any_op to !transform.any_op
+    # CHECK:   %[[RES2:.+]] = cast %[[RES]] : !transform.any_op to !transform.any_op
+    # CHECK:   yield %[[RES2]] : !transform.any_op
+    # CHECK: }
+
+
 @run
 def testSequenceOp(module: Module):
     sequence = transform.SequenceOp(
@@ -58,6 +78,7 @@ def testSequenceOp(module: Module):
     # CHECK:   yield %[[ARG0]] : !transform.any_op
     # CHECK: }
 
+
 @run
 def testNestedSequenceOp(module: Module):
     sequence = transform.SequenceOp(
@@ -103,55 +124,65 @@ def testSequenceOpWithExtras(module: Module):
     # CHECK-LABEL: TEST: testSequenceOpWithExtras
     # CHECK: transform.sequence failures(propagate)
     # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+    sequence = transform.sequence(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        transform.yield_()
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
 
 
 @run
 def testNestedSequenceOpWithExtras(module: Module):
-  sequence = transform.SequenceOp(
+    sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
         transform.AnyOpType.get(),
         [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
     )
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
             transform.FailurePropagationMode.Propagate,
             [],
             sequence.bodyTarget,
             sequence.bodyExtraArgs,
         )
-    with InsertionPoint(nested.body):
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
-  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+        with InsertionPoint(nested.body):
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+    # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
 
 
 @run
 def testTransformPDLOps(module: Module):
-  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(withPdl.body):
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [transform.AnyOpType.get()],
-        withPdl.bodyTarget,
-    )
-    with InsertionPoint(sequence.body):
-      match = transform_pdl.PDLMatchOp(
-          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
-      )
-      transform.YieldOp(match)
-  # CHECK-LABEL: TEST: testTransformPDLOps
-  # CHECK: transform.with_pdl_patterns {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
-  # CHECK:     yield %[[RES]] : !transform.any_op
-  # CHECK:   }
-  # CHECK: }
+    withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+    with InsertionPoint(withPdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate,
+            [transform.AnyOpType.get()],
+            withPdl.bodyTarget,
+        )
+        with InsertionPoint(sequence.body):
+            match = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+            )
+            transform.YieldOp(match)
+    # CHECK-LABEL: TEST: testTransformPDLOps
+    # CHECK: transform.with_pdl_patterns {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+    # CHECK:     yield %[[RES]] : !transform.any_op
+    # CHECK:   }
+    # CHECK: }
 
 
 @run
@@ -161,32 +192,53 @@ def testNamedSequenceOp(module: Module):
         "__transform_main",
         [transform.AnyOpType.get()],
         [transform.AnyOpType.get()],
-        arg_attrs = [{"transform.consumed": UnitAttr.get()}])
+        arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+    )
     with InsertionPoint(named_sequence.body):
         transform.YieldOp([named_sequence.bodyTarget])
     # CHECK-LABEL: TEST: testNamedSequenceOp
     # CHECK: module attributes {transform.with_named_sequence} {
-    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
-    # CHECK:   yield %[[ARG0]] : !transform.any_op
+    # CHECK:   transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
+    # CHECK:     yield %[[ARG0]] : !transform.any_op
+    named_sequence = transform.named_sequence(
+        "other_seq",
+        [transform.AnyOpType.get()],
+        [transform.AnyOpType.get()],
+        arg_attrs=[{"transform.consumed": UnitAttr.get()}],
+    )
+    with InsertionPoint(named_sequence.body):
+        transform.yield_([named_sequence.bodyTarget])
+    # CHECK:   transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
+    # CHECK:     yield %[[ARG1]] : !transform.any_op
 
 
 @run
 def testGetParentOp(module: Module):
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-  )
-  with InsertionPoint(sequence.body):
-    transform.GetParentOp(
-        transform.AnyOpType.get(),
-        sequence.bodyTarget,
-        isolated_from_above=True,
-        nth_parent=2,
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testGetParentOp
-  # CHECK: transform.sequence
-  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:   = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
+    with InsertionPoint(sequence.body):
+        transform.GetParentOp(
+            transform.AnyOpType.get(),
+            sequence.bodyTarget,
+            isolated_from_above=True,
+            nth_parent=2,
+        )
+        transform.get_parent_op(
+            transform.AnyOpType.get(),
+            sequence.bodyTarget,
+            isolated_from_above=True,
+            nth_parent=2,
+            allow_empty_results=True,
+            op_name="func.func",
+            deduplicate=True,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testGetParentOp
+    # CHECK: transform.sequence
+    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:   = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
+    # CHECK:   = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"}
 
 
 @run
@@ -195,43 +247,58 @@ def testMergeHandlesOp(module: Module):
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
-        transform.MergeHandlesOp([sequence.bodyTarget])
+        res = transform.MergeHandlesOp([sequence.bodyTarget])
+        transform.merge_handles([res.result], deduplicate=True)
         transform.YieldOp()
     # CHECK-LABEL: TEST: testMergeHandlesOp
     # CHECK: transform.sequence
     # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-    # CHECK:   = merge_handles %[[ARG1]]
+    # CHECK:  %[[RES1:.+]] = merge_handles %[[ARG1]] : !transform.any_op
+    # CHECK:               = merge_handles deduplicate %[[RES1]] : !transform.any_op
 
 
 @run
 def testApplyPatternsOpCompact(module: Module):
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-  )
-  with InsertionPoint(sequence.body):
-    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
-      transform.ApplyCanonicalizationPatternsOp()
-    transform.YieldOp()
-    # CHECK-LABEL: TEST: testApplyPatternsOpCompact
-    # CHECK: apply_patterns to
-    # CHECK: transform.apply_patterns.canonicalization
-    # CHECK: !transform.any_op
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
+            transform.ApplyCanonicalizationPatternsOp()
+        with InsertionPoint(
+            transform.apply_patterns(
+                sequence.bodyTarget,
+                apply_cse=True,
+                max_iterations=3,
+                max_num_rewrites=5,
+            ).patterns
+        ):
+            transform.ApplyCanonicalizationPatternsOp()
+        transform.YieldOp()
+        # CHECK-LABEL: TEST: testApplyPatternsOpCompact
+        # CHECK: apply_patterns to
+        # CHECK: transform.apply_patterns.canonicalization
+        # CHECK: } : !transform.any_op
+        # CHECK: apply_patterns to
+        # CHECK: transform.apply_patterns.canonicalization
+        # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op
 
 
 @run
 def testApplyPatternsOpWithType(module: Module):
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.Propagate, [],
-      transform.OperationType.get('test.dummy')
-  )
-  with InsertionPoint(sequence.body):
-    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
-      transform.ApplyCanonicalizationPatternsOp()
-    transform.YieldOp()
-    # CHECK-LABEL: TEST: testApplyPatternsOp
-    # CHECK: apply_patterns to
-    # CHECK: transform.apply_patterns.canonicalization
-    # CHECK: !transform.op<"test.dummy">
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("test.dummy"),
+    )
+    with InsertionPoint(sequence.body):
+        with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
+            transform.ApplyCanonicalizationPatternsOp()
+        transform.YieldOp()
+        # CHECK-LABEL: TEST: testApplyPatternsOp
+        # CHECK: apply_patterns to
+        # CHECK: transform.apply_patterns.canonicalization
+        # CHECK: !transform.op<"test.dummy">
 
 
 @run
@@ -249,11 +316,13 @@ def testReplicateOp(module: Module):
                 transform.AnyOpType.get(), sequence.bodyTarget, "second"
             )
             transform.ReplicateOp(m1, [m2])
+            transform.replicate(m1, [m2])
             transform.YieldOp()
     # CHECK-LABEL: TEST: testReplicateOp
     # CHECK: %[[FIRST:.+]] = pdl_match
     # CHECK: %[[SECOND:.+]] = pdl_match
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+    # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
 
 
 # CHECK-LABEL: TEST: testApplyRegisteredPassOp


        


More information about the Mlir-commits mailing list