[Mlir-commits] [mlir] 5967375 - [mlir][python] Add support for arg_attrs and other attrs to NamedSequenceOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 8 05:42:27 PST 2023


Author: Nicolas Vasilache
Date: 2023-11-08T13:42:16Z
New Revision: 5967375fcf3563b74aa7ffef45adb642b514c115

URL: https://github.com/llvm/llvm-project/commit/5967375fcf3563b74aa7ffef45adb642b514c115
DIFF: https://github.com/llvm/llvm-project/commit/5967375fcf3563b74aa7ffef45adb642b514c115.diff

LOG: [mlir][python] Add support for arg_attrs and other attrs to NamedSequenceOp

Added: 
    

Modified: 
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 23b278d374332b5..1dca1a66bc35707 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -172,11 +172,17 @@ def __init__(
         sym_name,
         input_types: Sequence[Type],
         result_types: Sequence[Type],
+        sym_visibility=None,
+        arg_attrs=None,
+        res_attrs=None
     ):
         function_type = FunctionType.get(input_types, result_types)
         super().__init__(
             sym_name=sym_name,
             function_type=TypeAttr.get(function_type),
+            sym_visibility=sym_visibility,
+            arg_attrs=arg_attrs,
+            res_attrs=res_attrs
         )
         self.regions[0].blocks.append(*input_types)
 

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 8212739c04a8777..6ed4818fc9d2f0e 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -153,6 +153,7 @@ def testTransformPDLOps(module: Module):
   # CHECK:   }
   # CHECK: }
 
+
 @run
 def testNamedSequenceOp(module: Module):
     module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
@@ -160,12 +161,12 @@ def testNamedSequenceOp(module: Module):
         "__transform_main",
         [transform.AnyOpType.get()],
         [transform.AnyOpType.get()],
-    )
+        arg_attrs = [{"transform.consumed": UnitAttr.get()}])
     with InsertionPoint(named_sequence.body):
         transform.YieldOp([named_sequence.bodyTarget])
     # CHECK-LABEL: TEST: testNamedSequenceOp
     # CHECK: module attributes {transform.with_named_sequence} {
-    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
     # CHECK:   yield %[[ARG0]] : !transform.any_op
 
 


        


More information about the Mlir-commits mailing list