[Mlir-commits] [mlir] 7bdd88c - [mlir][Vector] Add patterns to lower `vector.shuffle` (#157611)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Tue Sep 16 18:01:34 PDT 2025
    
    
  
Author: Diego Caballero
Date: 2025-09-16T18:01:30-07:00
New Revision: 7bdd88c1e3a70d8213f8bc68403fbd844f11b00c
URL: https://github.com/llvm/llvm-project/commit/7bdd88c1e3a70d8213f8bc68403fbd844f11b00c
DIFF: https://github.com/llvm/llvm-project/commit/7bdd88c1e3a70d8213f8bc68403fbd844f11b00c.diff
LOG: [mlir][Vector] Add patterns to lower `vector.shuffle` (#157611)
This PR adds patterns to lower `vector.shuffle` with inputs with
different vector sizes more efficiently. The current LLVM lowering for
these cases degenerates to a sequence of `vector.extract` and
`vector.insert` operations. With this PR, the smaller input is promoted
to larger vector size by introducing an extra `vector.shuffle`.
Added: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
    mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index f56124cb4fb95..b896506f29eef 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -293,6 +293,9 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
                                            int64_t targetRank = 1,
                                            PatternBenefit benefit = 1);
 
+void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
+                                           PatternBenefit benefit = 1);
+
 /// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
 /// n > 1.
 void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index d74007f13a95b..8e36ead6993a8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
+  LowerVectorShuffle.cpp
   LowerVectorStep.cpp
   LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
new file mode 100644
index 0000000000000..78102f7325b9f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -0,0 +1,110 @@
+//===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of complex `vector.shuffle` operation to a
+// set of simpler operations supported by LLVM/SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-shuffle-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Lowers a `vector.shuffle` operation with mixed-size inputs to a new
+/// `vector.shuffle` which promotes the smaller input to the larger vector size
+/// and an updated version of the original `vector.shuffle`.
+///
+/// Example:
+///
+///     %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32>
+///
+///   is lowered to:
+///
+///     %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
+///       vector<2xf32>, vector<2xf32>
+///     %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
+///       vector<4xf32>, vector<4xf32>
+///
+/// Note: This transformation helps legalize vector.shuffle ops when lowering
+/// to SPIR-V/LLVM, which don't support shuffle operations with mixed-size
+/// inputs.
+///
+struct MixedSizeInputShuffleOpRewrite final
+    : OpRewritePattern<vector::ShuffleOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
+                                PatternRewriter &rewriter) const override {
+    auto v1Type = shuffleOp.getV1VectorType();
+    auto v2Type = shuffleOp.getV2VectorType();
+
+    // Only support 1-D shuffle for now.
+    if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
+      return failure();
+
+    // Bail out if inputs don't have mixed sizes.
+    int64_t v1OrigNumElems = v1Type.getNumElements();
+    int64_t v2OrigNumElems = v2Type.getNumElements();
+    if (v1OrigNumElems == v2OrigNumElems)
+      return failure();
+
+    // Determine which input needs promotion.
+    bool promoteV1 = v1OrigNumElems < v2OrigNumElems;
+    Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2();
+    VectorType promotedType = promoteV1 ? v2Type : v1Type;
+    int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems;
+    int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems;
+
+    // Create a shuffle with a mask that preserves existing elements and fills
+    // up with poison.
+    SmallVector<int64_t> promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex);
+    for (int64_t i = 0; i < origNumElems; ++i)
+      promoteMask[i] = i;
+
+    Value promotedInput = rewriter.create<vector::ShuffleOp>(
+        shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
+        promoteMask);
+
+    // Create the final shuffle with the promoted inputs.
+    Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
+    Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput;
+
+    SmallVector<int64_t> newMask;
+    if (!promoteV1) {
+      newMask = to_vector(shuffleOp.getMask());
+    } else {
+      // Adjust V2 indices to account for the new V1 size.
+      for (auto idx : shuffleOp.getMask()) {
+        int64_t newIdx = idx;
+        if (idx >= v1OrigNumElems) {
+          newIdx += promotedNumElems - v1OrigNumElems;
+        }
+        newMask.push_back(newIdx);
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2,
+        newMask);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorShuffleLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<MixedSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
+}
diff  --git a/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
new file mode 100644
index 0000000000000..a137811fa367c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @shuffle_smaller_lhs_arbitrary
+// CHECK-SAME:    %[[LHS:.*]]: vector<2xf32>, %[[RHS:.*]]: vector<4xf32>
+func.func @shuffle_smaller_lhs_arbitrary(%lhs: vector<2xf32>, %rhs: vector<4xf32>) -> vector<5xf32> {
+  // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
+  // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32>
+  // CHECK: return %[[RESULT]] : vector<5xf32>
+  %0 = vector.shuffle %lhs, %rhs [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32>
+  return %0 : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_smaller_rhs_arbitrary
+// CHECK-SAME:    %[[LHS:.*]]: vector<4xi32>, %[[RHS:.*]]: vector<2xi32>
+func.func @shuffle_smaller_rhs_arbitrary(%lhs: vector<4xi32>, %rhs: vector<2xi32>) -> vector<6xi32> {
+  // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32>
+  // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32>
+  // CHECK: return %[[RESULT]] : vector<6xi32>
+  %0 = vector.shuffle %lhs, %rhs [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32>
+  return %0 : vector<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_smaller_lhs_concat
+// CHECK-SAME:    %[[LHS:.*]]: vector<3xf64>, %[[RHS:.*]]: vector<5xf64>
+func.func @shuffle_smaller_lhs_concat(%lhs: vector<3xf64>, %rhs: vector<5xf64>) -> vector<8xf64> {
+  // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64>
+  // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64>
+  // CHECK: return %[[RESULT]] : vector<8xf64>
+  %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64>
+  return %0 : vector<8xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_smaller_rhs_concat
+// CHECK-SAME:    %[[LHS:.*]]: vector<4xi16>, %[[RHS:.*]]: vector<2xi16>
+func.func @shuffle_smaller_rhs_concat(%lhs: vector<4xi16>, %rhs: vector<2xi16>) -> vector<6xi16> {
+  // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16>
+  // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16>
+  // CHECK: return %[[RESULT]] : vector<6xi16>
+  %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16>
+  return %0 : vector<6xi16>
+}
+
+// -----
+
+// Test that shuffles with same size inputs are not modified.
+
+// CHECK-LABEL: func.func @negative_shuffle_same_input_sizes
+// CHECK-SAME:    %[[LHS:.*]]: vector<4xf32>, %[[RHS:.*]]: vector<4xf32>
+func.func @negative_shuffle_same_input_sizes(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<6xf32> {
+  // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]]
+  // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]]
+  // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
+  // CHECK: return %[[RESULT]] : vector<6xf32>
+  %0 = vector.shuffle %lhs, %rhs [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
+  return %0 : vector<6xf32>
+}
+
+// -----
+
+// Test that multi-dimensional shuffles are not modified.
+
+// CHECK-LABEL: func.func @negative_shuffle_2d_vectors
+// CHECK-SAME:    %[[LHS:.*]]: vector<2x4xf32>, %[[RHS:.*]]: vector<3x4xf32>
+func.func @negative_shuffle_2d_vectors(%lhs: vector<2x4xf32>, %rhs: vector<3x4xf32>) -> vector<4x4xf32> {
+  // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]]
+  // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]]
+  // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
+  // CHECK: return %[[RESULT]] : vector<4x4xf32>
+  %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
+  return %0 : vector<4x4xf32>
+}
diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 72dd103b33f75..79bfc9bbcda71 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -994,6 +994,22 @@ struct TestEliminateVectorMasks
                          VscaleRange{vscaleMin, vscaleMax});
   }
 };
+
+struct TestVectorShuffleLowering
+    : public PassWrapper<TestVectorShuffleLowering,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering)
+
+  StringRef getArgument() const final { return "test-vector-shuffle-lowering"; }
+  StringRef getDescription() const final {
+    return "Test lowering patterns for vector.shuffle with mixed-size inputs";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorShuffleLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorScanLowering>();
 
+  PassRegistration<TestVectorShuffleLowering>();
+
   PassRegistration<TestVectorDistribution>();
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
        
    
    
More information about the Mlir-commits
mailing list