[Mlir-commits] [mlir] [mlir][python] Reland - Add sugared builder for transform.named_sequence (PR #71642)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 8 01:19:30 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
This reverts, addresses issues with #<!-- -->71597 and relands
---
Full diff: https://github.com/llvm/llvm-project/pull/71642.diff
1 Files Affected:
- (modified) mlir/test/python/dialects/transform.py (+94)
``````````diff
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 084f3ce2d502371..8212739c04a8777 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -58,6 +58,100 @@ def testSequenceOp(module: Module):
# CHECK: yield %[[ARG0]] : !transform.any_op
# CHECK: }
+ at run
+def testNestedSequenceOp(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
+ )
+ with InsertionPoint(nested.body):
+ doubly_nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ nested.bodyTarget,
+ )
+ with InsertionPoint(doubly_nested.body):
+ transform.YieldOp([doubly_nested.bodyTarget])
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOp
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+ # CHECK: yield %[[ARG2]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+
+
+ at run
+def testSequenceOpWithExtras(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ sequence.bodyTarget,
+ sequence.bodyExtraArgs,
+ )
+ with InsertionPoint(nested.body):
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+ at run
+def testTransformPDLOps(module: Module):
+ withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+ with InsertionPoint(withPdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ withPdl.bodyTarget,
+ )
+ with InsertionPoint(sequence.body):
+ match = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+ )
+ transform.YieldOp(match)
+ # CHECK-LABEL: TEST: testTransformPDLOps
+ # CHECK: transform.with_pdl_patterns {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+ # CHECK: yield %[[RES]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
@run
def testNamedSequenceOp(module: Module):
``````````
</details>
https://github.com/llvm/llvm-project/pull/71642
More information about the Mlir-commits
mailing list