[Mlir-commits] [mlir] [mlir][Vector] Add patterns to lower `vector.shuffle` (PR #157611)
Diego Caballero
llvmlistbot at llvm.org
Tue Sep 16 17:17:22 PDT 2025
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/157611
>From 94a102049ffb5ecaf388f3f97cca4b0e77e7475c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 9 Sep 2025 03:56:59 +0000
Subject: [PATCH 1/3] [mlir][Vector] Add patterns to lower `vector.shuffle`
This PR adds patterns to lower `vector.shuffle` with inputs with
different vector sizes more efficiently. The current LLVM lowering
for these cases degenerates to a sequence of `vector.extract` and
`vector.insert` operations. With this PR, the smaller input is promoted
to larger vector size by introducing an extra `vector.shuffle`.
---
.../Vector/Transforms/LoweringPatterns.h | 3 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/LowerVectorShuffle.cpp | 106 ++++++++++++++++++
.../Vector/vector-shuffle-lowering.mlir | 77 +++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 18 +++
5 files changed, 205 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
create mode 100644 mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index f56124cb4fb95..b896506f29eef 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -293,6 +293,9 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index d74007f13a95b..8e36ead6993a8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
+ LowerVectorShuffle.cpp
LowerVectorStep.cpp
LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
new file mode 100644
index 0000000000000..0adfd256a8498
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -0,0 +1,106 @@
+//===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of complex `vector.shuffle` operation to a
+// set of simpler operations supported by LLVM/SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-shuffle-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Lowers a `vector.shuffle` operation with mix-size inputs to a new
+/// `vector.shuffle` which promotes the smaller input to the larger vector size
+/// and an updated version of the original `vector.shuffle`.
+///
+/// Example:
+///
+/// %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32>
+///
+/// is lowered to:
+///
+/// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
+/// vector<2xf32>, vector<2xf32>
+/// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
+/// vector<4xf32>, vector<4xf32>
+///
+struct MixSizeInputShuffleOpRewrite final
+ : OpRewritePattern<vector::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
+ PatternRewriter &rewriter) const override {
+ auto v1Type = shuffleOp.getV1VectorType();
+ auto v2Type = shuffleOp.getV2VectorType();
+
+ // Only support 1-D shuffle for now.
+ if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
+ return failure();
+
+ // No mix-size inputs.
+ int64_t v1OrigNumElems = v1Type.getNumElements();
+ int64_t v2OrigNumElems = v2Type.getNumElements();
+ if (v1OrigNumElems == v2OrigNumElems)
+ return failure();
+
+ // Determine which input needs promotion.
+ bool promoteV1 = v1OrigNumElems < v2OrigNumElems;
+ Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2();
+ VectorType promotedType = promoteV1 ? v2Type : v1Type;
+ int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems;
+ int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems;
+
+ // Create a shuffle with a mask that preserves existing elements and fills
+ // up with poison.
+ SmallVector<int64_t> promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex);
+ for (int64_t i = 0; i < origNumElems; ++i)
+ promoteMask[i] = i;
+
+ Value promotedInput = rewriter.create<vector::ShuffleOp>(
+ shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
+ promoteMask);
+
+ // Create the final shuffle with the promoted inputs.
+ Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
+ Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput;
+
+ SmallVector<int64_t> newMask;
+ if (!promoteV1) {
+ newMask = to_vector(shuffleOp.getMask());
+ } else {
+ // Adjust V2 indices to account for the new V1 size.
+ for (auto idx : shuffleOp.getMask()) {
+ int64_t newIdx = idx;
+ if (idx >= v1OrigNumElems) {
+ newIdx += promotedNumElems - v1OrigNumElems;
+ }
+ newMask.push_back(newIdx);
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+ shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2,
+ newMask);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorShuffleLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<MixSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
new file mode 100644
index 0000000000000..3acc9aa9bffa0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @shuffle_v1_smaller_arbitrary
+// CHECK-SAME: %[[V1:.*]]: vector<2xf32>, %[[V2:.*]]: vector<4xf32>
+func.func @shuffle_v1_smaller_arbitrary(%v1: vector<2xf32>, %v2: vector<4xf32>) -> vector<5xf32> {
+ // CHECK: %[[PROMOTE_V1:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_V1]], %[[V2]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32>
+ // CHECK: return %[[RESULT]] : vector<5xf32>
+ %0 = vector.shuffle %v1, %v2 [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32>
+ return %0 : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_v2_smaller_arbitrary
+// CHECK-SAME: %[[V1:.*]]: vector<4xi32>, %[[V2:.*]]: vector<2xi32>
+func.func @shuffle_v2_smaller_arbitrary(%v1: vector<4xi32>, %v2: vector<2xi32>) -> vector<6xi32> {
+ // CHECK: %[[PROMOTE_V2:.*]] = vector.shuffle %[[V2]], %[[V2]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[PROMOTE_V2]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32>
+ // CHECK: return %[[RESULT]] : vector<6xi32>
+ %0 = vector.shuffle %v1, %v2 [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32>
+ return %0 : vector<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_v1_smaller_concat
+// CHECK-SAME: %[[V1:.*]]: vector<3xf64>, %[[V2:.*]]: vector<5xf64>
+func.func @shuffle_v1_smaller_concat(%v1: vector<3xf64>, %v2: vector<5xf64>) -> vector<8xf64> {
+ // CHECK: %[[PROMOTE_V1:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_V1]], %[[V2]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64>
+ // CHECK: return %[[RESULT]] : vector<8xf64>
+ %0 = vector.shuffle %v1, %v2 [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64>
+ return %0 : vector<8xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_v2_smaller_concat
+// CHECK-SAME: %[[V1:.*]]: vector<4xi16>, %[[V2:.*]]: vector<2xi16>
+func.func @shuffle_v2_smaller_concat(%v1: vector<4xi16>, %v2: vector<2xi16>) -> vector<6xi16> {
+ // CHECK: %[[PROMOTE_V2:.*]] = vector.shuffle %[[V2]], %[[V2]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[PROMOTE_V2]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16>
+ // CHECK: return %[[RESULT]] : vector<6xi16>
+ %0 = vector.shuffle %v1, %v2 [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16>
+ return %0 : vector<6xi16>
+}
+
+// -----
+
+// Test that shuffles with same size inputs are not modified.
+
+// CHECK-LABEL: func.func @shuffle_same_input_sizes
+// CHECK-SAME: %[[V1:.*]]: vector<4xf32>, %[[V2:.*]]: vector<4xf32>
+func.func @shuffle_same_input_sizes(%v1: vector<4xf32>, %v2: vector<4xf32>) -> vector<6xf32> {
+ // CHECK-NOT: vector.shuffle %[[V1]], %[[V1]]
+ // CHECK-NOT: vector.shuffle %[[V2]], %[[V2]]
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[V2]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
+ // CHECK: return %[[RESULT]] : vector<6xf32>
+ %0 = vector.shuffle %v1, %v2 [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
+ return %0 : vector<6xf32>
+}
+
+// -----
+
+// Test that multi-dimensional shuffles are not modified.
+
+// CHECK-LABEL: func.func @shuffle_2d_vectors_no_change
+// CHECK-SAME: %[[V1:.*]]: vector<2x4xf32>, %[[V2:.*]]: vector<3x4xf32>
+func.func @shuffle_2d_vectors_no_change(%v1: vector<2x4xf32>, %v2: vector<3x4xf32>) -> vector<4x4xf32> {
+ // CHECK-NOT: vector.shuffle %[[V1]], %[[V1]]
+ // CHECK-NOT: vector.shuffle %[[V2]], %[[V2]]
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[V2]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
+ // CHECK: return %[[RESULT]] : vector<4x4xf32>
+ %0 = vector.shuffle %v1, %v2 [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
+ return %0 : vector<4x4xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 72dd103b33f75..79bfc9bbcda71 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -994,6 +994,22 @@ struct TestEliminateVectorMasks
VscaleRange{vscaleMin, vscaleMax});
}
};
+
+struct TestVectorShuffleLowering
+ : public PassWrapper<TestVectorShuffleLowering,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering)
+
+ StringRef getArgument() const final { return "test-vector-shuffle-lowering"; }
+ StringRef getDescription() const final {
+ return "Test lowering patterns for vector.shuffle with mixed-size inputs";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorShuffleLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorScanLowering>();
+ PassRegistration<TestVectorShuffleLowering>();
+
PassRegistration<TestVectorDistribution>();
PassRegistration<TestVectorExtractStridedSliceLowering>();
>From 76cf1389f4ac1fb553b9bb49135c193fec42e86a Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Mon, 15 Sep 2025 19:05:58 +0000
Subject: [PATCH 2/3] Feedback
---
.../Vector/Transforms/LowerVectorShuffle.cpp | 2 +-
.../Vector/vector-shuffle-lowering.mlir | 76 +++++++++----------
2 files changed, 39 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 0adfd256a8498..8ca74cb49d4c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -51,7 +51,7 @@ struct MixSizeInputShuffleOpRewrite final
if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
return failure();
- // No mix-size inputs.
+ // Bail out if inputs don't have mixed sized.
int64_t v1OrigNumElems = v1Type.getNumElements();
int64_t v2OrigNumElems = v2Type.getNumElements();
if (v1OrigNumElems == v2OrigNumElems)
diff --git a/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
index 3acc9aa9bffa0..a137811fa367c 100644
--- a/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
@@ -1,48 +1,48 @@
// RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s
-// CHECK-LABEL: func.func @shuffle_v1_smaller_arbitrary
-// CHECK-SAME: %[[V1:.*]]: vector<2xf32>, %[[V2:.*]]: vector<4xf32>
-func.func @shuffle_v1_smaller_arbitrary(%v1: vector<2xf32>, %v2: vector<4xf32>) -> vector<5xf32> {
- // CHECK: %[[PROMOTE_V1:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
- // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_V1]], %[[V2]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32>
+// CHECK-LABEL: func.func @shuffle_smaller_lhs_arbitrary
+// CHECK-SAME: %[[LHS:.*]]: vector<2xf32>, %[[RHS:.*]]: vector<4xf32>
+func.func @shuffle_smaller_lhs_arbitrary(%lhs: vector<2xf32>, %rhs: vector<4xf32>) -> vector<5xf32> {
+ // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32>
// CHECK: return %[[RESULT]] : vector<5xf32>
- %0 = vector.shuffle %v1, %v2 [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32>
+ %0 = vector.shuffle %lhs, %rhs [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32>
return %0 : vector<5xf32>
}
// -----
-// CHECK-LABEL: func.func @shuffle_v2_smaller_arbitrary
-// CHECK-SAME: %[[V1:.*]]: vector<4xi32>, %[[V2:.*]]: vector<2xi32>
-func.func @shuffle_v2_smaller_arbitrary(%v1: vector<4xi32>, %v2: vector<2xi32>) -> vector<6xi32> {
- // CHECK: %[[PROMOTE_V2:.*]] = vector.shuffle %[[V2]], %[[V2]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32>
- // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[PROMOTE_V2]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32>
+// CHECK-LABEL: func.func @shuffle_smaller_rhs_arbitrary
+// CHECK-SAME: %[[LHS:.*]]: vector<4xi32>, %[[RHS:.*]]: vector<2xi32>
+func.func @shuffle_smaller_rhs_arbitrary(%lhs: vector<4xi32>, %rhs: vector<2xi32>) -> vector<6xi32> {
+ // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32>
// CHECK: return %[[RESULT]] : vector<6xi32>
- %0 = vector.shuffle %v1, %v2 [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32>
+ %0 = vector.shuffle %lhs, %rhs [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32>
return %0 : vector<6xi32>
}
// -----
-// CHECK-LABEL: func.func @shuffle_v1_smaller_concat
-// CHECK-SAME: %[[V1:.*]]: vector<3xf64>, %[[V2:.*]]: vector<5xf64>
-func.func @shuffle_v1_smaller_concat(%v1: vector<3xf64>, %v2: vector<5xf64>) -> vector<8xf64> {
- // CHECK: %[[PROMOTE_V1:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64>
- // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_V1]], %[[V2]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64>
+// CHECK-LABEL: func.func @shuffle_smaller_lhs_concat
+// CHECK-SAME: %[[LHS:.*]]: vector<3xf64>, %[[RHS:.*]]: vector<5xf64>
+func.func @shuffle_smaller_lhs_concat(%lhs: vector<3xf64>, %rhs: vector<5xf64>) -> vector<8xf64> {
+ // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64>
// CHECK: return %[[RESULT]] : vector<8xf64>
- %0 = vector.shuffle %v1, %v2 [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64>
+ %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64>
return %0 : vector<8xf64>
}
// -----
-// CHECK-LABEL: func.func @shuffle_v2_smaller_concat
-// CHECK-SAME: %[[V1:.*]]: vector<4xi16>, %[[V2:.*]]: vector<2xi16>
-func.func @shuffle_v2_smaller_concat(%v1: vector<4xi16>, %v2: vector<2xi16>) -> vector<6xi16> {
- // CHECK: %[[PROMOTE_V2:.*]] = vector.shuffle %[[V2]], %[[V2]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16>
- // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[PROMOTE_V2]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16>
+// CHECK-LABEL: func.func @shuffle_smaller_rhs_concat
+// CHECK-SAME: %[[LHS:.*]]: vector<4xi16>, %[[RHS:.*]]: vector<2xi16>
+func.func @shuffle_smaller_rhs_concat(%lhs: vector<4xi16>, %rhs: vector<2xi16>) -> vector<6xi16> {
+ // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16>
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16>
// CHECK: return %[[RESULT]] : vector<6xi16>
- %0 = vector.shuffle %v1, %v2 [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16>
+ %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16>
return %0 : vector<6xi16>
}
@@ -50,14 +50,14 @@ func.func @shuffle_v2_smaller_concat(%v1: vector<4xi16>, %v2: vector<2xi16>) ->
// Test that shuffles with same size inputs are not modified.
-// CHECK-LABEL: func.func @shuffle_same_input_sizes
-// CHECK-SAME: %[[V1:.*]]: vector<4xf32>, %[[V2:.*]]: vector<4xf32>
-func.func @shuffle_same_input_sizes(%v1: vector<4xf32>, %v2: vector<4xf32>) -> vector<6xf32> {
- // CHECK-NOT: vector.shuffle %[[V1]], %[[V1]]
- // CHECK-NOT: vector.shuffle %[[V2]], %[[V2]]
- // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[V2]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
+// CHECK-LABEL: func.func @negative_shuffle_same_input_sizes
+// CHECK-SAME: %[[LHS:.*]]: vector<4xf32>, %[[RHS:.*]]: vector<4xf32>
+func.func @negative_shuffle_same_input_sizes(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<6xf32> {
+ // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]]
+ // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]]
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
// CHECK: return %[[RESULT]] : vector<6xf32>
- %0 = vector.shuffle %v1, %v2 [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
+ %0 = vector.shuffle %lhs, %rhs [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
return %0 : vector<6xf32>
}
@@ -65,13 +65,13 @@ func.func @shuffle_same_input_sizes(%v1: vector<4xf32>, %v2: vector<4xf32>) -> v
// Test that multi-dimensional shuffles are not modified.
-// CHECK-LABEL: func.func @shuffle_2d_vectors_no_change
-// CHECK-SAME: %[[V1:.*]]: vector<2x4xf32>, %[[V2:.*]]: vector<3x4xf32>
-func.func @shuffle_2d_vectors_no_change(%v1: vector<2x4xf32>, %v2: vector<3x4xf32>) -> vector<4x4xf32> {
- // CHECK-NOT: vector.shuffle %[[V1]], %[[V1]]
- // CHECK-NOT: vector.shuffle %[[V2]], %[[V2]]
- // CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[V2]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
+// CHECK-LABEL: func.func @negative_shuffle_2d_vectors
+// CHECK-SAME: %[[LHS:.*]]: vector<2x4xf32>, %[[RHS:.*]]: vector<3x4xf32>
+func.func @negative_shuffle_2d_vectors(%lhs: vector<2x4xf32>, %rhs: vector<3x4xf32>) -> vector<4x4xf32> {
+ // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]]
+ // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]]
+ // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
// CHECK: return %[[RESULT]] : vector<4x4xf32>
- %0 = vector.shuffle %v1, %v2 [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
+ %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
return %0 : vector<4x4xf32>
}
>From 82ac739a91ea64d2b87039fcf8619b83026f584c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 17 Sep 2025 00:03:54 +0000
Subject: [PATCH 3/3] Feedback
---
.../Vector/Transforms/LowerVectorShuffle.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 8ca74cb49d4c0..78102f7325b9f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -23,7 +23,7 @@ using namespace mlir::vector;
namespace {
-/// Lowers a `vector.shuffle` operation with mix-size inputs to a new
+/// Lowers a `vector.shuffle` operation with mixed-size inputs to a new
/// `vector.shuffle` which promotes the smaller input to the larger vector size
/// and an updated version of the original `vector.shuffle`.
///
@@ -33,12 +33,16 @@ namespace {
///
/// is lowered to:
///
-/// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
-/// vector<2xf32>, vector<2xf32>
-/// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
+/// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
+/// vector<2xf32>, vector<2xf32>
+/// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
/// vector<4xf32>, vector<4xf32>
///
-struct MixSizeInputShuffleOpRewrite final
+/// Note: This transformation helps legalize vector.shuffle ops when lowering
+/// to SPIR-V/LLVM, which don't support shuffle operations with mixed-size
+/// inputs.
+///
+struct MixedSizeInputShuffleOpRewrite final
: OpRewritePattern<vector::ShuffleOp> {
using OpRewritePattern::OpRewritePattern;
@@ -51,7 +55,7 @@ struct MixSizeInputShuffleOpRewrite final
if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
return failure();
- // Bail out if inputs don't have mixed sized.
+ // Bail out if inputs don't have mixed sizes.
int64_t v1OrigNumElems = v1Type.getNumElements();
int64_t v2OrigNumElems = v2Type.getNumElements();
if (v1OrigNumElems == v2OrigNumElems)
@@ -102,5 +106,5 @@ struct MixSizeInputShuffleOpRewrite final
void mlir::vector::populateVectorShuffleLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<MixSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
+ patterns.add<MixedSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
}
More information about the Mlir-commits
mailing list