[Mlir-commits] [mlir] eae5ca9 - [mlir][Vector] Support poison in `vector.shuffle` mask (#122188)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 18 15:16:55 PST 2025


Author: Diego Caballero
Date: 2025-01-18T15:16:51-08:00
New Revision: eae5ca9b45bf1232f30d92ce50c19c1ea82c0f0b

URL: https://github.com/llvm/llvm-project/commit/eae5ca9b45bf1232f30d92ce50c19c1ea82c0f0b
DIFF: https://github.com/llvm/llvm-project/commit/eae5ca9b45bf1232f30d92ce50c19c1ea82c0f0b.diff

LOG: [mlir][Vector] Support poison in `vector.shuffle` mask (#122188)

This PR extends the existing poison support in
https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask
values in `vector.shuffle`. Similar to LLVM (see
https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884)
this requires defining an integer value (`-1`) to represent poison in
the `vector.shuffle` mask.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 30a5b06374fad1..4331eda1661960 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -434,10 +434,9 @@ def Vector_ShuffleOp
     The shuffle operation constructs a permutation (or duplication) of elements
     from two input vectors, returning a vector with the same element type as
     the input and a length that is the same as the shuffle mask. The two input
-    vectors must have the same element type, same rank , and trailing dimension
-    sizes and shuffles their values in the
-    leading dimension (which may 
diff er in size) according to the given mask.
-    The legality rules are:
+    vectors must have the same element type, same rank, and trailing dimension
+    sizes and shuffles their values in the leading dimension (which may 
diff er
+    in size) according to the given mask. The legality rules are:
     * the two operands must have the same element type as the result
       - Either, the two operands and the result must have the same
         rank and trailing dimension sizes, viz. given two k-D operands
@@ -448,7 +447,9 @@ def Vector_ShuffleOp
     * the mask length equals the leading dimension size of the result
     * numbering the input vector indices left to right across the operands, all
       mask values must be within range, viz. given two k-D operands v1 and v2
-      above, all mask values are in the range [0,s_1+t_1)
+      above, all mask values are in the range [0,s_1+t_1). The value `-1`
+      represents a poison mask value, which specifies that the selected element
+      is poison.
 
     Note, scalable vectors are not supported.
 
@@ -463,10 +464,15 @@ def Vector_ShuffleOp
                : vector<2xf32>, vector<2xf32>       ; yields vector<4xf32>
     %3 = vector.shuffle %a, %b[0, 1]
                : vector<f32>, vector<f32>           ; yields vector<2xf32>
+    %4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1]
+               : vector<4xf32>, vector<4xf32>       ; yields vector<6xf32>
     ```
   }];
 
   let extraClassDeclaration = [{
+    // Integer to represent a poison value in a vector shuffle mask.
+    static constexpr int64_t kMaskPoisonValue = -1;
+
     VectorType getV1VectorType() {
       return ::llvm::cast<VectorType>(getV1().getType());
     }
@@ -700,6 +706,8 @@ def Vector_ExtractOp :
     %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
     %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
     ```
+
+    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -890,6 +898,8 @@ def Vector_InsertOp :
     %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
     %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
     ```
+
+    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -980,6 +990,8 @@ def Vector_ScalableInsertOp :
     ```mlir
     %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
     ```
+
+    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
@@ -1031,6 +1043,8 @@ def Vector_ScalableExtractOp :
     ```mlir
     %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
     ```
+
+    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
@@ -1075,6 +1089,8 @@ def Vector_InsertStridedSliceOp :
         {offsets = [0, 0, 2], strides = [1, 1]}:
       vector<2x4xf32> into vector<16x4x8xf32>
     ```
+
+    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
@@ -1220,6 +1236,8 @@ def Vector_ExtractStridedSliceOp :
     %1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
       vector<4x8x16xf32> to vector<2x4x16xf32>
     ```
+
+    TODO: Implement support for poison indices.
   }];
   let builders = [
     OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$offsets,

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..696d1e0f9b1e68 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() {
   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
   for (auto [idx, maskPos] : llvm::enumerate(mask)) {
-    if (maskPos < 0 || maskPos >= indexSize)
+    if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
       return emitOpError("mask index #") << (idx + 1) << " out of range";
   }
   return success();

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..931cc36c9d4a88 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex
 
 // -----
 
+func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> {
+  %1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32>
+  return %1 : vector<4xf32>
+}
+// CHECK-LABEL: @shuffle_poison_mask(
+//  CHECK-SAME:   %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>)
+//       CHECK:     %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32>
+
+// -----
+
 func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
   %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
   return %1 : vector<5xf32>

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 103148633bf97c..fd73cea5e4f306 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
+//       CHECK:    return %[[SHUFFLE]] : vector<4xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @interleave
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 961f1b5ffeabec..cd6f3f518a1c07 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32
   return %1 : vector<3x4xf32>
 }
 
+// CHECK-LABEL: @shuffle_poison_mask
+func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
+  %1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+
 // CHECK-LABEL: @extract_element_0d
 func.func @extract_element_0d(%a: vector<f32>) -> f32 {
   // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>


        


More information about the Mlir-commits mailing list