[Mlir-commits] [mlir] [mlir][vector] Add special lowering for 2D transpose on 1D broadcast (PR #150562)

Min-Yih Hsu llvmlistbot at llvm.org
Thu Jul 24 18:22:49 PDT 2025


https://github.com/mshockwave created https://github.com/llvm/llvm-project/pull/150562

A 2D transpose of a 1D broadcast like this:
```
  %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
  %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
```
could be lowered into the following code:
```
  %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
  %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
  %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
  %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
  %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
```
Which is more efficient than a single shuffle on a flatten 2D vector on most platforms, as those shuffles are likely to be lowered into a bunch of splats.

>From 97eafe6810fc5e20e7f3f2e6acce8bbc9c159bde Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 24 Jul 2025 18:09:30 -0700
Subject: [PATCH] [mlir][vector] Add special lowering for 2D transpose on 1D
 broadcast

---
 .../Transforms/LowerVectorTranspose.cpp       | 73 +++++++++++++++++++
 .../Vector/vector-transpose-lowering.mlir     | 63 ++++++++++++++++
 2 files changed, 136 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9e7d0ced3e6d1..e7521c1708a42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -423,6 +423,75 @@ class Transpose2DWithUnitDimToShapeCast
   }
 };
 
+// Given this snippet
+// ```
+//  %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+//  %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+// ```
+// while we can't directly broadcast from vector<2xf32> to vector<2x32xf32>,
+// we can do something like this:
+// ```
+//  %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+//  %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
+//  %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
+//  %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
+//  %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
+// ```
+// Where the shuffles are effectively 1-D broadcasts (splats), which are more
+// efficient than a single shuffle on a flatten 2-D vector.
+static LogicalResult
+lowerTranspose2DOfBroadcast1D(vector::TransposeOp transpose, int64_t srcDim0,
+                              int64_t srcDim1, PatternRewriter &rewriter) {
+  auto loc = transpose.getLoc();
+  auto broadcast = transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+  if (!broadcast || !broadcast.getResult().hasOneUse())
+    return failure();
+
+  Value broadcastSrc = broadcast.getSource();
+  auto srcType = dyn_cast<VectorType>(broadcastSrc.getType());
+  if (!srcType)
+    return failure();
+  Type elementType = srcType.getElementType();
+  // Find the dimensions that are greater than 1.
+  SmallVector<int64_t> broadcastSrcDims;
+  for (int64_t size : srcType.getShape()) {
+    if (size > 1)
+      broadcastSrcDims.push_back(size);
+  }
+  if (broadcastSrcDims.size() != 1 || broadcastSrcDims[0] != srcDim1)
+    return failure();
+  // Normalize the broadcast source into an actual 1-D vector.
+  broadcastSrc =
+      rewriter
+          .create<vector::ShapeCastOp>(
+              loc, VectorType::get({broadcastSrcDims[0]}, elementType),
+              broadcastSrc)
+          .getResult();
+
+  // The normalized result type of the transpose.
+  auto normalizedResultType = VectorType::get({srcDim1, srcDim0}, elementType);
+  // The (normalized) 1-D type for the shuffles.
+  auto shuffleType = VectorType::get({srcDim0}, elementType);
+  SmallVector<int64_t> shuffleMask(srcDim0);
+
+  Value resultVec = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(normalizedResultType));
+  // Generate 1-D shuffles.
+  for (int64_t idx = 0; idx < srcDim1; ++idx) {
+    std::fill(shuffleMask.begin(), shuffleMask.end(), idx);
+    auto shuffle = rewriter.create<vector::ShuffleOp>(
+        loc, shuffleType, broadcastSrc, broadcastSrc,
+        rewriter.getDenseI64ArrayAttr(shuffleMask));
+    resultVec = rewriter.create<vector::InsertOp>(loc, shuffle, resultVec,
+                                                  /*position=*/idx);
+  }
+
+  // Cast the result back to the original shape.
+  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+      transpose, transpose.getResultVectorType(), resultVec);
+  return success();
+}
+
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -460,6 +529,10 @@ class TransposeOp2DToShuffleLowering
     int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
     int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
 
+    if (vectorTransposeLowering == VectorTransposeLowering::Shuffle1D &&
+        succeeded(lowerTranspose2DOfBroadcast1D(op, m, n, rewriter)))
+      return success();
+
     // Reshape the n-D input vector with only two dimensions greater than one
     // to a 2-D vector.
     Location loc = op.getLoc();
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 7838aad1825bc..9c96a6270d504 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -365,3 +365,66 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @transpose_of_broadcast(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           return %[[VAL_4]] : vector<2x32xf32>
+// CHECK:         }
+func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+  %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+  %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+  return %t : vector<2x32xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose_of_broadcast2(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<4xf32>) -> vector<4x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4x32xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_4]] [2] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.insert %[[VAL_7]], %[[VAL_6]] [3] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           return %[[VAL_8]] : vector<4x32xf32>
+// CHECK:         }
+func.func @transpose_of_broadcast2(%arg0 : vector<4xf32>) -> vector<4x32xf32> {
+  %b = vector.broadcast %arg0 : vector<4xf32> to vector<32x4xf32>
+  %t = vector.transpose %b, [1, 0] : vector<32x4xf32> to vector<4x32xf32>
+  return %t : vector<4x32xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose_of_broadcast_odd_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.insert %[[VAL_4]], %[[VAL_3]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x32xf32> to vector<2x1x32xf32>
+// CHECK:           return %[[VAL_6]] : vector<2x1x32xf32>
+// CHECK:         }
+func.func @transpose_of_broadcast_odd_shape(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+  %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+  %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+  return %t : vector<2x1x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list