[Mlir-commits] [mlir] af3d856 - [mlir][python] Reland - Add sugared builder for transform.named_sequence

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 8 01:34:43 PST 2023


Author: Nicolas Vasilache
Date: 2023-11-08T09:34:29Z
New Revision: af3d85694427257ee27a71007af72ea11d54a93a

URL: https://github.com/llvm/llvm-project/commit/af3d85694427257ee27a71007af72ea11d54a93a
DIFF: https://github.com/llvm/llvm-project/commit/af3d85694427257ee27a71007af72ea11d54a93a.diff

LOG: [mlir][python] Reland - Add sugared builder for transform.named_sequence

Address issues with #71597 post-revert and and reland

Added: 
    

Modified: 
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 166c5c5ca4ec344..23b278d374332b5 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class NamedSequenceOp(NamedSequenceOp):
+    def __init__(
+        self,
+        sym_name,
+        input_types: Sequence[Type],
+        result_types: Sequence[Type],
+    ):
+        function_type = FunctionType.get(input_types, result_types)
+        super().__init__(
+            sym_name=sym_name,
+            function_type=TypeAttr.get(function_type),
+        )
+        self.regions[0].blocks.append(*input_types)
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
+
+    @property
+    def bodyExtraArgs(self) -> BlockArgumentList:
+        return self.body.arguments[1:]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class YieldOp(YieldOp):
     def __init__(

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index d778172a607a360..8212739c04a8777 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -10,13 +10,13 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             print("\nTEST:", f.__name__)
-            f()
+            f(module)
         print(module)
     return f
 
 
 @run
-def testTypes():
+def testTypes(module: Module):
     # CHECK-LABEL: TEST: testTypes
     # CHECK: !transform.any_op
     any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
 
 
 @run
-def testSequenceOp():
+def testSequenceOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
@@ -58,9 +58,8 @@ def testSequenceOp():
     # CHECK:   yield %[[ARG0]] : !transform.any_op
     # CHECK: }
 
-
 @run
-def testNestedSequenceOp():
+def testNestedSequenceOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -92,7 +91,7 @@ def testNestedSequenceOp():
 
 
 @run
-def testSequenceOpWithExtras():
+def testSequenceOpWithExtras(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
@@ -107,7 +106,7 @@ def testSequenceOpWithExtras():
 
 
 @run
-def testNestedSequenceOpWithExtras():
+def testNestedSequenceOpWithExtras(module: Module):
   sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
@@ -131,7 +130,7 @@ def testNestedSequenceOpWithExtras():
 
 
 @run
-def testTransformPDLOps():
+def testTransformPDLOps(module: Module):
   withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
   with InsertionPoint(withPdl.body):
     sequence = transform.SequenceOp(
@@ -154,9 +153,24 @@ def testTransformPDLOps():
   # CHECK:   }
   # CHECK: }
 
+ at run
+def testNamedSequenceOp(module: Module):
+    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
+    named_sequence = transform.NamedSequenceOp(
+        "__transform_main",
+        [transform.AnyOpType.get()],
+        [transform.AnyOpType.get()],
+    )
+    with InsertionPoint(named_sequence.body):
+        transform.YieldOp([named_sequence.bodyTarget])
+    # CHECK-LABEL: TEST: testNamedSequenceOp
+    # CHECK: module attributes {transform.with_named_sequence} {
+    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+    # CHECK:   yield %[[ARG0]] : !transform.any_op
+
 
 @run
-def testGetParentOp():
+def testGetParentOp(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -175,7 +189,7 @@ def testGetParentOp():
 
 
 @run
-def testMergeHandlesOp():
+def testMergeHandlesOp(module: Module):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
@@ -189,7 +203,7 @@ def testMergeHandlesOp():
 
 
 @run
-def testApplyPatternsOpCompact():
+def testApplyPatternsOpCompact(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
@@ -204,7 +218,7 @@ def testApplyPatternsOpCompact():
 
 
 @run
-def testApplyPatternsOpWithType():
+def testApplyPatternsOpWithType(module: Module):
   sequence = transform.SequenceOp(
       transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
@@ -220,7 +234,7 @@ def testApplyPatternsOpWithType():
 
 
 @run
-def testReplicateOp():
+def testReplicateOp(module: Module):
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list