[Mlir-commits] [mlir] bb00f5b - [mlir][vector] Remove unneeded mask restriction (#113742)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 25 20:45:48 PDT 2024


Author: Jacques Pienaar
Date: 2024-10-25T20:45:44-07:00
New Revision: bb00f5b1edd0ed77b7e7b0113dad223505564b18

URL: https://github.com/llvm/llvm-project/commit/bb00f5b1edd0ed77b7e7b0113dad223505564b18
DIFF: https://github.com/llvm/llvm-project/commit/bb00f5b1edd0ed77b7e7b0113dad223505564b18.diff

LOG: [mlir][vector] Remove unneeded mask restriction (#113742)

These were added when the only mapping was to LLVM.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..e859270cf9a5e5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1819,17 +1819,17 @@ def Vector_MaskedLoadOp :
   Vector_Op<"maskedload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$pass_thru)>,
-    Results<(outs VectorOfRank<[1]>:$result)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$pass_thru)>,
+    Results<(outs AnyVector:$result)> {
 
   let summary = "loads elements from memory into a vector as defined by a mask vector";
 
   let description = [{
-    The masked load reads elements from memory into a 1-D vector as defined
-    by a base with indices and a 1-D mask vector. When the mask is set, the
+    The masked load reads elements from memory into a vector as defined
+    by a base with indices and a mask vector. When the mask is set, the
     element is read from memory. Otherwise, the corresponding element is taken
-    from a 1-D pass-through vector. Informally the semantics are:
+    from a pass-through vector. Informally the semantics are:
     ```
     result[0] := if mask[0] then base[i + 0] else pass_thru[0]
     result[1] := if mask[1] then base[i + 1] else pass_thru[1]
@@ -1882,14 +1882,14 @@ def Vector_MaskedStoreOp :
   Vector_Op<"maskedstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$valueToStore)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
   let description = [{
-    The masked store operation writes elements from a 1-D vector into memory
-    as defined by a base with indices and a 1-D mask vector. When the mask is
+    The masked store operation writes elements from a vector into memory
+    as defined by a base with indices and a mask vector. When the mask is
     set, the corresponding element from the vector is written to memory. Otherwise,
     no action is taken for the element. Informally the semantics are:
     ```
@@ -2076,23 +2076,26 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$pass_thru)>,
-    Results<(outs VectorOfRank<[1]>:$result)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$pass_thru)>,
+    Results<(outs AnyVector:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
   let description = [{
-    The expand load reads elements from memory into a 1-D vector as defined
-    by a base with indices and a 1-D mask vector. When the mask is set, the
-    next element is read from memory. Otherwise, the corresponding element
-    is taken from a 1-D pass-through vector. Informally the semantics are:
+    The expand load reads elements from memory into a vector as defined by a
+    base with indices and a mask vector. Expansion only applies to the innermost
+    dimension. When the mask is set, the next element is read from memory.
+    Otherwise, the corresponding element is taken from a pass-through vector.
+    Informally the semantics are:
+
     ```
     index = i
     result[0] := if mask[0] then base[index++] else pass_thru[0]
     result[1] := if mask[1] then base[index++] else pass_thru[1]
     etc.
     ```
+
     Note that the index increment is done conditionally.
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2140,22 +2143,25 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
   let description = [{
-    The compress store operation writes elements from a 1-D vector into memory
-    as defined by a base with indices and a 1-D mask vector. When the mask is
-    set, the corresponding element from the vector is written next to memory.
-    Otherwise, no action is taken for the element. Informally the semantics are:
+    The compress store operation writes elements from a vector into memory as
+    defined by a base with indices and a mask vector. Compression only applies
+    to the innermost dimension. When the mask is set, the corresponding element
+    from the vector is written next to memory.  Otherwise, no action is taken
+    for the element. Informally the semantics are:
+
     ```
     index = i
     if (mask[0]) base[index++] = value[0]
     if (mask[1]) base[index++] = value[1]
     etc.
     ```
+
     Note that the index increment is done conditionally.
 
     If a mask bit is set and the corresponding index is out-of-bounds for the

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..d71a236f62f454 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4977,8 +4977,8 @@ LogicalResult MaskedLoadOp::verify() {
     return emitOpError("base and result element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
-    return emitOpError("expected result dim to match mask dim");
+  if (resVType.getShape() != maskVType.getShape())
+    return emitOpError("expected result shape to match mask shape");
   if (resVType != passVType)
     return emitOpError("expected pass_thru of same type as result type");
   return success();
@@ -5030,8 +5030,8 @@ LogicalResult MaskedStoreOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
-    return emitOpError("expected valueToStore dim to match mask dim");
+  if (valueVType.getShape() != maskVType.getShape())
+    return emitOpError("expected valueToStore shape to match mask shape");
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..5b0fb537b35655 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1356,7 +1356,7 @@ func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16x
 
 func.func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedload' op expected result dim to match mask dim}}
+  // expected-error at +1 {{'vector.maskedload' op expected result shape to match mask shape}}
   %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1387,7 +1387,7 @@ func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16
 
 func.func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
+  // expected-error at +1 {{'vector.maskedstore' op expected valueToStore shape to match mask shape}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
 }
 


        


More information about the Mlir-commits mailing list