[Mlir-commits] [mlir] dac19b4 - [mlir][linalg][transform][python] Add mix-in for MaskedVectorize.
Ingo Müller
llvmlistbot at llvm.org
Wed Aug 16 08:08:45 PDT 2023
Author: Ingo Müller
Date: 2023-08-16T15:07:46Z
New Revision: dac19b457e2cfd139e0e5cc29872ba3c65b7510f
URL: https://github.com/llvm/llvm-project/commit/dac19b457e2cfd139e0e5cc29872ba3c65b7510f
DIFF: https://github.com/llvm/llvm-project/commit/dac19b457e2cfd139e0e5cc29872ba3c65b7510f.diff
LOG: [mlir][linalg][transform][python] Add mix-in for MaskedVectorize.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157735
Added:
Modified:
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index e34451af429c1e..de5161eb19b167 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -11,19 +11,67 @@
from typing import List, Optional, Sequence, Tuple, Union, overload
+StaticIntLike = Union[int, IntegerAttr]
+ValueLike = Union[Operation, OpView, Value]
+MixedInt = Union[StaticIntLike, ValueLike]
+
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
-MixedValues = Union[
- Sequence[Union[int, IntegerAttr, Operation, Value, OpView]],
- ArrayAttr,
- Operation,
- Value,
- OpView,
-]
+MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
+
+DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
+
+
+def _dispatch_dynamic_index_list(
+ indices: Union[DynamicIndexList, ArrayAttr],
+) -> tuple[list[ValueLike], list[int] | ArrayAttr, list[bool]]:
+ """Dispatches a list of indices to the appropriate form.
+
+ This is similar to the custom `DynamicIndexList` directive upstream:
+ provided indices may be in the form of dynamic SSA values or static values,
+ and they may be scalable (i.e., as a singleton list) or not. This function
+ dispatches each index into its respective form. It also extracts the SSA
+ values and static indices from various similar structures, respectively.
+ """
+ dynamic_indices = []
+ static_indices = [ShapedType.get_dynamic_size()] * len(indices)
+ scalable_indices = [False] * len(indices)
+
+ # ArrayAttr: Extract index values.
+ if isinstance(indices, ArrayAttr):
+ indices = [idx for idx in indices]
+
+ def process_nonscalable_index(i, index):
+ """Processes any form of non-scalable index.
+
+ Returns False if the given index was scalable and thus remains
+ unprocessed; True otherwise.
+ """
+ if isinstance(index, int):
+ static_indices[i] = index
+ elif isinstance(index, IntegerAttr):
+ static_indices[i] = index.value # pytype: disable=attribute-error
+ elif isinstance(index, (Operation, Value, OpView)):
+ dynamic_indices.append(index)
+ else:
+ return False
+ return True
+
+ # Process each index at a time.
+ for i, index in enumerate(indices):
+ if not process_nonscalable_index(i, index):
+ # If it wasn't processed, it must be a scalable index, which is
+ # provided as a Sequence of one value, so extract and process that.
+ scalable_indices[i] = True
+ assert len(index) == 1
+ ret = process_nonscalable_index(i, index[0])
+ assert ret
+
+ return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
@@ -281,6 +329,43 @@ def __init__(
)
+class MaskedVectorizeOp:
+ """Specialization for MaskedVectorizeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ vector_sizes: Union[DynamicIndexList, ArrayAttr],
+ *,
+ vectorize_nd_extract: Optional[bool] = None,
+ scalable_sizes: OptionalBoolList = None,
+ static_vector_sizes: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ if scalable_sizes is None and static_vector_sizes is None:
+ (
+ dynamic_vector_sizes,
+ static_vector_sizes,
+ scalable_sizes,
+ ) = _dispatch_dynamic_index_list(vector_sizes)
+ elif scalable_sizes is None or static_vector_sizes is None:
+ raise TypeError(
+ "'scalable_sizes' and 'static_vector_sizes' must either both "
+ "be given explicitly or both be given as part of 'vector_sizes'."
+ )
+ else:
+ dynamic_vector_sizes = vector_sizes
+
+ super().__init__(
+ target,
+ vector_sizes=dynamic_vector_sizes,
+ static_vector_sizes=static_vector_sizes,
+ scalable_sizes=scalable_sizes,
+ vectorize_nd_extract=vectorize_nd_extract,
+ )
+
+
class MatchOp:
"""Specialization for MatchOp class."""
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index e9ad3f0e8fde46..12422ba29e2fef 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -199,6 +199,85 @@ def testMatchOpNamesList():
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
+ at run
+def testMaskedVectorizeStatic():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMaskedVectorizeStatic
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.masked_vectorize
+ # CHECK-SAME: vector_sizes [16, 4]
+
+
+ at run
+def testMaskedVectorizeArray():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ sizes = Attribute.parse("[16, 4]")
+ structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMaskedVectorizeArray
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.masked_vectorize
+ # CHECK-SAME: vector_sizes [16, 4]
+
+
+ at run
+def testMaskedVectorizeMixed():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
+ sz2 = Attribute.parse("4")
+ structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMaskedVectorizeMixed
+ # CHECK: transform.sequence
+ # CHECK: %[[V0:.*]] = transform.structured.match
+ # CHECK: transform.structured.masked_vectorize
+ # CHECK-SAME: vector_sizes [%[[V0]] : !transform.any_op, 4]
+
+
+ at run
+def testMaskedVectorizeScalable():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
+ sz2 = Attribute.parse("4")
+ structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMaskedVectorizeScalable
+ # CHECK: transform.sequence
+ # CHECK-DAG: %[[V0:.*]] = transform.structured.match
+ # CHECK-DAG: transform.structured.masked_vectorize
+ # CHECK-SAME: vector_sizes [16, [%[[V0]] : !transform.any_op], [4], [8]]
+
+
+ at run
+def testMaskedVectorizeArgs():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+ )
+ with InsertionPoint(sequence.body):
+ structured.MaskedVectorizeOp(
+ sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMaskedVectorizeArgs
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.masked_vectorize
+ # CHECK-SAME: vectorize_nd_extract
+
+
@run
def testMatchOpNamesTyped():
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list