[Mlir-commits] [mlir] cb38805 - Revert "[mlir][python]Add sugared buider for transform.named_sequence (#71597)"
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 8 01:34:42 PST 2023
Author: Nicolas Vasilache
Date: 2023-11-08T09:34:29Z
New Revision: cb3880515fe9065ae3378f3d145a17bf75fa6740
URL: https://github.com/llvm/llvm-project/commit/cb3880515fe9065ae3378f3d145a17bf75fa6740
DIFF: https://github.com/llvm/llvm-project/commit/cb3880515fe9065ae3378f3d145a17bf75fa6740.diff
LOG: Revert "[mlir][python]Add sugared buider for transform.named_sequence (#71597)"
This reverts commit 4f51b2bfe3ec11b597272be6caa00efb575bc59f.
Added:
Modified:
mlir/python/mlir/dialects/transform/__init__.py
mlir/test/python/dialects/transform.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 23b278d374332b5..166c5c5ca4ec344 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
- at _ods_cext.register_operation(_Dialect, replace=True)
-class NamedSequenceOp(NamedSequenceOp):
- def __init__(
- self,
- sym_name,
- input_types: Sequence[Type],
- result_types: Sequence[Type],
- ):
- function_type = FunctionType.get(input_types, result_types)
- super().__init__(
- sym_name=sym_name,
- function_type=TypeAttr.get(function_type),
- )
- self.regions[0].blocks.append(*input_types)
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
-
- @property
- def bodyExtraArgs(self) -> BlockArgumentList:
- return self.body.arguments[1:]
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
def __init__(
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index e7f448850a66aa1..d778172a607a360 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
- f(module)
+ f()
print(module)
return f
@run
-def testTypes(module: Module):
+def testTypes():
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes(module: Module):
@run
-def testSequenceOp(module: Module):
+def testSequenceOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
@@ -60,23 +60,103 @@ def testSequenceOp(module: Module):
@run
-def testNamedSequenceOp(module: Module):
- module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
- named_sequence = transform.NamedSequenceOp(
- '__transform_main',
- [transform.AnyOpType.get()],
+def testNestedSequenceOp():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
+ )
+ with InsertionPoint(nested.body):
+ doubly_nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [transform.AnyOpType.get()],
+ nested.bodyTarget,
+ )
+ with InsertionPoint(doubly_nested.body):
+ transform.YieldOp([doubly_nested.bodyTarget])
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOp
+ # CHECK: transform.sequence failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+ # CHECK: yield %[[ARG2]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+
+
+ at run
+def testSequenceOpWithExtras():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+
+
+ at run
+def testNestedSequenceOpWithExtras():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+ )
+ with InsertionPoint(sequence.body):
+ nested = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ sequence.bodyTarget,
+ sequence.bodyExtraArgs,
+ )
+ with InsertionPoint(nested.body):
+ transform.YieldOp()
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+ # CHECK: transform.sequence failures(propagate)
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+
+
+ at run
+def testTransformPDLOps():
+ withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+ with InsertionPoint(withPdl.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
+ withPdl.bodyTarget,
)
- with InsertionPoint(named_sequence.body):
- transform.YieldOp([named_sequence.bodyTarget])
- # CHECK-LABEL: TEST: testNamedSequenceOp
- # CHECK: module attributes {transform.with_named_sequence} {
- # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
- # CHECK: yield %[[ARG0]] : !transform.any_op
+ with InsertionPoint(sequence.body):
+ match = transform_pdl.PDLMatchOp(
+ transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+ )
+ transform.YieldOp(match)
+ # CHECK-LABEL: TEST: testTransformPDLOps
+ # CHECK: transform.with_pdl_patterns {
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+ # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+ # CHECK: yield %[[RES]] : !transform.any_op
+ # CHECK: }
+ # CHECK: }
@run
-def testGetParentOp(module: Module):
+def testGetParentOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -95,7 +175,7 @@ def testGetParentOp(module: Module):
@run
-def testMergeHandlesOp(module: Module):
+def testMergeHandlesOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):
@run
-def testApplyPatternsOpCompact(module: Module):
+def testApplyPatternsOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
@@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):
@run
-def testApplyPatternsOpWithType(module: Module):
+def testApplyPatternsOpWithType():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [],
transform.OperationType.get('test.dummy')
@@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):
@run
-def testReplicateOp(module: Module):
+def testReplicateOp():
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list