[Mlir-commits] [mlir] andrzej/disable compress expand (PR #117538)

Andrzej Warzyński llvmlistbot at llvm.org
Mon Nov 25 02:52:27 PST 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/117538

- **[mlir][vector] Rename vector type TD definitions (nfc)**
- **[mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors**


>From 55b8bfba942531ea7b3fd2c5987b80374ff2e5bc Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 21 Nov 2024 10:59:10 +0000
Subject: [PATCH 1/2] [mlir][vector] Rename vector type TD definitions (nfc)

Currently, the Vector dialect TD file includes the following "vector"
type definitions:

```mlir
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

def AnyFixedVector : FixedVectorOf<[AnyType]>;

def AnyScalableVector : ScalableVectorOf<[AnyType]>;
```

In short:

  * `AnyVector` _excludes_ 0-D vectors.
  * `AnyVectorOfAnyRank`, `AnyFixedVector`, and `AnyScalableVector`
    _include_ 0-D vectors.

The naming for "groups" that include 0-D vectors is inconsistent and can
be misleading. This patch renames the definitions as follows:

```mlir
def AnyVector : VectorOf<[AnyType]>;

def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;

def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
```

Rationale:
* The updated names are more explicit about 0-D vector support.
* It becomes clearer that scalable vectors currently allow 0-D vectors -
  this might warrant a revisit.
* The renaming paves the way for adding a new group for "fixed-width
  vectors excluding 0-D vectors" (e.g., AnyFixedVector), which I plan to
  introduce in a follow-up patch.
---
 mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td  |  2 +-
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 54 +++++++++----------
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 22 ++++----
 mlir/include/mlir/IR/CommonTypeConstraints.td | 35 ++++++------
 4 files changed, 58 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
-  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
   "a vector with length " # length,
   "::mlir::VectorType">;
 
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[AnyFloat]>:$src1,
-          ScalableVectorOf<[AnyFloat]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src2
   );
-  let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
   );
-  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"add">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fadd">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"mul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fmul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fsub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedUDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ConvertFromSvboolIntrOp :
   ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/2>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
 
 // Note: This multi-vector intrinsic requires SME2.
 def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/4>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2,
-                   Arg<AnyScalableVector, "v3">:$v3,
-                   Arg<AnyScalableVector, "v3">:$v4)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
 
 // Note: This intrinsic requires SME or SVE2.1.
 def PselIntrOp : ArmSVE_IntrOp<"psel",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cc4cafa869e63a..df02b242f51d67 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
   let hasVerifier = 1;
 }
 
-def Vector_ShuffleOp :
-  Vector_Op<"shuffle", [Pure,
-     PredOpTrait<"first operand v1 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"second operand v2 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>,
-     InferTypeOpAdaptor]>,
-     Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
-                    DenseI64ArrayAttr:$mask)>,
-     Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+    : Vector_Op<
+          "shuffle",
+          [Pure,
+           PredOpTrait<"first operand v1 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 0>>,
+           PredOpTrait<"second operand v2 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 1>>,
+           InferTypeOpAdaptor]>,
+      Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+          DenseI64ArrayAttr:$mask)>,
+      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
   let description = [{
     The shuffle operation constructs a permutation (or duplication) of elements
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..874d96e99ec678 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -30,7 +30,7 @@ def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                   !::llvm::cast<VectorType>($_self).isScalable()}]>;
 
 // Whether a type is a scalable VectorType.
@@ -438,11 +438,11 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
 
-class FixedVectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
           "fixed-length vector", "::mlir::VectorType">;
 
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
           "scalable vector", "::mlir::VectorType">;
 
@@ -467,7 +467,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedRanks` list
 class IsFixedVectorOfRankPred<list<int> allowedRanks> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedRanks,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                            == }]
@@ -509,8 +509,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
 // the type is from the given `allowedTypes` list
 class FixedVectorOfRankAndType<list<int> allowedRanks,
                           list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
-  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+  FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
@@ -525,7 +525,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedLengths` list
 class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedLengths,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                            == }]
@@ -612,8 +612,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
                                  list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
-  FixedVectorOf<allowedTypes>.summary #
+  [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+  FixedVectorOfAnyRank<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -621,8 +621,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,
                                     list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
-  ScalableVectorOf<allowedTypes>.summary #
+  [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -632,10 +632,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
 class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
                                            list<int> allowedLengths,
                                            list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
    ScalableVectorOfLength<allowedLengths>],
   ScalableVectorOfRank<allowedRanks>.summary #
-  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -657,13 +657,14 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
    ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
   "::mlir::VectorType">;
 
+// Unlike the following definitions, this one excludes 0-D vectors
 def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
+
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
-def AnyFixedVector : FixedVectorOf<[AnyType]>;
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
 
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
 
 // Shaped types.
 

>From fcc429aded190d262bfd1d22573f07f35307dd6e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 25 Nov 2024 10:42:43 +0000
Subject: [PATCH 2/2] [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for
 scalable vectors
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

These operations were introduced as counterparts to the following LLVM
intrinsics:

  * `@llvm.masked.expandload.*`,
  * `@llvm.masked.compressstore.*`.

Currently, there is minimal test coverage for scalable vector use cases
involving these Ops (both LLVM and MLIR). Additionally, the verifier is
flawed  —it incorrectly allows mixing fixed-width and scalable vectors.

To address these issues, scalable vector support for these Ops is being
disabled for now. This decision can be revisited if a clear need arises
for their use with scalable vectors in the future.

**NOTE:** Depends on #117150 - please, only review the top commit.
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 14 +++++++++-----
 mlir/include/mlir/IR/CommonTypeConstraints.td    |  9 +++++++++
 mlir/test/Dialect/Vector/invalid.mlir            | 16 ++++++++++++++++
 3 files changed, 34 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index df02b242f51d67..5911355abd5146 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2084,9 +2084,9 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$pass_thru)>,
+    Results<(outs AnyFixedVector:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
@@ -2117,6 +2117,8 @@ def Vector_ExpandLoadOp :
     correspond to those of the `llvm.masked.expandload`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
 
+    Note, at the moment this Op is only available for fixed-width vectors.
+
     Examples:
 
     ```mlir
@@ -2151,8 +2153,8 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$valueToStore)> {
+               FixedVectorOf<[I1]>:$mask,
+               AnyFixedVector:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -2183,6 +2185,8 @@ def Vector_CompressStoreOp :
     correspond to those of the `llvm.masked.compressstore`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
 
+    Note that the index increment is done conditionally.
+
     Examples:
 
     ```mlir
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 874d96e99ec678..ef1645117a7280 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,6 +24,9 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                             CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsFixedVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                                 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+                                 CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
@@ -432,6 +435,10 @@ class VectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                       "::mlir::VectorType">;
 
+class FixedVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
+          "fixed-length vector", "::mlir::VectorType">;
+
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 class VectorOfAnyRankOf<list<Type> allowedTypes> :
@@ -660,6 +667,8 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
 // Unlike the following definitions, this one excludes 0-D vectors
 def AnyVector : VectorOf<[AnyType]>;
 
+def AnyFixedVector : FixedVectorOf<[AnyType]>;
+
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
 def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 0c093b0ccff141..ae336d4f5ddb8d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 // -----
 
+func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+}
+
+// -----
+
 func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
 
 // -----
 
+func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+}
+
+// -----
+
 func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}



More information about the Mlir-commits mailing list