[Mlir-commits] [mlir] [mlir][python]Add sugared buider for transform.named_sequence (PR #71597)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Nov 8 01:06:43 PST 2023
================
@@ -60,103 +60,23 @@ def testSequenceOp():
@run
-def testNestedSequenceOp():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
- )
- with InsertionPoint(nested.body):
- doubly_nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [transform.AnyOpType.get()],
- nested.bodyTarget,
- )
- with InsertionPoint(doubly_nested.body):
- transform.YieldOp([doubly_nested.bodyTarget])
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOp
- # CHECK: transform.sequence failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
- # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
- # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
- # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
- # CHECK: yield %[[ARG2]] : !transform.any_op
- # CHECK: }
- # CHECK: }
- # CHECK: }
-
-
- at run
-def testSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.AnyOpType.get(),
- [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
- )
- with InsertionPoint(sequence.body):
- transform.YieldOp()
- # CHECK-LABEL: TEST: testSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
-
-
- at run
-def testNestedSequenceOpWithExtras():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.AnyOpType.get(),
- [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
- )
- with InsertionPoint(sequence.body):
- nested = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- sequence.bodyTarget,
- sequence.bodyExtraArgs,
- )
- with InsertionPoint(nested.body):
- transform.YieldOp()
- transform.YieldOp()
- # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
- # CHECK: transform.sequence failures(propagate)
- # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
- # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
-
-
- at run
-def testTransformPDLOps():
- withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
- with InsertionPoint(withPdl.body):
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
+def testNamedSequenceOp(module: Module):
+ module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
----------------
ftynse wrote:
This can also be done by the caller (`def run`) rathen by the callee, which will remove the need to pass in the module into every test. Nothing will fail if the module has the attribute, but doesn't actually have named sequences in it.
https://github.com/llvm/llvm-project/pull/71597
More information about the Mlir-commits
mailing list