[Mlir-commits] [mlir] 1dccdf7 - [mlir][linalg][transform][python] Add type arg to MatchOp extension.

Ingo Müller llvmlistbot at llvm.org
Wed Jul 19 02:15:45 PDT 2023


Author: Ingo Müller
Date: 2023-07-19T09:15:41Z
New Revision: 1dccdf7f49a0cdad7121913c900ce9cb8b6e9fdc

URL: https://github.com/llvm/llvm-project/commit/1dccdf7f49a0cdad7121913c900ce9cb8b6e9fdc
DIFF: https://github.com/llvm/llvm-project/commit/1dccdf7f49a0cdad7121913c900ce9cb8b6e9fdc.diff

LOG: [mlir][linalg][transform][python] Add type arg to MatchOp extension.

The extension class to MatchOp has a class method called match_op_names.
The previous version of that function did not allow to specify the
result type. This, however, may be useful/necessary if the op consuming
the resulting handle requires a particular type (such as the
bufferization.EmptyTensorToAllocTensorOp). This patch adds an overload
to match_op_names that allows to specify the result type.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D155567

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index b754034c8a4809..640730997f93e1 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -85,17 +85,52 @@ def __init__(
 class MatchOp:
     """Specialization for MatchOp class."""
 
+    @overload
     @classmethod
     def match_op_names(
-        MatchOp,
+        cls,
         target: Union[Operation, Value],
         names: Sequence[str],
+        *,
         loc=None,
         ip=None,
     ):
-        pdl_operation_type = pdl.OperationType.get()
-        return MatchOp(
-            pdl_operation_type,
+       ...
+
+    @overload
+    @classmethod
+    def match_op_names(
+        cls,
+        result_type: Type,
+        target: Union[Operation, Value],
+        names: Sequence[str],
+        *,
+        loc=None,
+        ip=None,
+    ):
+       ...
+
+    @classmethod
+    def match_op_names(
+        cls,
+        result_type_or_target: Union[Type, Operation, Value],
+        target_or_names: Union[Operation, Value,  Sequence[str]],
+        names_or_none: Optional[Sequence[str]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(result_type_or_target, Type):
+           result_type = result_type_or_target
+           target = target_or_names
+           names = names_or_none
+        else:
+           result_type = transform.AnyOpType.get()
+           target = result_type_or_target
+           names = target_or_names
+
+        return cls(
+            result_type,
             _get_op_result_or_value(target),
             ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
             loc=loc,

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 2dfae47bdfb492..03a47166c824c6 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -57,6 +57,38 @@ def testInterchange():
     # CHECK: iterator_interchange = [1, 0]
 
 
+ at run
+def testMatchOpNames():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMatchOpNames
+    # CHECK: transform.structured.match ops
+    # CHECK-SAME: ["test.dummy"]
+    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+ at run
+def testMatchOpNamesTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MatchOp.match_op_names(
+            transform.OperationType.get("test.dummy"),
+            sequence.bodyTarget,
+            ["test.dummy"],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMatchOpNamesTyped
+    # CHECK: transform.structured.match ops
+    # CHECK-SAME: ["test.dummy"]
+    # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
+
+
 @run
 def testMultitileSizes():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list