[Mlir-commits] [mlir] [mlir][python]Add sugared buider for transform.named_sequence (PR #71597)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 7 14:55:50 PST 2023


================
@@ -60,103 +60,23 @@ def testSequenceOp():
 
 
 @run
-def testNestedSequenceOp():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        nested = transform.SequenceOp(
-            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
-        )
-        with InsertionPoint(nested.body):
-            doubly_nested = transform.SequenceOp(
-                transform.FailurePropagationMode.Propagate,
-                [transform.AnyOpType.get()],
-                nested.bodyTarget,
-            )
-            with InsertionPoint(doubly_nested.body):
-                transform.YieldOp([doubly_nested.bodyTarget])
-            transform.YieldOp()
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: testNestedSequenceOp
-    # CHECK: transform.sequence failures(propagate) {
-    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
-    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
-    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
-    # CHECK:       yield %[[ARG2]] : !transform.any_op
-    # CHECK:     }
-    # CHECK:   }
-    # CHECK: }
-
-
- at run
-def testSequenceOpWithExtras():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.AnyOpType.get(),
-        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
-    )
-    with InsertionPoint(sequence.body):
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: testSequenceOpWithExtras
-    # CHECK: transform.sequence failures(propagate)
-    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
-
-
- at run
-def testNestedSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.AnyOpType.get(),
-        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
-    )
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(
-            transform.FailurePropagationMode.Propagate,
-            [],
-            sequence.bodyTarget,
-            sequence.bodyExtraArgs,
-        )
-    with InsertionPoint(nested.body):
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
-  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
-
-
- at run
-def testTransformPDLOps():
-  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(withPdl.body):
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
+def testNamedSequenceOp(module: Module):
+    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
----------------
nicolasvasilache wrote:

I am unclear whether this is the idiomatic way to add attributes after the fact

https://github.com/llvm/llvm-project/pull/71597


More information about the Mlir-commits mailing list