[Mlir-commits] [mlir] cca053c - [mlir][transform][linalg][python] Add mix-in for FuseIntoContainingOp.
Ingo Müller
llvmlistbot at llvm.org
Wed Jul 19 07:42:50 PDT 2023
Author: Ingo Müller
Date: 2023-07-19T14:42:41Z
New Revision: cca053c1f07bb8349019bcb7f72c80ed8176bc11
URL: https://github.com/llvm/llvm-project/commit/cca053c1f07bb8349019bcb7f72c80ed8176bc11
DIFF: https://github.com/llvm/llvm-project/commit/cca053c1f07bb8349019bcb7f72c80ed8176bc11.diff
LOG: [mlir][transform][linalg][python] Add mix-in for FuseIntoContainingOp.
The class did not have any mix-in until now. The new mix-in has two
overloads for the constructor of the class: one with all arguments and
one without the result types, which are defaulted to `AnyOpType`.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D155695
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 7f90a464741b75..1936f4b0e0da7e 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -93,6 +93,70 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
)
+class FuseIntoContainingOp:
+ """Specialization for FuseIntoContainingOp class."""
+
+ @overload
+ def __init__(
+ self,
+ fused_op_type: Type,
+ new_containing_op_type: Type,
+ producer_op: Union[Operation, OpView, Value],
+ containing_op: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ producer_op: Union[Operation, OpView, Value],
+ containing_op: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
+ new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
+ producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(fused_op_type_or_producer_op, Type):
+ if not isinstance(new_containing_op_type_or_containing_op, Type):
+ raise TypeError(
+ "If 'fused_op_type_or_producer_op' is a type, then "
+ "'new_containing_op_type_or_containing_op' is expected "
+ "to be one as well."
+ )
+ fused_op_type = fused_op_type_or_producer_op
+ new_containing_op_type = new_containing_op_type_or_containing_op
+ producer_op = producer_op_or_none
+ containing_op = containing_op_or_none
+ else:
+ fused_op_type = transform.AnyOpType.get()
+ new_containing_op_type = transform.AnyOpType.get()
+ producer_op = fused_op_type_or_producer_op
+ containing_op = new_containing_op_type_or_containing_op
+
+ super().__init__(
+ fused_op_type,
+ new_containing_op_type,
+ producer_op,
+ containing_op,
+ loc=loc,
+ ip=ip,
+ )
+
+
class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index ceeb62c4dfa9bd..0bcfd81d75ff11 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -30,6 +30,45 @@ def testDecompose():
# CHECK: transform.structured.decompose
+ at run
+def testFuseIntoContainingOpTypes():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+ containing = structured.MatchOp.match_op_names(
+ sequence.bodyTarget, ["test.dummy"]
+ )
+ structured.FuseIntoContainingOp(
+ transform.OperationType.get("test.dummy"),
+ transform.OperationType.get("test.dummy"),
+ fused,
+ containing,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
+ # CHECK: = transform.structured.fuse_into_containing_op
+ # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
+
+
+ at run
+def testFuseIntoContainingOpCompact():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+ containing = structured.MatchOp.match_op_names(
+ sequence.bodyTarget, ["test.dummy"]
+ )
+ structured.FuseIntoContainingOp(fused, containing)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
+ # CHECK: = transform.structured.fuse_into_containing_op
+ # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
@run
def testGeneralize():
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list