[Mlir-commits] [mlir] 38098b4 - [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors (#117538)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 29 08:13:12 PST 2024


Author: Andrzej WarzyƄski
Date: 2024-11-29T16:13:09Z
New Revision: 38098b486e44ad077b674e512eee399fc6f5a30c

URL: https://github.com/llvm/llvm-project/commit/38098b486e44ad077b674e512eee399fc6f5a30c
DIFF: https://github.com/llvm/llvm-project/commit/38098b486e44ad077b674e512eee399fc6f5a30c.diff

LOG: [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors (#117538)

These operations were introduced as counterparts to the following LLVM
intrinsics:

  * `@llvm.masked.expandload.*`,
  * `@llvm.masked.compressstore.*`.

Currently, there is minimal test coverage for scalable vector use cases
involving these Ops (both LLVM and MLIR). Additionally, the verifier is
flawed  - it incorrectly allows mixing fixed-width and scalable vectors.

To address these issues, scalable vector support for these Ops is being
disabled for now. This decision can be revisited if a clear need arises
for their use with scalable vectors in the future.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a1c2dad8c2b8b3..d35847034cb125 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2086,7 +2086,7 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfNonZeroRankOf<[I1]>:$mask,
+               FixedVectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$pass_thru)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
@@ -2119,6 +2119,8 @@ def Vector_ExpandLoadOp :
     correspond to those of the `llvm.masked.expandload`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
 
+    Note, at the moment this Op is only available for fixed-width vectors.
+
     Examples:
 
     ```mlir
@@ -2153,7 +2155,7 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfNonZeroRankOf<[I1]>:$mask,
+               FixedVectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
@@ -2185,6 +2187,8 @@ def Vector_CompressStoreOp :
     correspond to those of the `llvm.masked.compressstore`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
 
+    Note, at the moment this Op is only available for fixed-width vectors.
+
     Examples:
 
     ```mlir

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index fc4383d08422cb..7db095d0ae5af6 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,6 +24,9 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                                          CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+                                              CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
@@ -432,6 +435,10 @@ class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
                       "::mlir::VectorType">;
 
+class FixedVectorOfNonZeroRankOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfNonZeroRankTypePred,
+                      "fixed-length vector", "::mlir::VectorType">;
+
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 class VectorOfAnyRankOf<list<Type> allowedTypes> :
@@ -660,6 +667,8 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
 // Unlike the following definitions, this one excludes 0-D vectors
 def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
 
+def AnyFixedVectorOfNonZeroRank : FixedVectorOfNonZeroRankOf<[AnyType]>;
+
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
 def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c244fe7df3e94b..9f7efa15ed5207 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
 
 // -----
 
+func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+}
+
+// -----
+
 func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
 
 // -----
 
+func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+}
+
+// -----
+
 func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}


        


More information about the Mlir-commits mailing list