[Mlir-commits] [mlir] [mlir][python] Reland - Add sugared builder for transform.named_sequence (PR #71642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 8 01:19:30 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

<details>
<summary>Changes</summary>

This reverts, addresses issues with #<!-- -->71597 and relands

---
Full diff: https://github.com/llvm/llvm-project/pull/71642.diff


1 Files Affected:

- (modified) mlir/test/python/dialects/transform.py (+94) 


``````````diff
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 084f3ce2d502371..8212739c04a8777 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -58,6 +58,100 @@ def testSequenceOp(module: Module):
     # CHECK:   yield %[[ARG0]] : !transform.any_op
     # CHECK: }
 
+ at run
+def testNestedSequenceOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
+        )
+        with InsertionPoint(nested.body):
+            doubly_nested = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [transform.AnyOpType.get()],
+                nested.bodyTarget,
+            )
+            with InsertionPoint(doubly_nested.body):
+                transform.YieldOp([doubly_nested.bodyTarget])
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOp
+    # CHECK: transform.sequence failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+    # CHECK:       yield %[[ARG2]] : !transform.any_op
+    # CHECK:     }
+    # CHECK:   }
+    # CHECK: }
+
+
+ at run
+def testSequenceOpWithExtras(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras(module: Module):
+  sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+  with InsertionPoint(sequence.body):
+    nested = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate,
+            [],
+            sequence.bodyTarget,
+            sequence.bodyExtraArgs,
+        )
+    with InsertionPoint(nested.body):
+      transform.YieldOp()
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+  # CHECK: transform.sequence failures(propagate)
+  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+ at run
+def testTransformPDLOps(module: Module):
+  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+  with InsertionPoint(withPdl.body):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [transform.AnyOpType.get()],
+        withPdl.bodyTarget,
+    )
+    with InsertionPoint(sequence.body):
+      match = transform_pdl.PDLMatchOp(
+          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+      )
+      transform.YieldOp(match)
+  # CHECK-LABEL: TEST: testTransformPDLOps
+  # CHECK: transform.with_pdl_patterns {
+  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+  # CHECK:     yield %[[RES]] : !transform.any_op
+  # CHECK:   }
+  # CHECK: }
 
 @run
 def testNamedSequenceOp(module: Module):

``````````

</details>


https://github.com/llvm/llvm-project/pull/71642


More information about the Mlir-commits mailing list