[Mlir-commits] [mlir] dac19b4 - [mlir][linalg][transform][python] Add mix-in for MaskedVectorize.

Ingo Müller llvmlistbot at llvm.org
Wed Aug 16 08:08:45 PDT 2023


Author: Ingo Müller
Date: 2023-08-16T15:07:46Z
New Revision: dac19b457e2cfd139e0e5cc29872ba3c65b7510f

URL: https://github.com/llvm/llvm-project/commit/dac19b457e2cfd139e0e5cc29872ba3c65b7510f
DIFF: https://github.com/llvm/llvm-project/commit/dac19b457e2cfd139e0e5cc29872ba3c65b7510f.diff

LOG: [mlir][linalg][transform][python] Add mix-in for MaskedVectorize.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D157735

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index e34451af429c1e..de5161eb19b167 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -11,19 +11,67 @@
 
 from typing import List, Optional, Sequence, Tuple, Union, overload
 
+StaticIntLike = Union[int, IntegerAttr]
+ValueLike = Union[Operation, OpView, Value]
+MixedInt = Union[StaticIntLike, ValueLike]
+
 IntOrAttrList = Sequence[Union[IntegerAttr, int]]
 OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
 
 BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
 OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
 
-MixedValues = Union[
-    Sequence[Union[int, IntegerAttr, Operation, Value, OpView]],
-    ArrayAttr,
-    Operation,
-    Value,
-    OpView,
-]
+MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
+
+DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
+
+
+def _dispatch_dynamic_index_list(
+    indices: Union[DynamicIndexList, ArrayAttr],
+) -> tuple[list[ValueLike], list[int] | ArrayAttr, list[bool]]:
+    """Dispatches a list of indices to the appropriate form.
+
+    This is similar to the custom `DynamicIndexList` directive upstream:
+    provided indices may be in the form of dynamic SSA values or static values,
+    and they may be scalable (i.e., as a singleton list) or not. This function
+    dispatches each index into its respective form. It also extracts the SSA
+    values and static indices from various similar structures, respectively.
+    """
+    dynamic_indices = []
+    static_indices = [ShapedType.get_dynamic_size()] * len(indices)
+    scalable_indices = [False] * len(indices)
+
+    # ArrayAttr: Extract index values.
+    if isinstance(indices, ArrayAttr):
+        indices = [idx for idx in indices]
+
+    def process_nonscalable_index(i, index):
+        """Processes any form of non-scalable index.
+
+        Returns False if the given index was scalable and thus remains
+        unprocessed; True otherwise.
+        """
+        if isinstance(index, int):
+            static_indices[i] = index
+        elif isinstance(index, IntegerAttr):
+            static_indices[i] = index.value  # pytype: disable=attribute-error
+        elif isinstance(index, (Operation, Value, OpView)):
+            dynamic_indices.append(index)
+        else:
+            return False
+        return True
+
+    # Process each index at a time.
+    for i, index in enumerate(indices):
+        if not process_nonscalable_index(i, index):
+            # If it wasn't processed, it must be a scalable index, which is
+            # provided as a Sequence of one value, so extract and process that.
+            scalable_indices[i] = True
+            assert len(index) == 1
+            ret = process_nonscalable_index(i, index[0])
+            assert ret
+
+    return dynamic_indices, static_indices, scalable_indices
 
 
 # Dispatches `MixedValues` that all represents integers in various forms into
@@ -281,6 +329,43 @@ def __init__(
         )
 
 
+class MaskedVectorizeOp:
+    """Specialization for MaskedVectorizeOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, OpView, Value],
+        vector_sizes: Union[DynamicIndexList, ArrayAttr],
+        *,
+        vectorize_nd_extract: Optional[bool] = None,
+        scalable_sizes: OptionalBoolList = None,
+        static_vector_sizes: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        if scalable_sizes is None and static_vector_sizes is None:
+            (
+                dynamic_vector_sizes,
+                static_vector_sizes,
+                scalable_sizes,
+            ) = _dispatch_dynamic_index_list(vector_sizes)
+        elif scalable_sizes is None or static_vector_sizes is None:
+            raise TypeError(
+                "'scalable_sizes' and 'static_vector_sizes' must either both "
+                "be given explicitly or both be given as part of 'vector_sizes'."
+            )
+        else:
+            dynamic_vector_sizes = vector_sizes
+
+        super().__init__(
+            target,
+            vector_sizes=dynamic_vector_sizes,
+            static_vector_sizes=static_vector_sizes,
+            scalable_sizes=scalable_sizes,
+            vectorize_nd_extract=vectorize_nd_extract,
+        )
+
+
 class MatchOp:
     """Specialization for MatchOp class."""
 

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index e9ad3f0e8fde46..12422ba29e2fef 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -199,6 +199,85 @@ def testMatchOpNamesList():
     # CHECK-SAME: (!transform.any_op) -> !transform.any_op
 
 
+ at run
+def testMaskedVectorizeStatic():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMaskedVectorizeStatic
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.masked_vectorize
+    # CHECK-SAME:     vector_sizes [16, 4]
+
+
+ at run
+def testMaskedVectorizeArray():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        sizes = Attribute.parse("[16, 4]")
+        structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMaskedVectorizeArray
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.masked_vectorize
+    # CHECK-SAME:     vector_sizes [16, 4]
+
+
+ at run
+def testMaskedVectorizeMixed():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
+        sz2 = Attribute.parse("4")
+        structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMaskedVectorizeMixed
+    # CHECK: transform.sequence
+    # CHECK: %[[V0:.*]] = transform.structured.match
+    # CHECK: transform.structured.masked_vectorize
+    # CHECK-SAME:     vector_sizes [%[[V0]] : !transform.any_op, 4]
+
+
+ at run
+def testMaskedVectorizeScalable():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
+        sz2 = Attribute.parse("4")
+        structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMaskedVectorizeScalable
+    # CHECK: transform.sequence
+    # CHECK-DAG: %[[V0:.*]] = transform.structured.match
+    # CHECK-DAG: transform.structured.masked_vectorize
+    # CHECK-SAME:     vector_sizes [16, [%[[V0]] : !transform.any_op], [4], [8]]
+
+
+ at run
+def testMaskedVectorizeArgs():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MaskedVectorizeOp(
+            sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMaskedVectorizeArgs
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.masked_vectorize
+    # CHECK-SAME: vectorize_nd_extract
+
+
 @run
 def testMatchOpNamesTyped():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list