[Mlir-commits] [mlir] [mlir][vector] Fold poison operands into vector.shuffle mask (PR #190932)
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 8 01:53:24 PDT 2026
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/190932
Fold poison operands into the `vector.shuffle` mask. This commit also splits up the `vector::ShuffleOp::fold` implementation into multiple helper functions.
>From ca437cbcb10c8ae98372d89f491cfeb7fefd58b2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 8 Apr 2026 08:51:51 +0000
Subject: [PATCH] [mlir][vector] Fold poison operands into vector.shuffle mask
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 107 +++++++++++++-----
mlir/test/Dialect/Vector/canonicalize.mlir | 10 ++
...tract-to-matrix-intrinsics-transforms.mlir | 4 +-
.../Dialect/XeGPU/xegpu-vector-linearize.mlir | 4 +-
4 files changed, 95 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d866d4011a09..08338ed90f00b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3289,41 +3289,66 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
});
}
-OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
- auto v1Type = getV1VectorType();
- auto v2Type = getV2VectorType();
-
- assert(!v1Type.isScalable() && !v2Type.isScalable() &&
- "Vector shuffle does not support scalable vectors");
+/// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
+/// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
+static OpFoldResult foldShuffleIdentityMask(ShuffleOp op) {
+ auto v1Type = op.getV1VectorType();
+ auto v2Type = op.getV2VectorType();
+ auto mask = op.getMask();
+ if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
+ return op.getV1();
+ if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
+ return op.getV2();
+ return {};
+}
- // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
- // but must be a canonicalization into a vector.broadcast.
- if (v1Type.getRank() == 0)
+/// If a shuffle operand is poison, replace all mask indices that reference it
+/// with kPoisonIndex. This is an in-place fold.
+static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op) {
+ bool isV1Poison = matchPattern(op.getV1(), ub::m_Poison());
+ bool isV2Poison = matchPattern(op.getV2(), ub::m_Poison());
+ if (!isV1Poison && !isV2Poison)
return {};
- // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
- auto mask = getMask();
- if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
- return getV1();
- // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
- if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
- return getV2();
+ int64_t v1Size = op.getV1VectorType().getDimSize(0);
+ bool changed = false;
+ SmallVector<int64_t> newMask = llvm::to_vector(op.getMask());
+ for (int64_t &idx : newMask) {
+ if (idx == ShuffleOp::kPoisonIndex)
+ continue;
+ if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
+ idx = ShuffleOp::kPoisonIndex;
+ changed = true;
+ }
+ }
- Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
- if (!v1Attr || !v2Attr)
+ if (!changed)
return {};
- // Fold shuffle poison, poison -> poison.
- bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
- bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
- if (isV1Poison && isV2Poison)
- return ub::PoisonAttr::get(getContext());
+ op.setMask(newMask);
+ return op.getResult();
+}
+
+/// Fold shuffle poison, poison -> poison.
+static OpFoldResult foldShufflePoisonInputs(MLIRContext *context,
+ Attribute v1Attr,
+ Attribute v2Attr) {
+ if (matchPattern(v1Attr, ub::m_Poison()) &&
+ matchPattern(v2Attr, ub::m_Poison()))
+ return ub::PoisonAttr::get(context);
+ return {};
+}
- // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
- // manipulation.
+/// Fold a shuffle of constant 1-D inputs by evaluating the mask.
+static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr,
+ Attribute v2Attr) {
+ auto v1Type = op.getV1VectorType();
if (v1Type.getRank() != 1)
return {};
+ bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
+ bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
+
// Poison input attributes need special handling as they are not
// DenseElementsAttr. If an index is poison, we select the first element of
// the first non-poison input.
@@ -3344,6 +3369,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
poisonElement = v1Elements[0];
}
+ ArrayRef<int64_t> mask = op.getMask();
SmallVector<Attribute> results;
int64_t v1Size = v1Type.getDimSize(0);
for (int64_t maskIdx : mask) {
@@ -3361,7 +3387,36 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
results.push_back(indexedElm);
}
- return DenseElementsAttr::get(getResultVectorType(), results);
+ return DenseElementsAttr::get(op.getResultVectorType(), results);
+}
+
+OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
+ auto v1Type = getV1VectorType();
+ auto v2Type = getV2VectorType();
+
+ assert(!v1Type.isScalable() && !v2Type.isScalable() &&
+ "Vector shuffle does not support scalable vectors");
+
+ // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
+ // but must be a canonicalization into a vector.broadcast.
+ if (v1Type.getRank() == 0)
+ return {};
+
+ if (auto res = foldShuffleIdentityMask(*this))
+ return res;
+ if (auto res = foldShufflePoisonOperandToMask(*this))
+ return res;
+
+ Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
+ if (!v1Attr || !v2Attr)
+ return {};
+
+ if (auto res = foldShufflePoisonInputs(getContext(), v1Attr, v2Attr))
+ return res;
+ if (auto res = foldShuffleConstantInputs(*this, v1Attr, v2Attr))
+ return res;
+
+ return {};
}
namespace {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 71a18f796c27d..b5655c547f4b1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2819,6 +2819,16 @@ func.func @shuffle_poison_unused(%1: vector<2xi32>) -> vector<4xi32> {
// -----
+// CHECK-LABEL: @fold_poison_into_mask
+// CHECK: vector.shuffle %{{.*}}, %{{.*}} [-1, -1, 2, -1] : vector<2xi32>, vector<2xi32>
+func.func @fold_poison_into_mask(%1: vector<2xi32>) -> vector<4xi32> {
+ %0 = ub.poison : vector<2xi32>
+ %r = vector.shuffle %0, %1 [0, 1, 2, 1] : vector<2xi32>, vector<2xi32>
+ return %r : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_splatlike_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index bf4f094263545..11ca563c82d3f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -28,7 +28,7 @@
//
// CHECK: %[[LHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[TP_1:.*]] = llvm.shufflevector %[[LHS_ROW_1]], %[[LHS_ROW_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>
-// CHECK: %[[TP_2:.*]] = llvm.shufflevector %[[TP_1]], %[[POISON_LHS]] [0, 1, 2, 3, 12, 13, 14, 15] : vector<8xf32>
+// CHECK: %[[TP_2:.*]] = llvm.shufflevector %[[TP_1]], %[[POISON_LHS]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<8xf32>
// CHECK: %[[LHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[TP_3:.*]] = llvm.shufflevector %[[LHS_ROW_2]], %[[LHS_ROW_2]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>
// CHECK: %[[LHS:.*]] = llvm.shufflevector %[[TP_3]], %[[TP_2]] [8, 9, 10, 11, 0, 1, 2, 3] : vector<8xf32>
@@ -43,7 +43,7 @@
// | ROW_4 |
// CHECK: %[[RHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[TP_4:.*]] = llvm.shufflevector %[[RHS_ROW_1]], %[[RHS_ROW_1]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
-// CHECK: %[[TP_5:.*]] = llvm.shufflevector %[[TP_4]], %[[POISON_RHS]] [0, 1, 2, 15, 16, 17, 18, 19, 20, 21, 22, 23] : vector<12xf32>
+// CHECK: %[[TP_5:.*]] = llvm.shufflevector %[[TP_4]], %[[POISON_RHS]] [0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1] : vector<12xf32>
// CHECK: %[[RHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[TP_6:.*]] = llvm.shufflevector %[[RHS_ROW_2]], %[[RHS_ROW_2]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32>
// CHECK: %[[TP_7:.*]] = llvm.shufflevector %[[TP_6]], %[[TP_5]] [12, 13, 14, 0, 1, 2, 18, 19, 20, 21, 22, 23] : vector<12xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 1dcec16f7ad52..1f74b20819b11 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -110,7 +110,7 @@ func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<1x4xf32> to vector<4xf32>
-// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
+// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1] : vector<12xf32>, vector<4xf32>
// CHECK: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[CAST]] [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
// CHECK: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[CAST]] [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] : vector<12xf32>, vector<4xf32>
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<12xf32> to vector<3x4xf32>
@@ -188,7 +188,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// … (similar checks for the rest of row 0, then row 1)
-// CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, 3, 4, 5]
+// CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, -1, -1, -1]
// CHECK: %[[ROW1_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[POISON]] [3, 4, 5]
// Row 1 if ladder checks
More information about the Mlir-commits
mailing list