[Mlir-commits] [mlir] 8fd207f - [mlir][transform][structured][python] Allow str arg in match_op_names.
Ingo Müller
llvmlistbot at llvm.org
Fri Jul 21 02:36:59 PDT 2023
Author: Ingo Müller
Date: 2023-07-21T09:36:55Z
New Revision: 8fd207fd0dcc398c2fcfd953d7e3ebe7cb53f188
URL: https://github.com/llvm/llvm-project/commit/8fd207fd0dcc398c2fcfd953d7e3ebe7cb53f188
DIFF: https://github.com/llvm/llvm-project/commit/8fd207fd0dcc398c2fcfd953d7e3ebe7cb53f188.diff
LOG: [mlir][transform][structured][python] Allow str arg in match_op_names.
Allow the `names` argument in `MatchOp.match_op_names` to be of type
`str` in addition to `Sequence[str]`. In this case, the argument is
treated as a list with one name, i.e., it is possible to write
`MatchOp.match_op_names(..., "test.dummy")` instead of
`MatchOp.match_op_names(..., ["test.dummy"])`.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D155807
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 1936f4b0e0da7e..9f623efb500173 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -195,7 +195,7 @@ class MatchOp:
def match_op_names(
cls,
target: Union[Operation, Value],
- names: Sequence[str],
+ names: Union[str, Sequence[str]],
*,
loc=None,
ip=None,
@@ -208,7 +208,7 @@ def match_op_names(
cls,
result_type: Type,
target: Union[Operation, Value],
- names: Sequence[str],
+ names: Union[str, Sequence[str]],
*,
loc=None,
ip=None,
@@ -219,8 +219,8 @@ def match_op_names(
def match_op_names(
cls,
result_type_or_target: Union[Type, Operation, Value],
- target_or_names: Union[Operation, Value, Sequence[str]],
- names_or_none: Optional[Sequence[str]] = None,
+ target_or_names: Union[Operation, Value, Sequence[str], str],
+ names_or_none: Optional[Union[Sequence[str], str]] = None,
*,
loc=None,
ip=None,
@@ -234,6 +234,9 @@ def match_op_names(
target = result_type_or_target
names = target_or_names
+ if isinstance(names, str):
+ names = [names]
+
return cls(
result_type,
_get_op_result_or_value(target),
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 0bcfd81d75ff11..1da55edf777e00 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -97,14 +97,28 @@ def testInterchange():
@run
-def testMatchOpNames():
+def testMatchOpNamesString():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMatchOpNamesString
+ # CHECK: transform.structured.match ops
+ # CHECK-SAME: ["test.dummy"]
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+ at run
+def testMatchOpNamesList():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
transform.YieldOp()
- # CHECK-LABEL: TEST: testMatchOpNames
+ # CHECK-LABEL: TEST: testMatchOpNamesList
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
More information about the Mlir-commits
mailing list