[Mlir-commits] [mlir] [mlir][vector] Add special lowering for 2D transpose on 1D broadcast (PR #150562)
Min-Yih Hsu
llvmlistbot at llvm.org
Thu Jul 24 18:22:49 PDT 2025
https://github.com/mshockwave created https://github.com/llvm/llvm-project/pull/150562
A 2D transpose of a 1D broadcast like this:
```
%b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
%t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
```
could be lowered into the following code:
```
%cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
%0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
%1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
%2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
%t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
```
Which is more efficient than a single shuffle on a flatten 2D vector on most platforms, as those shuffles are likely to be lowered into a bunch of splats.
>From 97eafe6810fc5e20e7f3f2e6acce8bbc9c159bde Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 24 Jul 2025 18:09:30 -0700
Subject: [PATCH] [mlir][vector] Add special lowering for 2D transpose on 1D
broadcast
---
.../Transforms/LowerVectorTranspose.cpp | 73 +++++++++++++++++++
.../Vector/vector-transpose-lowering.mlir | 63 ++++++++++++++++
2 files changed, 136 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9e7d0ced3e6d1..e7521c1708a42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -423,6 +423,75 @@ class Transpose2DWithUnitDimToShapeCast
}
};
+// Given this snippet
+// ```
+// %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+// %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+// ```
+// while we can't directly broadcast from vector<2xf32> to vector<2x32xf32>,
+// we can do something like this:
+// ```
+// %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
+// %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
+// %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
+// %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
+// ```
+// Where the shuffles are effectively 1-D broadcasts (splats), which are more
+// efficient than a single shuffle on a flatten 2-D vector.
+static LogicalResult
+lowerTranspose2DOfBroadcast1D(vector::TransposeOp transpose, int64_t srcDim0,
+ int64_t srcDim1, PatternRewriter &rewriter) {
+ auto loc = transpose.getLoc();
+ auto broadcast = transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast || !broadcast.getResult().hasOneUse())
+ return failure();
+
+ Value broadcastSrc = broadcast.getSource();
+ auto srcType = dyn_cast<VectorType>(broadcastSrc.getType());
+ if (!srcType)
+ return failure();
+ Type elementType = srcType.getElementType();
+ // Find the dimensions that are greater than 1.
+ SmallVector<int64_t> broadcastSrcDims;
+ for (int64_t size : srcType.getShape()) {
+ if (size > 1)
+ broadcastSrcDims.push_back(size);
+ }
+ if (broadcastSrcDims.size() != 1 || broadcastSrcDims[0] != srcDim1)
+ return failure();
+ // Normalize the broadcast source into an actual 1-D vector.
+ broadcastSrc =
+ rewriter
+ .create<vector::ShapeCastOp>(
+ loc, VectorType::get({broadcastSrcDims[0]}, elementType),
+ broadcastSrc)
+ .getResult();
+
+ // The normalized result type of the transpose.
+ auto normalizedResultType = VectorType::get({srcDim1, srcDim0}, elementType);
+ // The (normalized) 1-D type for the shuffles.
+ auto shuffleType = VectorType::get({srcDim0}, elementType);
+ SmallVector<int64_t> shuffleMask(srcDim0);
+
+ Value resultVec = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(normalizedResultType));
+ // Generate 1-D shuffles.
+ for (int64_t idx = 0; idx < srcDim1; ++idx) {
+ std::fill(shuffleMask.begin(), shuffleMask.end(), idx);
+ auto shuffle = rewriter.create<vector::ShuffleOp>(
+ loc, shuffleType, broadcastSrc, broadcastSrc,
+ rewriter.getDenseI64ArrayAttr(shuffleMask));
+ resultVec = rewriter.create<vector::InsertOp>(loc, shuffle, resultVec,
+ /*position=*/idx);
+ }
+
+ // Cast the result back to the original shape.
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getResultVectorType(), resultVec);
+ return success();
+}
+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -460,6 +529,10 @@ class TransposeOp2DToShuffleLowering
int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
+ if (vectorTransposeLowering == VectorTransposeLowering::Shuffle1D &&
+ succeeded(lowerTranspose2DOfBroadcast1D(op, m, n, rewriter)))
+ return success();
+
// Reshape the n-D input vector with only two dimensions greater than one
// to a 2-D vector.
Location loc = op.getLoc();
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 7838aad1825bc..9c96a6270d504 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -365,3 +365,66 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @transpose_of_broadcast(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK: return %[[VAL_4]] : vector<2x32xf32>
+// CHECK: }
+func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+ %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+ %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+ return %t : vector<2x32xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_of_broadcast2(
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xf32>) -> vector<4x32xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4x32xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<4x32xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<4x32xf32>
+// CHECK: %[[VAL_5:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_4]] [2] : vector<32xf32> into vector<4x32xf32>
+// CHECK: %[[VAL_7:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.insert %[[VAL_7]], %[[VAL_6]] [3] : vector<32xf32> into vector<4x32xf32>
+// CHECK: return %[[VAL_8]] : vector<4x32xf32>
+// CHECK: }
+func.func @transpose_of_broadcast2(%arg0 : vector<4xf32>) -> vector<4x32xf32> {
+ %b = vector.broadcast %arg0 : vector<4xf32> to vector<32x4xf32>
+ %t = vector.transpose %b, [1, 0] : vector<32x4xf32> to vector<4x32xf32>
+ return %t : vector<4x32xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_of_broadcast_odd_shape(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2xf32>
+// CHECK: %[[VAL_2:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK: %[[VAL_4:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_5:.*]] = vector.insert %[[VAL_4]], %[[VAL_3]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x32xf32> to vector<2x1x32xf32>
+// CHECK: return %[[VAL_6]] : vector<2x1x32xf32>
+// CHECK: }
+func.func @transpose_of_broadcast_odd_shape(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+ %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+ %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+ return %t : vector<2x1x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list