[Mlir-commits] [mlir] [mlir][python] Reland - Add sugared builder for transform.named_sequence (PR #71642)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 8 01:19:02 PST 2023
https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/71642
This reverts, addresses issues with #71597 and relands
>From 4c9f7b616f23d844c7ebaca42c09b9f343ba43f3 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 8 Nov 2023 09:12:17 +0000
Subject: [PATCH 1/3] Revert "[mlir][python] NFC - Lint fix"
This reverts commit 8c014e5949fdbecc31a82138361f8cdf886768a9.
---
mlir/test/python/dialects/transform.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 084f3ce2d502371..e7f448850a66aa1 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -63,7 +63,7 @@ def testSequenceOp(module: Module):
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
named_sequence = transform.NamedSequenceOp(
- "__transform_main",
+ '__transform_main',
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
)
>From be056f640a52c1acadb85bc97eebef9d932ff760 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 8 Nov 2023 09:12:24 +0000
Subject: [PATCH 2/3] Revert "[mlir][python]Add sugared buider for
transform.named_sequence (#71597)"
This reverts commit 4f51b2bfe3ec11b597272be6caa00efb575bc59f.
---
.../mlir/dialects/transform/__init__.py | 28 -----
mlir/test/python/dialects/transform.py | 118 +++++++++++++++---
2 files changed, 99 insertions(+), 47 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 23b278d374332b5..166c5c5ca4ec344 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
- at _ods_cext.register_operation(_Dialect, replace=True)
-class NamedSequenceOp(NamedSequenceOp):
- def __init__(
- self,
- sym_name,
- input_types: Sequence[Type],
- result_types: Sequence[Type],
- ):
- function_type = FunctionType.get(input_types, result_types)
- super().__init__(
- sym_name=sym_name,
- function_type=TypeAttr.get(function_type),
- )
- self.regions[0].blocks.append(*input_types)
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
-
- @property
- def bodyExtraArgs(self) -> BlockArgumentList:
- return self.body.arguments[1:]
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index e7f448850a66aa1..d778172a607a360 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
- f(module)
+ f()
print(module)
return f
@run
-def testTypes(module: Module):
+def testTypes():
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes(module: Module):
@run
-def testSequenceOp(module: Module):
+def testSequenceOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
@@ -60,23 +60,103 @@ def testSequenceOp(module: Module):
@run
-def testNamedSequenceOp(module: Module):
- module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
- named_sequence = transform.NamedSequenceOp(
- '__transform_main',
- [transform.AnyOpType.get()],
+def testNestedSequenceOp():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
+ )
+ with InsertionPoint(nested.body):
+ doubly_nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ nested.bodyTarget,
+ )
+ with InsertionPoint(doubly_nested.body):
+ transform.YieldOp([doubly_nested.bodyTarget])
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOp
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+ # CHECK: yield %[[ARG2]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+
+
+ at run
+def testSequenceOpWithExtras():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ sequence.bodyTarget,
+ sequence.bodyExtraArgs,
+ )
+ with InsertionPoint(nested.body):
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+ at run
+def testTransformPDLOps():
+ withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+ with InsertionPoint(withPdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
+ withPdl.bodyTarget,
)
- with InsertionPoint(named_sequence.body):
- transform.YieldOp([named_sequence.bodyTarget])
- # CHECK-LABEL: TEST: testNamedSequenceOp
- # CHECK: module attributes {transform.with_named_sequence} {
- # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
- # CHECK: yield %[[ARG0]] : !transform.any_op
+ with InsertionPoint(sequence.body):
+ match = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+ )
+ transform.YieldOp(match)
+ # CHECK-LABEL: TEST: testTransformPDLOps
+ # CHECK: transform.with_pdl_patterns {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+ # CHECK: yield %[[RES]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
@run
-def testGetParentOp(module: Module):
+def testGetParentOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -95,7 +175,7 @@ def testGetParentOp(module: Module):
@run
-def testMergeHandlesOp(module: Module):
+def testMergeHandlesOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):
@run
-def testApplyPatternsOpCompact(module: Module):
+def testApplyPatternsOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):
@run
-def testApplyPatternsOpWithType(module: Module):
+def testApplyPatternsOpWithType():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [],
transform.OperationType.get('test.dummy')
@@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):
@run
-def testReplicateOp(module: Module):
+def testReplicateOp():
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
>From 5b116bf1b211ef1a791085afa9e949cd0885d2e5 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolasvasilache at users.noreply.github.com>
Date: Wed, 8 Nov 2023 09:49:57 +0100
Subject: [PATCH 3/3] [mlir][python] Reland - Add sugared builder for
transform.named_sequence (#71597)
---
.../mlir/dialects/transform/__init__.py | 28 +++++++++++++
mlir/test/python/dialects/transform.py | 40 +++++++++++++------
2 files changed, 55 insertions(+), 13 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 166c5c5ca4ec344..23b278d374332b5 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class NamedSequenceOp(NamedSequenceOp):
+ def __init__(
+ self,
+ sym_name,
+ input_types: Sequence[Type],
+ result_types: Sequence[Type],
+ ):
+ function_type = FunctionType.get(input_types, result_types)
+ super().__init__(
+ sym_name=sym_name,
+ function_type=TypeAttr.get(function_type),
+ )
+ self.regions[0].blocks.append(*input_types)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
+
+ @property
+ def bodyExtraArgs(self) -> BlockArgumentList:
+ return self.body.arguments[1:]
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index d778172a607a360..8212739c04a8777 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
- f()
+ f(module)
print(module)
return f
@run
-def testTypes():
+def testTypes(module: Module):
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
@run
-def testSequenceOp():
+def testSequenceOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
@@ -58,9 +58,8 @@ def testSequenceOp():
# CHECK: yield %[[ARG0]] : !transform.any_op
# CHECK: }
-
@run
-def testNestedSequenceOp():
+def testNestedSequenceOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -92,7 +91,7 @@ def testNestedSequenceOp():
@run
-def testSequenceOpWithExtras():
+def testSequenceOpWithExtras(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
@@ -107,7 +106,7 @@ def testSequenceOpWithExtras():
@run
-def testNestedSequenceOpWithExtras():
+def testNestedSequenceOpWithExtras(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
@@ -131,7 +130,7 @@ def testNestedSequenceOpWithExtras():
@run
-def testTransformPDLOps():
+def testTransformPDLOps(module: Module):
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(withPdl.body):
sequence = transform.SequenceOp(
@@ -154,9 +153,24 @@ def testTransformPDLOps():
# CHECK: }
# CHECK: }
+ at run
+def testNamedSequenceOp(module: Module):
+ module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
+ named_sequence = transform.NamedSequenceOp(
+ "__transform_main",
+ [transform.AnyOpType.get()],
+ [transform.AnyOpType.get()],
+ )
+ with InsertionPoint(named_sequence.body):
+ transform.YieldOp([named_sequence.bodyTarget])
+ # CHECK-LABEL: TEST: testNamedSequenceOp
+ # CHECK: module attributes {transform.with_named_sequence} {
+ # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+ # CHECK: yield %[[ARG0]] : !transform.any_op
+
@run
-def testGetParentOp():
+def testGetParentOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -175,7 +189,7 @@ def testGetParentOp():
@run
-def testMergeHandlesOp():
+def testMergeHandlesOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -189,7 +203,7 @@ def testMergeHandlesOp():
@run
-def testApplyPatternsOpCompact():
+def testApplyPatternsOpCompact(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -204,7 +218,7 @@ def testApplyPatternsOpCompact():
@run
-def testApplyPatternsOpWithType():
+def testApplyPatternsOpWithType(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [],
transform.OperationType.get('test.dummy')
@@ -220,7 +234,7 @@ def testApplyPatternsOpWithType():
@run
-def testReplicateOp():
+def testReplicateOp(module: Module):
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list