[Mlir-commits] [mlir] 5967375 - [mlir][python] Add support for arg_attrs and other attrs to NamedSequenceOp
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 8 05:42:27 PST 2023
Author: Nicolas Vasilache
Date: 2023-11-08T13:42:16Z
New Revision: 5967375fcf3563b74aa7ffef45adb642b514c115
URL: https://github.com/llvm/llvm-project/commit/5967375fcf3563b74aa7ffef45adb642b514c115
DIFF: https://github.com/llvm/llvm-project/commit/5967375fcf3563b74aa7ffef45adb642b514c115.diff
LOG: [mlir][python] Add support for arg_attrs and other attrs to NamedSequenceOp
Added:
Modified:
mlir/python/mlir/dialects/transform/__init__.py
mlir/test/python/dialects/transform.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 23b278d374332b5..1dca1a66bc35707 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -172,11 +172,17 @@ def __init__(
sym_name,
input_types: Sequence[Type],
result_types: Sequence[Type],
+ sym_visibility=None,
+ arg_attrs=None,
+ res_attrs=None
):
function_type = FunctionType.get(input_types, result_types)
super().__init__(
sym_name=sym_name,
function_type=TypeAttr.get(function_type),
+ sym_visibility=sym_visibility,
+ arg_attrs=arg_attrs,
+ res_attrs=res_attrs
)
self.regions[0].blocks.append(*input_types)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 8212739c04a8777..6ed4818fc9d2f0e 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -153,6 +153,7 @@ def testTransformPDLOps(module: Module):
# CHECK: }
# CHECK: }
+
@run
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
@@ -160,12 +161,12 @@ def testNamedSequenceOp(module: Module):
"__transform_main",
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
- )
+ arg_attrs = [{"transform.consumed": UnitAttr.get()}])
with InsertionPoint(named_sequence.body):
transform.YieldOp([named_sequence.bodyTarget])
# CHECK-LABEL: TEST: testNamedSequenceOp
# CHECK: module attributes {transform.with_named_sequence} {
- # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
+ # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
# CHECK: yield %[[ARG0]] : !transform.any_op
More information about the Mlir-commits
mailing list