[Mlir-commits] [mlir] [mlir][vector] Generalize the canonicalization of transpose(broadcast(x)) (PR #153056)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 11 10:43:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Min-Yih Hsu (mshockwave)
<details>
<summary>Changes</summary>
Previously, we canonicalized transpose(broadcast(x)) into broadcast(x) if the transpose preserves the order. This rule, however, could be further generalized as canonicalizing transpose(broadcast(x)) into broadcast(shape_cast(x)).
The rationale behind this could be broken down into two steps: first, we state that transpose(broadcast(x)) could be turned into broadcast(transpose(x')), where x' is the normalized of x, if the original broadcasted dimensions from x to broadcast(x) are the same as that from transpose(x') to broadcast(transpose(x')). Then, let x' = shape_cast(x), we can further simplify transpose(x') into just shape_cast(x) if transpose(x') preserves the order, hence the final broadcast(shape_cast(x)).
------
This patch was inspired by #<!-- -->150562, where I attempted to lower the following snippet
```
%b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
%t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
```
with a bunch of 1-D vector.shuffle, while a better way would be turning that into broadcast(shape_cast(%arg0)) as shown in this patch.
---
Full diff: https://github.com/llvm/llvm-project/pull/153056.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+85-55)
- (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+84-24)
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+4-6)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cb4783d26a114..021a081ccb1c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5923,9 +5923,8 @@ LogicalResult ShapeCastOp::verify() {
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
- ArrayRef<int64_t> permutation = transpose.getPermutation();
- VectorType sourceType = transpose.getSourceVectorType();
+static bool isOrderPreserving(ArrayRef<int64_t> permutation,
+ VectorType sourceType) {
ArrayRef<int64_t> inShape = sourceType.getShape();
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
auto isNonScalableUnitDim = [&](int64_t dim) {
@@ -5943,6 +5942,11 @@ static bool isOrderPreserving(TransposeOp transpose) {
return true;
}
+static bool isOrderPreserving(TransposeOp transpose) {
+ return isOrderPreserving(transpose.getPermutation(),
+ transpose.getSourceVectorType());
+}
+
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
@@ -6492,31 +6496,20 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
-/// 'order preserving', where 'order preserving' means the flattened
-/// inputs and outputs of the transpose have identical (numerical) values.
+/// Cannonicalize transpose(broadcast(x)) into broadcast(transpose(x')),
+/// where x' is the normalized x, if the following conditions meet:
+/// (1) Normalize x to x' such that x' has the same shape as broadcast(x)
///
-/// Example:
-/// ```
-/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
-/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
-/// to vector<8x1xi32>
-/// ```
-/// can be rewritten as the equivalent
-/// ```
-/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
-/// ```
-/// The algorithm works by partitioning dimensions into groups that can be
-/// locally permuted while preserving order, and checks that the transpose
-/// only permutes within these groups.
+/// (2) Check if transpose(x') is broadcastable to the original output type.
///
-/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
-/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
-/// broadcasting from 1x1x4x1x1x7.
-/// ^^^ ^ ^^^ ^
-/// groups: 0 1 2 3
-/// Order preserving permutations for this example are ones that only permute
-/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+/// (3) Check if the broadcasted dimensions in x -> broadcast(x) are the same as
+/// that in transpose(x') -> broadcast(transpose(x'))
+///
+/// (4) If the above conditions meet, we can generate broadcast(transpose(x')),
+/// where x' = shape_cast(x). However, this won't be profitable if
+/// transpose(shape_cast(x)) cannot be folded into shape_cast(x), so check if
+/// such folding is possible by checking whether such transpose preserves the
+/// order.
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -6525,7 +6518,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {
-
+ auto loc = transpose.getLoc();
vector::BroadcastOp broadcast =
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcast) {
@@ -6544,44 +6537,81 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
return success();
}
+ VectorType transposeInputType = transpose.getSourceVectorType();
ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inputShape = inputType.getShape();
+ // This is also the shape of broadcast result.
+ ArrayRef<int64_t> transposeInputShape = transposeInputType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
int64_t inputRank = inputType.getRank();
- int64_t outputRank = transpose.getType().getRank();
+ int64_t outputRank = outputShape.size();
int64_t deltaRank = outputRank - inputRank;
+ assert(deltaRank >= 0);
+
+ // Normalize the input type.
+ VectorType normalizedInputType = inputType;
+ if (deltaRank > 0) {
+ // Fill leading dimensions with ones.
+ SmallVector<int64_t> newShape(deltaRank, 1);
+ newShape.append(inputShape.begin(), inputShape.end());
+ normalizedInputType =
+ VectorType::get(newShape, inputType.getElementType());
+ }
- int low = 0;
- for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
- bool notOne = inputShape[inputIndex] != 1;
- bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
- bool groupEndFound = notOne || prevNotOne;
- if (groupEndFound) {
- int high = inputIndex + deltaRank;
- // Return failure if not all permutation destinations for indices in
- // [low, high) are in [low, high), i.e. the permutation is not local to
- // the group.
- for (int i = low; i < high; ++i) {
- if (permutation[i] < low || permutation[i] >= high) {
- return rewriter.notifyMatchFailure(
- transpose, "permutation not local to group");
- }
- }
- low = high;
- }
+ ArrayRef<int64_t> normalizedInputShape = normalizedInputType.getShape();
+ // Retrieve the original broadcasted dimensions.
+ BitVector origBroadcastDims(outputRank);
+ for (int64_t i = 0; i < outputRank; ++i) {
+ if (normalizedInputShape[i] == 1 && transposeInputShape[i] > 1)
+ origBroadcastDims.set(i);
}
- // We don't need to check the final group [low, outputRank) because if it is
- // not locally bound, there must be a preceding group that already failed
- // the check (impossible to have just 1 non-locally bound group).
+ // Transpose the normalized input type
+ VectorType::Builder builder(normalizedInputType);
+ for (auto [idx, idxNew] : enumerate(permutation))
+ builder.setDim(idx, normalizedInputShape[idxNew]);
+ VectorType transposedInputType = builder;
+
+ // Check if the new normalized and transposed inputType is broadcastable to
+ // the output type.
+ if (vector::isBroadcastableTo(transposedInputType, outputType) !=
+ BroadcastableToResult::Success)
+ return failure();
- // The preceding logic also ensures that at this point, the output of the
- // transpose is definitely broadcastable from the input shape, assert so:
- assert(vector::isBroadcastableTo(inputType, outputType) ==
- vector::BroadcastableToResult::Success &&
- "not broadcastable directly to transpose output");
+ // Retrieve the prospective broadcasted dimensions from transposedInputType
+ // to outputType.
+ ArrayRef<int64_t> transposedInputShape = transposedInputType.getShape();
+ BitVector newBroadcastDims(outputRank);
+ for (int64_t i = 0; i < outputRank; ++i) {
+ if (transposedInputShape[i] == 1 && outputShape[i] > 1)
+ newBroadcastDims.set(i);
+ }
+
+ // Check if the _transposed_ of the original broadcasted dimensions equals
+ // to the prospective broadcasted dimensions.
+ BitVector refBroadcastDims(outputRank);
+ for (unsigned bitIdx : origBroadcastDims.set_bits())
+ refBroadcastDims.set(permutation[bitIdx]);
+ if (refBroadcastDims != newBroadcastDims)
+ return failure();
+
+ // Check if this transpose(shape_cast(x)) could be folded
+ // into shape_cast(x).
+ if (!isOrderPreserving(permutation, normalizedInputType))
+ return failure();
+ // All checks pass, replace with broadcast(transpose(x')), where x' =
+ // shape_cast(x).
+ Value normalizedInput =
+ rewriter
+ .create<vector::ShapeCastOp>(loc, normalizedInputType,
+ broadcast.getSource())
+ .getResult();
+ Value newTranspose =
+ rewriter.create<vector::TransposeOp>(loc, normalizedInput, permutation)
+ .getResult();
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- broadcast.getSource());
+ newTranspose);
return success();
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index f1e1c5e896c66..359342bf155c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -91,12 +91,27 @@ func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector
// -----
-// CHECK-LABEL: negative_broadcast_transpose_square
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
-// CHECK: return %[[TRP]] : vector<4x4xi8>
-func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x32xf32>
+// CHECK: return %[[VAL_1]] : vector<2x32xf32>
+// CHECK: }
+func.func @broadcast_transpose_shapecast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+ %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+ %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+ return %t : vector<2x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_square(
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1xi8> to vector<1x4xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x4xi8> to vector<4x4xi8>
+// CHECK: return %[[VAL_1]] : vector<4x4xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
@@ -104,12 +119,13 @@ func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector
// -----
-// CHECK-LABEL: negative_broadcast_transpose_hypercube
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
-// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
-func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_hypercube(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xi8> to vector<1x1x4x1xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x1x4x1xi8> to vector<4x4x4x4xi8>
+// CHECK: return %[[VAL_1]] : vector<4x4x4x4xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
return %1 : vector<4x4x4x4xi8>
@@ -117,12 +133,13 @@ func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> v
// -----
-// CHECK-LABEL: negative_broadcast_transpose_102
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
-// CHECK: return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_102(
+// CHECK-SAME: %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<1x3x3xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x3x3xi8> to vector<3x3x3xi8>
+// CHECK: return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
@@ -130,12 +147,13 @@ func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<
// -----
-// CHECK-LABEL: negative_broadcast_transpose_021
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
-// CHECK: return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_021(
+// CHECK-SAME: %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<3x3x1xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<3x3x1xi8> to vector<3x3x3xi8>
+// CHECK: return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
@@ -143,6 +161,48 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
// -----
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_210(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2x1x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x1x32xf32>
+// CHECK: return %[[VAL_1]] : vector<2x1x32xf32>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_210(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+ %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+ %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+ return %t : vector<2x1x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_tail_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<2x32x1xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2x1x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x32x1xf32>
+// CHECK: return %[[VAL_1]] : vector<2x32x1xf32>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_tail_unit_dim(%arg0 : vector<2x1xf32>) -> vector<2x32x1xf32> {
+ %b = vector.broadcast %arg0 : vector<2x1xf32> to vector<32x2x1xf32>
+ %t = vector.transpose %b, [1, 0, 2] : vector<32x2x1xf32> to vector<2x32x1xf32>
+ return %t : vector<2x32x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_transpose_shapecast_not_order_preserving(
+// CHECK-SAME: %[[ARG0:.*]]: vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<14x7xf32> to vector<8x16x14x7xf32>
+// CHECK: %[[VAL_1:.*]] = vector.transpose %[[VAL_0]], [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+// CHECK: return %[[VAL_1]] : vector<7x14x8x16xf32>
+// CHECK: }
+func.func @negative_broadcast_transpose_shapecast_not_order_preserving(%arg0 : vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+ %b = vector.broadcast %arg0 : vector<14x7xf32> to vector<8x16x14x7xf32>
+ %t = vector.transpose %b, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+ return %t : vector<7x14x8x16xf32>
+}
+
+// -----
+
/// +--------------------------------------------------------------------------
/// Tests of ShapeCastOp::fold: shape_cast(transpose) -> shape_cast
/// +--------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..d3cf534a369bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
%c0 = arith.constant 0 : index
%res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
- // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
+ // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
return %res : tensor<?x?x?x?xf32>
}
@@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref(
%c0 = arith.constant 0 : index
vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
- // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/153056
More information about the Mlir-commits
mailing list