[Mlir-commits] [mlir] [mlir][vector] Remove unneeded maks restriction (PR #113742)

Jacques Pienaar llvmlistbot at llvm.org
Fri Oct 25 20:45:09 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/113742

>From 8a43c900d0c9cefbddf23e50dd2a85ad06323301 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sat, 26 Oct 2024 03:26:52 +0000
Subject: [PATCH] [mlir][vector] Remove unneeded restrictions

These were created when the only mapping was to LLVM, rather than some intrinsic property of the op.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 52 +++++++++++--------
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  8 +--
 mlir/test/Dialect/Vector/invalid.mlir         |  4 +-
 3 files changed, 35 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..e859270cf9a5e5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1819,17 +1819,17 @@ def Vector_MaskedLoadOp :
   Vector_Op<"maskedload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$pass_thru)>,
-    Results<(outs VectorOfRank<[1]>:$result)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$pass_thru)>,
+    Results<(outs AnyVector:$result)> {
 
   let summary = "loads elements from memory into a vector as defined by a mask vector";
 
   let description = [{
-    The masked load reads elements from memory into a 1-D vector as defined
-    by a base with indices and a 1-D mask vector. When the mask is set, the
+    The masked load reads elements from memory into a vector as defined
+    by a base with indices and a mask vector. When the mask is set, the
     element is read from memory. Otherwise, the corresponding element is taken
-    from a 1-D pass-through vector. Informally the semantics are:
+    from a pass-through vector. Informally the semantics are:
     ```
     result[0] := if mask[0] then base[i + 0] else pass_thru[0]
     result[1] := if mask[1] then base[i + 1] else pass_thru[1]
@@ -1882,14 +1882,14 @@ def Vector_MaskedStoreOp :
   Vector_Op<"maskedstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$valueToStore)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
   let description = [{
-    The masked store operation writes elements from a 1-D vector into memory
-    as defined by a base with indices and a 1-D mask vector. When the mask is
+    The masked store operation writes elements from a vector into memory
+    as defined by a base with indices and a mask vector. When the mask is
     set, the corresponding element from the vector is written to memory. Otherwise,
     no action is taken for the element. Informally the semantics are:
     ```
@@ -2076,23 +2076,26 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$pass_thru)>,
-    Results<(outs VectorOfRank<[1]>:$result)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$pass_thru)>,
+    Results<(outs AnyVector:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
   let description = [{
-    The expand load reads elements from memory into a 1-D vector as defined
-    by a base with indices and a 1-D mask vector. When the mask is set, the
-    next element is read from memory. Otherwise, the corresponding element
-    is taken from a 1-D pass-through vector. Informally the semantics are:
+    The expand load reads elements from memory into a vector as defined by a
+    base with indices and a mask vector. Expansion only applies to the innermost
+    dimension. When the mask is set, the next element is read from memory.
+    Otherwise, the corresponding element is taken from a pass-through vector.
+    Informally the semantics are:
+
     ```
     index = i
     result[0] := if mask[0] then base[index++] else pass_thru[0]
     result[1] := if mask[1] then base[index++] else pass_thru[1]
     etc.
     ```
+
     Note that the index increment is done conditionally.
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2140,22 +2143,25 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOf<[I1]>:$mask,
+               AnyVector:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
   let description = [{
-    The compress store operation writes elements from a 1-D vector into memory
-    as defined by a base with indices and a 1-D mask vector. When the mask is
-    set, the corresponding element from the vector is written next to memory.
-    Otherwise, no action is taken for the element. Informally the semantics are:
+    The compress store operation writes elements from a vector into memory as
+    defined by a base with indices and a mask vector. Compression only applies
+    to the innermost dimension. When the mask is set, the corresponding element
+    from the vector is written next to memory.  Otherwise, no action is taken
+    for the element. Informally the semantics are:
+
     ```
     index = i
     if (mask[0]) base[index++] = value[0]
     if (mask[1]) base[index++] = value[1]
     etc.
     ```
+
     Note that the index increment is done conditionally.
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..d71a236f62f454 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4977,8 +4977,8 @@ LogicalResult MaskedLoadOp::verify() {
     return emitOpError("base and result element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
-    return emitOpError("expected result dim to match mask dim");
+  if (resVType.getShape() != maskVType.getShape())
+    return emitOpError("expected result shape to match mask shape");
   if (resVType != passVType)
     return emitOpError("expected pass_thru of same type as result type");
   return success();
@@ -5030,8 +5030,8 @@ LogicalResult MaskedStoreOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
-    return emitOpError("expected valueToStore dim to match mask dim");
+  if (valueVType.getShape() != maskVType.getShape())
+    return emitOpError("expected valueToStore shape to match mask shape");
   return success();
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..5b0fb537b35655 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1356,7 +1356,7 @@ func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16x
 
 func.func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedload' op expected result dim to match mask dim}}
+  // expected-error at +1 {{'vector.maskedload' op expected result shape to match mask shape}}
   %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1387,7 +1387,7 @@ func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16
 
 func.func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
+  // expected-error at +1 {{'vector.maskedstore' op expected valueToStore shape to match mask shape}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
 }
 



More information about the Mlir-commits mailing list