[Mlir-commits] [mlir] [mlir][vector] Remove unneeded maks restriction (PR #113742)
Jacques Pienaar
llvmlistbot at llvm.org
Fri Oct 25 20:45:09 PDT 2024
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/113742
>From 8a43c900d0c9cefbddf23e50dd2a85ad06323301 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sat, 26 Oct 2024 03:26:52 +0000
Subject: [PATCH] [mlir][vector] Remove unneeded restrictions
These were created when the only mapping was to LLVM, rather than some intrinsic property of the op.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 52 +++++++++++--------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 +--
mlir/test/Dialect/Vector/invalid.mlir | 4 +-
3 files changed, 35 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..e859270cf9a5e5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1819,17 +1819,17 @@ def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$pass_thru)>,
- Results<(outs VectorOfRank<[1]>:$result)> {
+ VectorOf<[I1]>:$mask,
+ AnyVector:$pass_thru)>,
+ Results<(outs AnyVector:$result)> {
let summary = "loads elements from memory into a vector as defined by a mask vector";
let description = [{
- The masked load reads elements from memory into a 1-D vector as defined
- by a base with indices and a 1-D mask vector. When the mask is set, the
+ The masked load reads elements from memory into a vector as defined
+ by a base with indices and a mask vector. When the mask is set, the
element is read from memory. Otherwise, the corresponding element is taken
- from a 1-D pass-through vector. Informally the semantics are:
+ from a pass-through vector. Informally the semantics are:
```
result[0] := if mask[0] then base[i + 0] else pass_thru[0]
result[1] := if mask[1] then base[i + 1] else pass_thru[1]
@@ -1882,14 +1882,14 @@ def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$valueToStore)> {
+ VectorOf<[I1]>:$mask,
+ AnyVector:$valueToStore)> {
let summary = "stores elements from a vector into memory as defined by a mask vector";
let description = [{
- The masked store operation writes elements from a 1-D vector into memory
- as defined by a base with indices and a 1-D mask vector. When the mask is
+ The masked store operation writes elements from a vector into memory
+ as defined by a base with indices and a mask vector. When the mask is
set, the corresponding element from the vector is written to memory. Otherwise,
no action is taken for the element. Informally the semantics are:
```
@@ -2076,23 +2076,26 @@ def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$pass_thru)>,
- Results<(outs VectorOfRank<[1]>:$result)> {
+ VectorOf<[I1]>:$mask,
+ AnyVector:$pass_thru)>,
+ Results<(outs AnyVector:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
let description = [{
- The expand load reads elements from memory into a 1-D vector as defined
- by a base with indices and a 1-D mask vector. When the mask is set, the
- next element is read from memory. Otherwise, the corresponding element
- is taken from a 1-D pass-through vector. Informally the semantics are:
+ The expand load reads elements from memory into a vector as defined by a
+ base with indices and a mask vector. Expansion only applies to the innermost
+ dimension. When the mask is set, the next element is read from memory.
+ Otherwise, the corresponding element is taken from a pass-through vector.
+ Informally the semantics are:
+
```
index = i
result[0] := if mask[0] then base[index++] else pass_thru[0]
result[1] := if mask[1] then base[index++] else pass_thru[1]
etc.
```
+
Note that the index increment is done conditionally.
If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2140,22 +2143,25 @@ def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$valueToStore)> {
+ VectorOf<[I1]>:$mask,
+ AnyVector:$valueToStore)> {
let summary = "writes elements selectively from a vector as defined by a mask";
let description = [{
- The compress store operation writes elements from a 1-D vector into memory
- as defined by a base with indices and a 1-D mask vector. When the mask is
- set, the corresponding element from the vector is written next to memory.
- Otherwise, no action is taken for the element. Informally the semantics are:
+ The compress store operation writes elements from a vector into memory as
+ defined by a base with indices and a mask vector. Compression only applies
+ to the innermost dimension. When the mask is set, the corresponding element
+ from the vector is written next to memory. Otherwise, no action is taken
+ for the element. Informally the semantics are:
+
```
index = i
if (mask[0]) base[index++] = value[0]
if (mask[1]) base[index++] = value[1]
etc.
```
+
Note that the index increment is done conditionally.
If a mask bit is set and the corresponding index is out-of-bounds for the
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..d71a236f62f454 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4977,8 +4977,8 @@ LogicalResult MaskedLoadOp::verify() {
return emitOpError("base and result element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
- if (resVType.getDimSize(0) != maskVType.getDimSize(0))
- return emitOpError("expected result dim to match mask dim");
+ if (resVType.getShape() != maskVType.getShape())
+ return emitOpError("expected result shape to match mask shape");
if (resVType != passVType)
return emitOpError("expected pass_thru of same type as result type");
return success();
@@ -5030,8 +5030,8 @@ LogicalResult MaskedStoreOp::verify() {
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
- if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
- return emitOpError("expected valueToStore dim to match mask dim");
+ if (valueVType.getShape() != maskVType.getShape())
+ return emitOpError("expected valueToStore shape to match mask shape");
return success();
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..5b0fb537b35655 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1356,7 +1356,7 @@ func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16x
func.func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.maskedload' op expected result dim to match mask dim}}
+ // expected-error at +1 {{'vector.maskedload' op expected result shape to match mask shape}}
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1387,7 +1387,7 @@ func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16
func.func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
+ // expected-error at +1 {{'vector.maskedstore' op expected valueToStore shape to match mask shape}}
vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
}
More information about the Mlir-commits
mailing list