[Mlir-commits] [mlir] cca053c - [mlir][transform][linalg][python] Add mix-in for FuseIntoContainingOp.

Ingo Müller llvmlistbot at llvm.org
Wed Jul 19 07:42:50 PDT 2023


Author: Ingo Müller
Date: 2023-07-19T14:42:41Z
New Revision: cca053c1f07bb8349019bcb7f72c80ed8176bc11

URL: https://github.com/llvm/llvm-project/commit/cca053c1f07bb8349019bcb7f72c80ed8176bc11
DIFF: https://github.com/llvm/llvm-project/commit/cca053c1f07bb8349019bcb7f72c80ed8176bc11.diff

LOG: [mlir][transform][linalg][python] Add mix-in for FuseIntoContainingOp.

The class did not have any mix-in until now. The new mix-in has two
overloads for the constructor of the class: one with all arguments and
one without the result types, which are defaulted to `AnyOpType`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D155695

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 7f90a464741b75..1936f4b0e0da7e 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -93,6 +93,70 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
         )
 
 
+class FuseIntoContainingOp:
+    """Specialization for FuseIntoContainingOp class."""
+
+    @overload
+    def __init__(
+        self,
+        fused_op_type: Type,
+        new_containing_op_type: Type,
+        producer_op: Union[Operation, OpView, Value],
+        containing_op: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        producer_op: Union[Operation, OpView, Value],
+        containing_op: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
+        new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
+        producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(fused_op_type_or_producer_op, Type):
+            if not isinstance(new_containing_op_type_or_containing_op, Type):
+                raise TypeError(
+                    "If 'fused_op_type_or_producer_op' is a type, then "
+                    "'new_containing_op_type_or_containing_op' is expected "
+                    "to be one as well."
+                )
+            fused_op_type = fused_op_type_or_producer_op
+            new_containing_op_type = new_containing_op_type_or_containing_op
+            producer_op = producer_op_or_none
+            containing_op = containing_op_or_none
+        else:
+            fused_op_type = transform.AnyOpType.get()
+            new_containing_op_type = transform.AnyOpType.get()
+            producer_op = fused_op_type_or_producer_op
+            containing_op = new_containing_op_type_or_containing_op
+
+        super().__init__(
+            fused_op_type,
+            new_containing_op_type,
+            producer_op,
+            containing_op,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class GeneralizeOp:
     """Specialization for GeneralizeOp class."""
 

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index ceeb62c4dfa9bd..0bcfd81d75ff11 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -30,6 +30,45 @@ def testDecompose():
     # CHECK: transform.structured.decompose
 
 
+ at run
+def testFuseIntoContainingOpTypes():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+        containing = structured.MatchOp.match_op_names(
+            sequence.bodyTarget, ["test.dummy"]
+        )
+        structured.FuseIntoContainingOp(
+            transform.OperationType.get("test.dummy"),
+            transform.OperationType.get("test.dummy"),
+            fused,
+            containing,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
+    # CHECK: = transform.structured.fuse_into_containing_op
+    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
+
+
+ at run
+def testFuseIntoContainingOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+        containing = structured.MatchOp.match_op_names(
+            sequence.bodyTarget, ["test.dummy"]
+        )
+        structured.FuseIntoContainingOp(fused, containing)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
+    # CHECK: = transform.structured.fuse_into_containing_op
+    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
 @run
 def testGeneralize():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list