[Mlir-commits] [mlir] [mlir][vector] Fold poison operands into vector.shuffle mask (PR #190932)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 8 01:53:58 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Fold poison operands into the `vector.shuffle` mask. This commit also splits up the `vector::ShuffleOp::fold` implementation into multiple helper functions.


---
Full diff: https://github.com/llvm/llvm-project/pull/190932.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+81-26) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+10) 
- (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+2-2) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir (+2-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d866d4011a09..08338ed90f00b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3289,41 +3289,66 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
          });
 }
 
-OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
-  auto v1Type = getV1VectorType();
-  auto v2Type = getV2VectorType();
-
-  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
-         "Vector shuffle does not support scalable vectors");
+/// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
+/// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
+static OpFoldResult foldShuffleIdentityMask(ShuffleOp op) {
+  auto v1Type = op.getV1VectorType();
+  auto v2Type = op.getV2VectorType();
+  auto mask = op.getMask();
+  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
+    return op.getV1();
+  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
+    return op.getV2();
+  return {};
+}
 
-  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
-  // but must be a canonicalization into a vector.broadcast.
-  if (v1Type.getRank() == 0)
+/// If a shuffle operand is poison, replace all mask indices that reference it
+/// with kPoisonIndex. This is an in-place fold.
+static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op) {
+  bool isV1Poison = matchPattern(op.getV1(), ub::m_Poison());
+  bool isV2Poison = matchPattern(op.getV2(), ub::m_Poison());
+  if (!isV1Poison && !isV2Poison)
     return {};
 
-  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
-  auto mask = getMask();
-  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
-    return getV1();
-  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
-  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
-    return getV2();
+  int64_t v1Size = op.getV1VectorType().getDimSize(0);
+  bool changed = false;
+  SmallVector<int64_t> newMask = llvm::to_vector(op.getMask());
+  for (int64_t &idx : newMask) {
+    if (idx == ShuffleOp::kPoisonIndex)
+      continue;
+    if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
+      idx = ShuffleOp::kPoisonIndex;
+      changed = true;
+    }
+  }
 
-  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
-  if (!v1Attr || !v2Attr)
+  if (!changed)
     return {};
 
-  // Fold shuffle poison, poison -> poison.
-  bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
-  bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
-  if (isV1Poison && isV2Poison)
-    return ub::PoisonAttr::get(getContext());
+  op.setMask(newMask);
+  return op.getResult();
+}
+
+/// Fold shuffle poison, poison -> poison.
+static OpFoldResult foldShufflePoisonInputs(MLIRContext *context,
+                                            Attribute v1Attr,
+                                            Attribute v2Attr) {
+  if (matchPattern(v1Attr, ub::m_Poison()) &&
+      matchPattern(v2Attr, ub::m_Poison()))
+    return ub::PoisonAttr::get(context);
+  return {};
+}
 
-  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
-  // manipulation.
+/// Fold a shuffle of constant 1-D inputs by evaluating the mask.
+static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr,
+                                              Attribute v2Attr) {
+  auto v1Type = op.getV1VectorType();
   if (v1Type.getRank() != 1)
     return {};
 
+  bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
+  bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
+
   // Poison input attributes need special handling as they are not
   // DenseElementsAttr. If an index is poison, we select the first element of
   // the first non-poison input.
@@ -3344,6 +3369,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
     poisonElement = v1Elements[0];
   }
 
+  ArrayRef<int64_t> mask = op.getMask();
   SmallVector<Attribute> results;
   int64_t v1Size = v1Type.getDimSize(0);
   for (int64_t maskIdx : mask) {
@@ -3361,7 +3387,36 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
     results.push_back(indexedElm);
   }
 
-  return DenseElementsAttr::get(getResultVectorType(), results);
+  return DenseElementsAttr::get(op.getResultVectorType(), results);
+}
+
+OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
+  auto v1Type = getV1VectorType();
+  auto v2Type = getV2VectorType();
+
+  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
+         "Vector shuffle does not support scalable vectors");
+
+  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
+  // but must be a canonicalization into a vector.broadcast.
+  if (v1Type.getRank() == 0)
+    return {};
+
+  if (auto res = foldShuffleIdentityMask(*this))
+    return res;
+  if (auto res = foldShufflePoisonOperandToMask(*this))
+    return res;
+
+  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
+  if (!v1Attr || !v2Attr)
+    return {};
+
+  if (auto res = foldShufflePoisonInputs(getContext(), v1Attr, v2Attr))
+    return res;
+  if (auto res = foldShuffleConstantInputs(*this, v1Attr, v2Attr))
+    return res;
+
+  return {};
 }
 
 namespace {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 71a18f796c27d..b5655c547f4b1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2819,6 +2819,16 @@ func.func @shuffle_poison_unused(%1: vector<2xi32>) -> vector<4xi32> {
 
 // -----
 
+// CHECK-LABEL: @fold_poison_into_mask
+//       CHECK:   vector.shuffle %{{.*}}, %{{.*}} [-1, -1, 2, -1] : vector<2xi32>, vector<2xi32>
+func.func @fold_poison_into_mask(%1: vector<2xi32>) -> vector<4xi32> {
+  %0 = ub.poison : vector<2xi32>
+  %r = vector.shuffle %0, %1 [0, 1, 2, 1] : vector<2xi32>, vector<2xi32>
+  return %r : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @transpose_splatlike_constant
 //       CHECK:   %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
 //       CHECK:   return %[[CST]]
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index bf4f094263545..11ca563c82d3f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -28,7 +28,7 @@
 //
 // CHECK:           %[[LHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<2 x vector<4xf32>>
 // CHECK:           %[[TP_1:.*]] = llvm.shufflevector %[[LHS_ROW_1]], %[[LHS_ROW_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>
-// CHECK:           %[[TP_2:.*]] = llvm.shufflevector %[[TP_1]], %[[POISON_LHS]] [0, 1, 2, 3, 12, 13, 14, 15] : vector<8xf32>
+// CHECK:           %[[TP_2:.*]] = llvm.shufflevector %[[TP_1]], %[[POISON_LHS]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<8xf32>
 // CHECK:           %[[LHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.array<2 x vector<4xf32>>
 // CHECK:           %[[TP_3:.*]] = llvm.shufflevector %[[LHS_ROW_2]], %[[LHS_ROW_2]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>
 // CHECK:           %[[LHS:.*]] = llvm.shufflevector %[[TP_3]], %[[TP_2]] [8, 9, 10, 11, 0, 1, 2, 3] : vector<8xf32>
@@ -43,7 +43,7 @@
 //       | ROW_4 |
 // CHECK:           %[[RHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.array<4 x vector<3xf32>>
 // CHECK:           %[[TP_4:.*]] = llvm.shufflevector %[[RHS_ROW_1]], %[[RHS_ROW_1]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
-// CHECK:           %[[TP_5:.*]] = llvm.shufflevector %[[TP_4]], %[[POISON_RHS]] [0, 1, 2, 15, 16, 17, 18, 19, 20, 21, 22, 23] : vector<12xf32>
+// CHECK:           %[[TP_5:.*]] = llvm.shufflevector %[[TP_4]], %[[POISON_RHS]] [0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1] : vector<12xf32>
 // CHECK:           %[[RHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.array<4 x vector<3xf32>>
 // CHECK:           %[[TP_6:.*]] = llvm.shufflevector %[[RHS_ROW_2]], %[[RHS_ROW_2]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
 // CHECK:           %[[TP_7:.*]] = llvm.shufflevector %[[TP_6]], %[[TP_5]] [12, 13, 14, 0, 1, 2, 18, 19, 20, 21, 22, 23] : vector<12xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 1dcec16f7ad52..1f74b20819b11 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -110,7 +110,7 @@ func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>
 // CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>) -> vector<3x4xf32>
 // CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32>
 // CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<1x4xf32> to vector<4xf32>
-// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
+// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1] : vector<12xf32>, vector<4xf32>
 // CHECK: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[CAST]] [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
 // CHECK: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[CAST]] [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] : vector<12xf32>, vector<4xf32>
 // CHECK: %[[RESULT:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<12xf32> to vector<3x4xf32>
@@ -188,7 +188,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
 
 // … (similar checks for the rest of row 0, then row 1)
 
-// CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, 3, 4, 5]
+// CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, -1, -1, -1]
 // CHECK: %[[ROW1_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[POISON]] [3, 4, 5]
 
 // Row 1 if ladder checks

``````````

</details>


https://github.com/llvm/llvm-project/pull/190932


More information about the Mlir-commits mailing list