[Mlir-commits] [mlir] eae5ca9 - [mlir][Vector] Support poison in `vector.shuffle` mask (#122188)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 18 15:16:55 PST 2025
Author: Diego Caballero
Date: 2025-01-18T15:16:51-08:00
New Revision: eae5ca9b45bf1232f30d92ce50c19c1ea82c0f0b
URL: https://github.com/llvm/llvm-project/commit/eae5ca9b45bf1232f30d92ce50c19c1ea82c0f0b
DIFF: https://github.com/llvm/llvm-project/commit/eae5ca9b45bf1232f30d92ce50c19c1ea82c0f0b.diff
LOG: [mlir][Vector] Support poison in `vector.shuffle` mask (#122188)
This PR extends the existing poison support in
https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask
values in `vector.shuffle`. Similar to LLVM (see
https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884)
this requires defining an integer value (`-1`) to represent poison in
the `vector.shuffle` mask.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 30a5b06374fad1..4331eda1661960 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -434,10 +434,9 @@ def Vector_ShuffleOp
The shuffle operation constructs a permutation (or duplication) of elements
from two input vectors, returning a vector with the same element type as
the input and a length that is the same as the shuffle mask. The two input
- vectors must have the same element type, same rank , and trailing dimension
- sizes and shuffles their values in the
- leading dimension (which may
diff er in size) according to the given mask.
- The legality rules are:
+ vectors must have the same element type, same rank, and trailing dimension
+ sizes and shuffles their values in the leading dimension (which may
diff er
+ in size) according to the given mask. The legality rules are:
* the two operands must have the same element type as the result
- Either, the two operands and the result must have the same
rank and trailing dimension sizes, viz. given two k-D operands
@@ -448,7 +447,9 @@ def Vector_ShuffleOp
* the mask length equals the leading dimension size of the result
* numbering the input vector indices left to right across the operands, all
mask values must be within range, viz. given two k-D operands v1 and v2
- above, all mask values are in the range [0,s_1+t_1)
+ above, all mask values are in the range [0,s_1+t_1). The value `-1`
+ represents a poison mask value, which specifies that the selected element
+ is poison.
Note, scalable vectors are not supported.
@@ -463,10 +464,15 @@ def Vector_ShuffleOp
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
%3 = vector.shuffle %a, %b[0, 1]
: vector<f32>, vector<f32> ; yields vector<2xf32>
+ %4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1]
+ : vector<4xf32>, vector<4xf32> ; yields vector<6xf32>
```
}];
let extraClassDeclaration = [{
+ // Integer to represent a poison value in a vector shuffle mask.
+ static constexpr int64_t kMaskPoisonValue = -1;
+
VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
}
@@ -700,6 +706,8 @@ def Vector_ExtractOp :
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
```
+
+ TODO: Implement support for poison indices.
}];
let arguments = (ins
@@ -890,6 +898,8 @@ def Vector_InsertOp :
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
```
+
+ TODO: Implement support for poison indices.
}];
let arguments = (ins
@@ -980,6 +990,8 @@ def Vector_ScalableInsertOp :
```mlir
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
```
+
+ TODO: Implement support for poison indices.
}];
let assemblyFormat = [{
@@ -1031,6 +1043,8 @@ def Vector_ScalableExtractOp :
```mlir
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
```
+
+ TODO: Implement support for poison indices.
}];
let assemblyFormat = [{
@@ -1075,6 +1089,8 @@ def Vector_InsertStridedSliceOp :
{offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
```
+
+ TODO: Implement support for poison indices.
}];
let assemblyFormat = [{
@@ -1220,6 +1236,8 @@ def Vector_ExtractStridedSliceOp :
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
vector<4x8x16xf32> to vector<2x4x16xf32>
```
+
+ TODO: Implement support for poison indices.
}];
let builders = [
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$offsets,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..696d1e0f9b1e68 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() {
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
- if (maskPos < 0 || maskPos >= indexSize)
+ if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..931cc36c9d4a88 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex
// -----
+func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> {
+ %1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32>
+ return %1 : vector<4xf32>
+}
+// CHECK-LABEL: @shuffle_poison_mask(
+// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>)
+// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32>
+
+// -----
+
func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
%1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
return %1 : vector<5xf32>
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 103148633bf97c..fd73cea5e4f306 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
// -----
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
+// CHECK: return %[[SHUFFLE]] : vector<4xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @interleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 961f1b5ffeabec..cd6f3f518a1c07 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32
return %1 : vector<3x4xf32>
}
+// CHECK-LABEL: @shuffle_poison_mask
+func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
+ %1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
+ return %1 : vector<4xf32>
+}
+
// CHECK-LABEL: @extract_element_0d
func.func @extract_element_0d(%a: vector<f32>) -> f32 {
// CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
More information about the Mlir-commits
mailing list