[Mlir-commits] [mlir] [mlir][python]Add sugared buider for transform.named_sequence (PR #71597)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 7 14:55:12 PST 2023


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/71597

None

>From 0e382dd88c7ae6d07da35f9a7312dbfd9850b630 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Tue, 7 Nov 2023 22:54:30 +0000
Subject: [PATCH] [mlir][python]Add sugared buider for transform.named_sequence

---
 .../mlir/dialects/transform/__init__.py       |  28 +++++
 mlir/test/python/dialects/transform.py        | 118 +++---------------
 2 files changed, 47 insertions(+), 99 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 166c5c5ca4ec344..23b278d374332b5 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class NamedSequenceOp(NamedSequenceOp):
+    def __init__(
+        self,
+        sym_name,
+        input_types: Sequence[Type],
+        result_types: Sequence[Type],
+    ):
+        function_type = FunctionType.get(input_types, result_types)
+        super().__init__(
+            sym_name=sym_name,
+            function_type=TypeAttr.get(function_type),
+        )
+        self.regions[0].blocks.append(*input_types)
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
+
+    @property
+    def bodyExtraArgs(self) -> BlockArgumentList:
+        return self.body.arguments[1:]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index d778172a607a360..e7f448850a66aa1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             print("\nTEST:", f.__name__)
-            f()
+            f(module)
         print(module)
     return f
 
 
 @run
-def testTypes():
+def testTypes(module: Module):
     # CHECK-LABEL: TEST: testTypes
     # CHECK: !transform.any_op
     any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
 
 
 @run
-def testSequenceOp():
+def testSequenceOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
@@ -60,103 +60,23 @@ def testSequenceOp():
 
 
 @run
-def testNestedSequenceOp():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        nested = transform.SequenceOp(
-            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
-        )
-        with InsertionPoint(nested.body):
-            doubly_nested = transform.SequenceOp(
-                transform.FailurePropagationMode.Propagate,
-                [transform.AnyOpType.get()],
-                nested.bodyTarget,
-            )
-            with InsertionPoint(doubly_nested.body):
-                transform.YieldOp([doubly_nested.bodyTarget])
-            transform.YieldOp()
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: testNestedSequenceOp
-    # CHECK: transform.sequence failures(propagate) {
-    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
-    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
-    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
-    # CHECK:       yield %[[ARG2]] : !transform.any_op
-    # CHECK:     }
-    # CHECK:   }
-    # CHECK: }
-
-
- at run
-def testSequenceOpWithExtras():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.AnyOpType.get(),
-        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
-    )
-    with InsertionPoint(sequence.body):
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: testSequenceOpWithExtras
-    # CHECK: transform.sequence failures(propagate)
-    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
-
-
- at run
-def testNestedSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.AnyOpType.get(),
-        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
-    )
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(
-            transform.FailurePropagationMode.Propagate,
-            [],
-            sequence.bodyTarget,
-            sequence.bodyExtraArgs,
-        )
-    with InsertionPoint(nested.body):
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
-  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
-
-
- at run
-def testTransformPDLOps():
-  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(withPdl.body):
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
+def testNamedSequenceOp(module: Module):
+    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
+    named_sequence = transform.NamedSequenceOp(
+        '__transform_main',
+        [transform.AnyOpType.get()],
         [transform.AnyOpType.get()],
-        withPdl.bodyTarget,
     )
-    with InsertionPoint(sequence.body):
-      match = transform_pdl.PDLMatchOp(
-          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
-      )
-      transform.YieldOp(match)
-  # CHECK-LABEL: TEST: testTransformPDLOps
-  # CHECK: transform.with_pdl_patterns {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
-  # CHECK:     yield %[[RES]] : !transform.any_op
-  # CHECK:   }
-  # CHECK: }
+    with InsertionPoint(named_sequence.body):
+        transform.YieldOp([named_sequence.bodyTarget])
+    # CHECK-LABEL: TEST: testNamedSequenceOp
+    # CHECK: module attributes {transform.with_named_sequence} {
+    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+    # CHECK:   yield %[[ARG0]] : !transform.any_op
 
 
 @run
-def testGetParentOp():
+def testGetParentOp(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -175,7 +95,7 @@ def testGetParentOp():
 
 
 @run
-def testMergeHandlesOp():
+def testMergeHandlesOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -189,7 +109,7 @@ def testMergeHandlesOp():
 
 
 @run
-def testApplyPatternsOpCompact():
+def testApplyPatternsOpCompact(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -204,7 +124,7 @@ def testApplyPatternsOpCompact():
 
 
 @run
-def testApplyPatternsOpWithType():
+def testApplyPatternsOpWithType(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
@@ -220,7 +140,7 @@ def testApplyPatternsOpWithType():
 
 
 @run
-def testReplicateOp():
+def testReplicateOp(module: Module):
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(



More information about the Mlir-commits mailing list