[Mlir-commits] [mlir] [mlir][vector] Replace unused shuffle operands / results with poison (PR #190763)
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 7 02:54:43 PDT 2026
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/190763
If a shuffle operand is not used, replace it with `ub.poison`. This may make the value dead and enable additional DCE. Also replace the entire shuffle op with `ub.poison` if all selected values are poisoned.
>From 668bf98bff31c9863a220cef4f33c22698ccd804 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 09:51:15 +0000
Subject: [PATCH] [mlir][vector] Replace unused shuffle operands / results with
poison
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 49 ++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 70 +++++++++++++++++++
.../Dialect/XeGPU/xegpu-vector-linearize.mlir | 10 +--
3 files changed, 123 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b1..0e4d3013c496e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3472,12 +3472,57 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
}
};
+/// Pattern to replace usused shuffle operands / results with poison.
+class FoldUnusedShuffleOperand final : public OpRewritePattern<ShuffleOp> {
+public:
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace with poison if all mask elements are poison.
+ if (llvm::all_of(op.getMask(), [](int64_t mask) {
+ return mask == ShuffleOp::kPoisonIndex;
+ })) {
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType());
+ return success();
+ }
+
+ // Check if elements from V1 / V2 are used.
+ int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
+ ? op.getV1VectorType().getDimSize(0)
+ : 1;
+ bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
+ return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
+ });
+ bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
+ return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
+ });
+
+ // Replace V1 with poison if it is not used.
+ if (!isV1Used && !op.getV1().getDefiningOp<ub::PoisonOp>()) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, op.getLoc(), op.getV1VectorType());
+ rewriter.modifyOpInPlace(op, [&]() { op.getV1Mutable().assign(poison); });
+ return success();
+ }
+
+ // Replace V2 with poison if it is not used.
+ if (!isV2Used && !op.getV2().getDefiningOp<ub::PoisonOp>()) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, op.getLoc(), op.getV2VectorType());
+ rewriter.modifyOpInPlace(op, [&]() { op.getV2Mutable().assign(poison); });
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
- context);
+ results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
+ FoldUnusedShuffleOperand>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index acde86d20d346..e7fded4846811 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2736,6 +2736,76 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
// -----
+// All mask elements are poison: replace shuffle with poison.
+// CHECK-LABEL: func @shuffle_all_poison_mask
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[P:.*]] = ub.poison : vector<2xi32>
+// CHECK: return %[[P]]
+func.func @shuffle_all_poison_mask(%v0 : vector<3xi32>, %v1 : vector<3xi32>) -> vector<2xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [-1, -1] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<2xi32>
+}
+
+// -----
+
+// V1 is unused: replace V1 operand with poison.
+// CHECK-LABEL: func @shuffle_unused_v1
+// CHECK-SAME: %[[A:.*]]: vector<3xi32>, %[[B:.*]]: vector<3xi32>
+// CHECK: %[[P:.*]] = ub.poison : vector<3xi32>
+// CHECK: vector.shuffle %[[P]], %[[B]] [4, 3] : vector<3xi32>, vector<3xi32>
+func.func @shuffle_unused_v1(%v0 : vector<3xi32>, %v1 : vector<3xi32>) -> vector<2xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [4, 3] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<2xi32>
+}
+
+// -----
+
+// V2 is unused: replace V2 operand with poison.
+// CHECK-LABEL: func @shuffle_unused_v2
+// CHECK-SAME: %[[A:.*]]: vector<3xi32>, %[[B:.*]]: vector<3xi32>
+// CHECK: %[[P:.*]] = ub.poison : vector<3xi32>
+// CHECK: vector.shuffle %[[A]], %[[P]] [2, 0] : vector<3xi32>, vector<3xi32>
+func.func @shuffle_unused_v2(%v0 : vector<3xi32>, %v1 : vector<3xi32>) -> vector<2xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [2, 0] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<2xi32>
+}
+
+// -----
+
+// V1 is unused (mask has poison indices mixed with V2 references).
+// CHECK-LABEL: func @shuffle_unused_v1_with_poison_idx
+// CHECK-SAME: %[[A:.*]]: vector<3xi32>, %[[B:.*]]: vector<3xi32>
+// CHECK: %[[P:.*]] = ub.poison : vector<3xi32>
+// CHECK: vector.shuffle %[[P]], %[[B]] [4, -1, 3] : vector<3xi32>, vector<3xi32>
+func.func @shuffle_unused_v1_with_poison_idx(%v0 : vector<3xi32>, %v1 : vector<3xi32>) -> vector<3xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [4, -1, 3] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<3xi32>
+}
+
+// -----
+
+// V2 is unused (multidimensional vectors).
+// CHECK-LABEL: func @shuffle_unused_v2_multidim
+// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>, %[[B:.*]]: vector<3x2xf32>
+// CHECK: %[[P:.*]] = ub.poison : vector<3x2xf32>
+// CHECK: vector.shuffle %[[A]], %[[P]] [2, 0, 3] : vector<4x2xf32>, vector<3x2xf32>
+func.func @shuffle_unused_v2_multidim(%v0 : vector<4x2xf32>, %v1 : vector<3x2xf32>) -> vector<3x2xf32> {
+ %shuffle = vector.shuffle %v0, %v1 [2, 0, 3] : vector<4x2xf32>, vector<3x2xf32>
+ return %shuffle : vector<3x2xf32>
+}
+
+// -----
+
+// Both operands are used: no folding.
+// CHECK-LABEL: func @shuffle_both_operands_used
+// CHECK: vector.shuffle %arg0, %arg1 [0, 3, 1, 4] : vector<3xi32>, vector<3xi32>
+func.func @shuffle_both_operands_used(%v0 : vector<3xi32>, %v1 : vector<3xi32>) -> vector<4xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 3, 1, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_splatlike_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 94205a6c26ba2..1dcec16f7ad52 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -17,8 +17,9 @@ func.func @test_vector_insert_2d_idx(%arg0: vector<2x8x4xf32>, %arg1: vector<4xf
// -----
// CHECK-LABEL: test_vector_transpose
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32>
+// CHECK: %[[POISON:.*]] = ub.poison : vector<16xf32>
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32>
-// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[POISON]]
// CHECK: [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<16xf32>, vector<16xf32>
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
// CHECK: return %[[RES]] : vector<8x2xf32>
@@ -163,7 +164,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// CHECK: %[[PASS_CAST:.*]] = vector.shape_cast %[[PASS]] : vector<2x3xf32> to vector<6xf32>
// First shuffle + if ladder for row 0
-// CHECK: %[[ROW0_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[PASS_CAST]] [0, 1, 2]
+// CHECK: %[[ROW0_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[POISON]] [0, 1, 2]
// CHECK: %[[DIM0:.*]] = memref.dim %[[BASE]], %[[C0]]
// CHECK: %[[DIM1:.*]] = memref.dim %[[BASE]], %[[C1]]
// CHECK: %[[MASK_0_0:.*]] = vector.extract %[[MASK]][0, 0]
@@ -188,7 +189,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// … (similar checks for the rest of row 0, then row 1)
// CHECK: %[[ROW_SHUFFLE:.*]] = vector.shuffle %[[POISON]], {{.*}} [6, 7, 8, 3, 4, 5]
-// CHECK: %[[ROW1_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[PASS_CAST]] [3, 4, 5]
+// CHECK: %[[ROW1_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[POISON]] [3, 4, 5]
// Row 1 if ladder checks
// CHECK: %[[MASK_1_0:.*]] = vector.extract %[[MASK]][1, 0]
@@ -215,6 +216,7 @@ func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask
// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops.
// CHECK: gpu.func @test_kernel(%[[A:.*]]: memref<8x16xf16>, %[[B:.*]]: memref<16x16xf16>, %[[C:.*]]: memref<8x16xf32>) kernel {
+// CHECK: %[[POISON_F32:.*]] = ub.poison : vector<128xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CST_A:.*]] = arith.constant dense<0.000000e+00> : vector<64xf16>
// CHECK: %[[CST_C:.*]] = arith.constant dense<5.000000e+00> : vector<64xf32>
@@ -233,7 +235,7 @@ func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask
// CHECK: %[[DPAS:.*]] = xegpu.dpas %[[A_RESULT]], %[[B_RESULT]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK: %[[DPAS_CAST:.*]] = vector.shape_cast %[[DPAS]] : vector<8x16xf32> to vector<128xf32>
-// CHECK: %[[EXTRACT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]], %[[DPAS_CAST]] {{.*}} : vector<128xf32>, vector<128xf32>
+// CHECK: %[[EXTRACT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]], %[[POISON_F32]] {{.*}} : vector<128xf32>, vector<128xf32>
// CHECK: %[[ADDF:.*]] = arith.addf %[[EXTRACT_SHUFFLE]], %[[CST_C]] : vector<64xf32>
// CHECK: %[[INSERT_SHUFFLE:.*]] = vector.shuffle %[[DPAS_CAST]], %[[ADDF]] {{.*}} : vector<128xf32>, vector<64xf32>
// CHECK: %[[C_RESULT:.*]] = vector.shape_cast %[[INSERT_SHUFFLE]] : vector<128xf32> to vector<8x16xf32>
More information about the Mlir-commits
mailing list