[Mlir-commits] [mlir] cb38805 - Revert "[mlir][python]Add sugared buider for transform.named_sequence (#71597)"

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 8 01:34:42 PST 2023


Author: Nicolas Vasilache
Date: 2023-11-08T09:34:29Z
New Revision: cb3880515fe9065ae3378f3d145a17bf75fa6740

URL: https://github.com/llvm/llvm-project/commit/cb3880515fe9065ae3378f3d145a17bf75fa6740
DIFF: https://github.com/llvm/llvm-project/commit/cb3880515fe9065ae3378f3d145a17bf75fa6740.diff

LOG: Revert "[mlir][python]Add sugared buider for transform.named_sequence (#71597)"

This reverts commit 4f51b2bfe3ec11b597272be6caa00efb575bc59f.

Added: 
    

Modified: 
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 23b278d374332b5..166c5c5ca4ec344 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
- at _ods_cext.register_operation(_Dialect, replace=True)
-class NamedSequenceOp(NamedSequenceOp):
-    def __init__(
-        self,
-        sym_name,
-        input_types: Sequence[Type],
-        result_types: Sequence[Type],
-    ):
-        function_type = FunctionType.get(input_types, result_types)
-        super().__init__(
-            sym_name=sym_name,
-            function_type=TypeAttr.get(function_type),
-        )
-        self.regions[0].blocks.append(*input_types)
-
-    @property
-    def body(self) -> Block:
-        return self.regions[0].blocks[0]
-
-    @property
-    def bodyTarget(self) -> Value:
-        return self.body.arguments[0]
-
-    @property
-    def bodyExtraArgs(self) -> BlockArgumentList:
-        return self.body.arguments[1:]
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index e7f448850a66aa1..d778172a607a360 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             print("\nTEST:", f.__name__)
-            f(module)
+            f()
         print(module)
     return f
 
 
 @run
-def testTypes(module: Module):
+def testTypes():
     # CHECK-LABEL: TEST: testTypes
     # CHECK: !transform.any_op
     any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes(module: Module):
 
 
 @run
-def testSequenceOp(module: Module):
+def testSequenceOp():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
@@ -60,23 +60,103 @@ def testSequenceOp(module: Module):
 
 
 @run
-def testNamedSequenceOp(module: Module):
-    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
-    named_sequence = transform.NamedSequenceOp(
-        '__transform_main',
-        [transform.AnyOpType.get()],
+def testNestedSequenceOp():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
+        )
+        with InsertionPoint(nested.body):
+            doubly_nested = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [transform.AnyOpType.get()],
+                nested.bodyTarget,
+            )
+            with InsertionPoint(doubly_nested.body):
+                transform.YieldOp([doubly_nested.bodyTarget])
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOp
+    # CHECK: transform.sequence failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+    # CHECK:       yield %[[ARG2]] : !transform.any_op
+    # CHECK:     }
+    # CHECK:   }
+    # CHECK: }
+
+
+ at run
+def testSequenceOpWithExtras():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras():
+  sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+  with InsertionPoint(sequence.body):
+    nested = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate,
+            [],
+            sequence.bodyTarget,
+            sequence.bodyExtraArgs,
+        )
+    with InsertionPoint(nested.body):
+      transform.YieldOp()
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+  # CHECK: transform.sequence failures(propagate)
+  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+ at run
+def testTransformPDLOps():
+  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+  with InsertionPoint(withPdl.body):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
+        withPdl.bodyTarget,
     )
-    with InsertionPoint(named_sequence.body):
-        transform.YieldOp([named_sequence.bodyTarget])
-    # CHECK-LABEL: TEST: testNamedSequenceOp
-    # CHECK: module attributes {transform.with_named_sequence} {
-    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
-    # CHECK:   yield %[[ARG0]] : !transform.any_op
+    with InsertionPoint(sequence.body):
+      match = transform_pdl.PDLMatchOp(
+          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+      )
+      transform.YieldOp(match)
+  # CHECK-LABEL: TEST: testTransformPDLOps
+  # CHECK: transform.with_pdl_patterns {
+  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+  # CHECK:     yield %[[RES]] : !transform.any_op
+  # CHECK:   }
+  # CHECK: }
 
 
 @run
-def testGetParentOp(module: Module):
+def testGetParentOp():
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -95,7 +175,7 @@ def testGetParentOp(module: Module):
 
 
 @run
-def testMergeHandlesOp(module: Module):
+def testMergeHandlesOp():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):
 
 
 @run
-def testApplyPatternsOpCompact(module: Module):
+def testApplyPatternsOpCompact():
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):
 
 
 @run
-def testApplyPatternsOpWithType(module: Module):
+def testApplyPatternsOpWithType():
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
@@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):
 
 
 @run
-def testReplicateOp(module: Module):
+def testReplicateOp():
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list