[Mlir-commits] [mlir] [mlir][python] Reland - Add sugared builder for transform.named_sequence (PR #71642)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 8 01:19:02 PST 2023


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/71642

This reverts, addresses issues with #71597 and relands

>From 4c9f7b616f23d844c7ebaca42c09b9f343ba43f3 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 8 Nov 2023 09:12:17 +0000
Subject: [PATCH 1/3] Revert "[mlir][python] NFC - Lint fix"

This reverts commit 8c014e5949fdbecc31a82138361f8cdf886768a9.
---
 mlir/test/python/dialects/transform.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 084f3ce2d502371..e7f448850a66aa1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -63,7 +63,7 @@ def testSequenceOp(module: Module):
 def testNamedSequenceOp(module: Module):
     module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
     named_sequence = transform.NamedSequenceOp(
-        "__transform_main",
+        '__transform_main',
         [transform.AnyOpType.get()],
         [transform.AnyOpType.get()],
     )

>From be056f640a52c1acadb85bc97eebef9d932ff760 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 8 Nov 2023 09:12:24 +0000
Subject: [PATCH 2/3] Revert "[mlir][python]Add sugared buider for
 transform.named_sequence (#71597)"

This reverts commit 4f51b2bfe3ec11b597272be6caa00efb575bc59f.
---
 .../mlir/dialects/transform/__init__.py       |  28 -----
 mlir/test/python/dialects/transform.py        | 118 +++++++++++++++---
 2 files changed, 99 insertions(+), 47 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 23b278d374332b5..166c5c5ca4ec344 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
- at _ods_cext.register_operation(_Dialect, replace=True)
-class NamedSequenceOp(NamedSequenceOp):
-    def __init__(
-        self,
-        sym_name,
-        input_types: Sequence[Type],
-        result_types: Sequence[Type],
-    ):
-        function_type = FunctionType.get(input_types, result_types)
-        super().__init__(
-            sym_name=sym_name,
-            function_type=TypeAttr.get(function_type),
-        )
-        self.regions[0].blocks.append(*input_types)
-
-    @property
-    def body(self) -> Block:
-        return self.regions[0].blocks[0]
-
-    @property
-    def bodyTarget(self) -> Value:
-        return self.body.arguments[0]
-
-    @property
-    def bodyExtraArgs(self) -> BlockArgumentList:
-        return self.body.arguments[1:]
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index e7f448850a66aa1..d778172a607a360 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             print("\nTEST:", f.__name__)
-            f(module)
+            f()
         print(module)
     return f
 
 
 @run
-def testTypes(module: Module):
+def testTypes():
     # CHECK-LABEL: TEST: testTypes
     # CHECK: !transform.any_op
     any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes(module: Module):
 
 
 @run
-def testSequenceOp(module: Module):
+def testSequenceOp():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
@@ -60,23 +60,103 @@ def testSequenceOp(module: Module):
 
 
 @run
-def testNamedSequenceOp(module: Module):
-    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
-    named_sequence = transform.NamedSequenceOp(
-        '__transform_main',
-        [transform.AnyOpType.get()],
+def testNestedSequenceOp():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
+        )
+        with InsertionPoint(nested.body):
+            doubly_nested = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [transform.AnyOpType.get()],
+                nested.bodyTarget,
+            )
+            with InsertionPoint(doubly_nested.body):
+                transform.YieldOp([doubly_nested.bodyTarget])
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOp
+    # CHECK: transform.sequence failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+    # CHECK:       yield %[[ARG2]] : !transform.any_op
+    # CHECK:     }
+    # CHECK:   }
+    # CHECK: }
+
+
+ at run
+def testSequenceOpWithExtras():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras():
+  sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+  with InsertionPoint(sequence.body):
+    nested = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate,
+            [],
+            sequence.bodyTarget,
+            sequence.bodyExtraArgs,
+        )
+    with InsertionPoint(nested.body):
+      transform.YieldOp()
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+  # CHECK: transform.sequence failures(propagate)
+  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+ at run
+def testTransformPDLOps():
+  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+  with InsertionPoint(withPdl.body):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
+        withPdl.bodyTarget,
     )
-    with InsertionPoint(named_sequence.body):
-        transform.YieldOp([named_sequence.bodyTarget])
-    # CHECK-LABEL: TEST: testNamedSequenceOp
-    # CHECK: module attributes {transform.with_named_sequence} {
-    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
-    # CHECK:   yield %[[ARG0]] : !transform.any_op
+    with InsertionPoint(sequence.body):
+      match = transform_pdl.PDLMatchOp(
+          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+      )
+      transform.YieldOp(match)
+  # CHECK-LABEL: TEST: testTransformPDLOps
+  # CHECK: transform.with_pdl_patterns {
+  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+  # CHECK:     yield %[[RES]] : !transform.any_op
+  # CHECK:   }
+  # CHECK: }
 
 
 @run
-def testGetParentOp(module: Module):
+def testGetParentOp():
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -95,7 +175,7 @@ def testGetParentOp(module: Module):
 
 
 @run
-def testMergeHandlesOp(module: Module):
+def testMergeHandlesOp():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):
 
 
 @run
-def testApplyPatternsOpCompact(module: Module):
+def testApplyPatternsOpCompact():
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):
 
 
 @run
-def testApplyPatternsOpWithType(module: Module):
+def testApplyPatternsOpWithType():
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
@@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):
 
 
 @run
-def testReplicateOp(module: Module):
+def testReplicateOp():
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(

>From 5b116bf1b211ef1a791085afa9e949cd0885d2e5 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Wed, 8 Nov 2023 09:49:57 +0100
Subject: [PATCH 3/3] [mlir][python] Reland - Add sugared builder for
 transform.named_sequence (#71597)

---
 .../mlir/dialects/transform/__init__.py       | 28 +++++++++++++
 mlir/test/python/dialects/transform.py        | 40 +++++++++++++------
 2 files changed, 55 insertions(+), 13 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 166c5c5ca4ec344..23b278d374332b5 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class NamedSequenceOp(NamedSequenceOp):
+    def __init__(
+        self,
+        sym_name,
+        input_types: Sequence[Type],
+        result_types: Sequence[Type],
+    ):
+        function_type = FunctionType.get(input_types, result_types)
+        super().__init__(
+            sym_name=sym_name,
+            function_type=TypeAttr.get(function_type),
+        )
+        self.regions[0].blocks.append(*input_types)
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
+
+    @property
+    def bodyExtraArgs(self) -> BlockArgumentList:
+        return self.body.arguments[1:]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index d778172a607a360..8212739c04a8777 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             print("\nTEST:", f.__name__)
-            f()
+            f(module)
         print(module)
     return f
 
 
 @run
-def testTypes():
+def testTypes(module: Module):
     # CHECK-LABEL: TEST: testTypes
     # CHECK: !transform.any_op
     any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
 
 
 @run
-def testSequenceOp():
+def testSequenceOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
@@ -58,9 +58,8 @@ def testSequenceOp():
     # CHECK:   yield %[[ARG0]] : !transform.any_op
     # CHECK: }
 
-
 @run
-def testNestedSequenceOp():
+def testNestedSequenceOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -92,7 +91,7 @@ def testNestedSequenceOp():
 
 
 @run
-def testSequenceOpWithExtras():
+def testSequenceOpWithExtras(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
@@ -107,7 +106,7 @@ def testSequenceOpWithExtras():
 
 
 @run
-def testNestedSequenceOpWithExtras():
+def testNestedSequenceOpWithExtras(module: Module):
   sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
@@ -131,7 +130,7 @@ def testNestedSequenceOpWithExtras():
 
 
 @run
-def testTransformPDLOps():
+def testTransformPDLOps(module: Module):
   withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
   with InsertionPoint(withPdl.body):
     sequence = transform.SequenceOp(
@@ -154,9 +153,24 @@ def testTransformPDLOps():
   # CHECK:   }
   # CHECK: }
 
+ at run
+def testNamedSequenceOp(module: Module):
+    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
+    named_sequence = transform.NamedSequenceOp(
+        "__transform_main",
+        [transform.AnyOpType.get()],
+        [transform.AnyOpType.get()],
+    )
+    with InsertionPoint(named_sequence.body):
+        transform.YieldOp([named_sequence.bodyTarget])
+    # CHECK-LABEL: TEST: testNamedSequenceOp
+    # CHECK: module attributes {transform.with_named_sequence} {
+    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+    # CHECK:   yield %[[ARG0]] : !transform.any_op
+
 
 @run
-def testGetParentOp():
+def testGetParentOp(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -175,7 +189,7 @@ def testGetParentOp():
 
 
 @run
-def testMergeHandlesOp():
+def testMergeHandlesOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -189,7 +203,7 @@ def testMergeHandlesOp():
 
 
 @run
-def testApplyPatternsOpCompact():
+def testApplyPatternsOpCompact(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -204,7 +218,7 @@ def testApplyPatternsOpCompact():
 
 
 @run
-def testApplyPatternsOpWithType():
+def testApplyPatternsOpWithType(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
@@ -220,7 +234,7 @@ def testApplyPatternsOpWithType():
 
 
 @run
-def testReplicateOp():
+def testReplicateOp(module: Module):
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(



More information about the Mlir-commits mailing list