[Mlir-commits] [mlir] [mlir][vector] Generalize the canonicalization of transpose(broadcast(x)) (PR #153056)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 11 10:43:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Min-Yih Hsu (mshockwave)

<details>
<summary>Changes</summary>

Previously, we canonicalized transpose(broadcast(x)) into broadcast(x) if the transpose preserves the order. This rule, however, could be further generalized as canonicalizing transpose(broadcast(x)) into broadcast(shape_cast(x)).

The rationale behind this could be broken down into two steps: first, we state that transpose(broadcast(x)) could be turned into broadcast(transpose(x')), where x' is the normalized of x, if the original broadcasted dimensions from x to broadcast(x) are the same as that from transpose(x') to broadcast(transpose(x')). Then, let x' = shape_cast(x), we can further simplify transpose(x') into just shape_cast(x) if transpose(x') preserves the order, hence the final broadcast(shape_cast(x)). 

------

This patch was inspired by #<!-- -->150562, where I attempted to lower the following snippet
```
%b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
%t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
```
with a bunch of 1-D vector.shuffle, while a better way would be turning that into broadcast(shape_cast(%arg0)) as shown in this patch.

---
Full diff: https://github.com/llvm/llvm-project/pull/153056.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+85-55) 
- (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+84-24) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+4-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cb4783d26a114..021a081ccb1c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5923,9 +5923,8 @@ LogicalResult ShapeCastOp::verify() {
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
-  ArrayRef<int64_t> permutation = transpose.getPermutation();
-  VectorType sourceType = transpose.getSourceVectorType();
+static bool isOrderPreserving(ArrayRef<int64_t> permutation,
+                              VectorType sourceType) {
   ArrayRef<int64_t> inShape = sourceType.getShape();
   ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
   auto isNonScalableUnitDim = [&](int64_t dim) {
@@ -5943,6 +5942,11 @@ static bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
+static bool isOrderPreserving(TransposeOp transpose) {
+  return isOrderPreserving(transpose.getPermutation(),
+                           transpose.getSourceVectorType());
+}
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -6492,31 +6496,20 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
-/// 'order preserving', where 'order preserving' means the flattened
-/// inputs and outputs of the transpose have identical (numerical) values.
+/// Cannonicalize transpose(broadcast(x)) into broadcast(transpose(x')),
+/// where x' is the normalized x, if the following conditions meet:
+/// (1) Normalize x to x' such that x' has the same shape as broadcast(x)
 ///
-/// Example:
-/// ```
-///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
-///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
-///                                                 to vector<8x1xi32>
-/// ```
-/// can be rewritten as the equivalent
-/// ```
-///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
-/// ```
-/// The algorithm works by partitioning dimensions into groups that can be
-/// locally permuted while preserving order, and checks that the transpose
-/// only permutes within these groups.
+/// (2) Check if transpose(x') is broadcastable to the original output type.
 ///
-/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
-/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
-/// broadcasting from 1x1x4x1x1x7.
-///                   ^^^ ^ ^^^ ^
-///          groups:   0  1  2  3
-/// Order preserving permutations for this example are ones that only permute
-/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+/// (3) Check if the broadcasted dimensions in x -> broadcast(x) are the same as
+/// that in transpose(x') -> broadcast(transpose(x'))
+///
+/// (4) If the above conditions meet, we can generate broadcast(transpose(x')),
+/// where x' = shape_cast(x). However, this won't be profitable if
+/// transpose(shape_cast(x)) cannot be folded into shape_cast(x), so check if
+/// such folding is possible by checking whether such transpose preserves the
+/// order.
 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6525,7 +6518,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transpose,
                                 PatternRewriter &rewriter) const override {
-
+    auto loc = transpose.getLoc();
     vector::BroadcastOp broadcast =
         transpose.getVector().getDefiningOp<vector::BroadcastOp>();
     if (!broadcast) {
@@ -6544,44 +6537,81 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
       return success();
     }
 
+    VectorType transposeInputType = transpose.getSourceVectorType();
     ArrayRef<int64_t> permutation = transpose.getPermutation();
     ArrayRef<int64_t> inputShape = inputType.getShape();
+    // This is also the shape of broadcast result.
+    ArrayRef<int64_t> transposeInputShape = transposeInputType.getShape();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
     int64_t inputRank = inputType.getRank();
-    int64_t outputRank = transpose.getType().getRank();
+    int64_t outputRank = outputShape.size();
     int64_t deltaRank = outputRank - inputRank;
+    assert(deltaRank >= 0);
+
+    // Normalize the input type.
+    VectorType normalizedInputType = inputType;
+    if (deltaRank > 0) {
+      // Fill leading dimensions with ones.
+      SmallVector<int64_t> newShape(deltaRank, 1);
+      newShape.append(inputShape.begin(), inputShape.end());
+      normalizedInputType =
+          VectorType::get(newShape, inputType.getElementType());
+    }
 
-    int low = 0;
-    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
-      bool notOne = inputShape[inputIndex] != 1;
-      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
-      bool groupEndFound = notOne || prevNotOne;
-      if (groupEndFound) {
-        int high = inputIndex + deltaRank;
-        // Return failure if not all permutation destinations for indices in
-        // [low, high) are in [low, high), i.e. the permutation is not local to
-        // the group.
-        for (int i = low; i < high; ++i) {
-          if (permutation[i] < low || permutation[i] >= high) {
-            return rewriter.notifyMatchFailure(
-                transpose, "permutation not local to group");
-          }
-        }
-        low = high;
-      }
+    ArrayRef<int64_t> normalizedInputShape = normalizedInputType.getShape();
+    // Retrieve the original broadcasted dimensions.
+    BitVector origBroadcastDims(outputRank);
+    for (int64_t i = 0; i < outputRank; ++i) {
+      if (normalizedInputShape[i] == 1 && transposeInputShape[i] > 1)
+        origBroadcastDims.set(i);
     }
 
-    // We don't need to check the final group [low, outputRank) because if it is
-    // not locally bound, there must be a preceding group that already failed
-    // the check (impossible to have just 1 non-locally bound group).
+    // Transpose the normalized input type
+    VectorType::Builder builder(normalizedInputType);
+    for (auto [idx, idxNew] : enumerate(permutation))
+      builder.setDim(idx, normalizedInputShape[idxNew]);
+    VectorType transposedInputType = builder;
+
+    // Check if the new normalized and transposed inputType is broadcastable to
+    // the output type.
+    if (vector::isBroadcastableTo(transposedInputType, outputType) !=
+        BroadcastableToResult::Success)
+      return failure();
 
-    // The preceding logic also ensures that at this point, the output of the
-    // transpose is definitely broadcastable from the input shape, assert so:
-    assert(vector::isBroadcastableTo(inputType, outputType) ==
-               vector::BroadcastableToResult::Success &&
-           "not broadcastable directly to transpose output");
+    // Retrieve the prospective broadcasted dimensions from transposedInputType
+    // to outputType.
+    ArrayRef<int64_t> transposedInputShape = transposedInputType.getShape();
+    BitVector newBroadcastDims(outputRank);
+    for (int64_t i = 0; i < outputRank; ++i) {
+      if (transposedInputShape[i] == 1 && outputShape[i] > 1)
+        newBroadcastDims.set(i);
+    }
+
+    // Check if the _transposed_ of the original broadcasted dimensions equals
+    // to the prospective broadcasted dimensions.
+    BitVector refBroadcastDims(outputRank);
+    for (unsigned bitIdx : origBroadcastDims.set_bits())
+      refBroadcastDims.set(permutation[bitIdx]);
+    if (refBroadcastDims != newBroadcastDims)
+      return failure();
+
+    // Check if this transpose(shape_cast(x)) could be folded
+    // into shape_cast(x).
+    if (!isOrderPreserving(permutation, normalizedInputType))
+      return failure();
 
+    // All checks pass, replace with broadcast(transpose(x')), where x' =
+    // shape_cast(x).
+    Value normalizedInput =
+        rewriter
+            .create<vector::ShapeCastOp>(loc, normalizedInputType,
+                                         broadcast.getSource())
+            .getResult();
+    Value newTranspose =
+        rewriter.create<vector::TransposeOp>(loc, normalizedInput, permutation)
+            .getResult();
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                     broadcast.getSource());
+                                                     newTranspose);
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index f1e1c5e896c66..359342bf155c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -91,12 +91,27 @@ func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_square
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
-//       CHECK:  return %[[TRP]] : vector<4x4xi8>
-func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x32xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x32xf32>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+  %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+  %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+  return %t : vector<2x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_square(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1xi8> to vector<1x4xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x4xi8> to vector<4x4xi8>
+// CHECK:           return %[[VAL_1]] : vector<4x4xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
   %0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
   %1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
   return %1 : vector<4x4xi8>
@@ -104,12 +119,13 @@ func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_hypercube
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
-//       CHECK:  return %[[TRP]] : vector<4x4x4x4xi8>
-func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_hypercube(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xi8> to vector<1x1x4x1xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x1x4x1xi8> to vector<4x4x4x4xi8>
+// CHECK:           return %[[VAL_1]] : vector<4x4x4x4xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
   %0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
   %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
   return %1 : vector<4x4x4x4xi8>
@@ -117,12 +133,13 @@ func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> v
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_102
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
-//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_102(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<1x3x3xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x3x3xi8> to vector<3x3x3xi8>
+// CHECK:           return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
   %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
   %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
   return %1 : vector<3x3x3xi8>
@@ -130,12 +147,13 @@ func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_021
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
-//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_021(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<3x3x1xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<3x3x1xi8> to vector<3x3x3xi8>
+// CHECK:           return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
   %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
   %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
   return %1 : vector<3x3x3xi8>
@@ -143,6 +161,48 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
 
 // -----
 
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_210(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2x1x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x1x32xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x1x32xf32>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_210(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+  %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+  %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+  return %t : vector<2x1x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_tail_unit_dim(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2x1xf32>) -> vector<2x32x1xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2x1x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x32x1xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x32x1xf32>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_tail_unit_dim(%arg0 : vector<2x1xf32>) -> vector<2x32x1xf32> {
+  %b = vector.broadcast %arg0 : vector<2x1xf32> to vector<32x2x1xf32>
+  %t = vector.transpose %b, [1, 0, 2] : vector<32x2x1xf32> to vector<2x32x1xf32>
+  return %t : vector<2x32x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @negative_broadcast_transpose_shapecast_not_order_preserving(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<14x7xf32> to vector<8x16x14x7xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.transpose %[[VAL_0]], [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+// CHECK:           return %[[VAL_1]] : vector<7x14x8x16xf32>
+// CHECK:         }
+func.func @negative_broadcast_transpose_shapecast_not_order_preserving(%arg0 : vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+  %b = vector.broadcast %arg0 : vector<14x7xf32> to vector<8x16x14x7xf32>
+  %t = vector.transpose %b, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+  return %t : vector<7x14x8x16xf32>
+}
+
+// -----
+
 /// +--------------------------------------------------------------------------
 ///  Tests of ShapeCastOp::fold:  shape_cast(transpose) -> shape_cast
 /// +--------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..d3cf534a369bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
   %c0 = arith.constant 0 : index
 
   %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
-  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
 
   return %res : tensor<?x?x?x?xf32>
 }
@@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref(
   %c0 = arith.constant 0 : index
 
   vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
 
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/153056


More information about the Mlir-commits mailing list