[Mlir-commits] [mlir] [mlir][python]Add sugared buider for transform.named_sequence (PR #71597)
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Nov 7 14:55:50 PST 2023
================
@@ -60,103 +60,23 @@ def testSequenceOp():
@run
-def testNestedSequenceOp():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
- )
- with InsertionPoint(nested.body):
- doubly_nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [transform.AnyOpType.get()],
- nested.bodyTarget,
- )
- with InsertionPoint(doubly_nested.body):
- transform.YieldOp([doubly_nested.bodyTarget])
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOp
- # CHECK: transform.sequence failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
- # CHECK: yield %[[ARG2]] : !transform.any_op
- # CHECK: }
- # CHECK: }
- # CHECK: }
-
-
- at run
-def testSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.AnyOpType.get(),
- [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
- )
- with InsertionPoint(sequence.body):
- transform.YieldOp()
- # CHECK-LABEL: TEST: testSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
-
-
- at run
-def testNestedSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.AnyOpType.get(),
- [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
- )
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- sequence.bodyTarget,
- sequence.bodyExtraArgs,
- )
- with InsertionPoint(nested.body):
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
- # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
-
-
- at run
-def testTransformPDLOps():
- withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
- with InsertionPoint(withPdl.body):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
+def testNamedSequenceOp(module: Module):
+ module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
----------------
nicolasvasilache wrote:
I am unclear whether this is the idiomatic way to add attributes after the fact
https://github.com/llvm/llvm-project/pull/71597
More information about the Mlir-commits
mailing list