[Mlir-commits] [mlir] 0daf20b - [mlir][vector] transpose(broadcast) -> broadcast canonicalization (#135096)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 16 10:08:39 PDT 2025
Author: James Newling
Date: 2025-04-16T13:08:36-04:00
New Revision: 0daf20b3605f19271af7afa4175e7d62194e5578
URL: https://github.com/llvm/llvm-project/commit/0daf20b3605f19271af7afa4175e7d62194e5578
DIFF: https://github.com/llvm/llvm-project/commit/0daf20b3605f19271af7afa4175e7d62194e5578.diff
LOG: [mlir][vector] transpose(broadcast) -> broadcast canonicalization (#135096)
Example seen in the 'real world':
```
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
```
This PR adds a canonicalizer that rewrites the above as
```
%1 = vector.broadcast %arg0 : vector<1xi8> to vector<8x1xi8>
```
It works by determining if a transpose is only shuffling contiguous
broadcast dimensions.
Added:
mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..504032a398fbe 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6085,28 +6085,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};
-// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
-struct FoldTransposedScalarBroadcast final
- : public OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
- PatternRewriter &rewriter) const override {
- auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
- if (!bcastOp)
- return failure();
-
- auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
- if (!srcVectorType || srcVectorType.getNumElements() == 1) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
- return success();
- }
-
- return failure();
- }
-};
-
// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
@@ -6161,12 +6139,106 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+/// to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+///
+/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
+/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
+/// broadcasting from 1x1x4x1x1x7.
+/// ^^^ ^ ^^^ ^
+/// groups: 0 1 2 3
+/// Order preserving permutations for this example are ones that only permute
+/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+
+ vector::BroadcastOp broadcast =
+ transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast) {
+ return rewriter.notifyMatchFailure(transpose,
+ "not preceded by a broadcast");
+ }
+
+ auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+ VectorType outputType = transpose.getResultVectorType();
+
+ // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+ bool inputIsScalar = !inputType;
+ if (inputIsScalar) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
+ transpose.getVector());
+ return success();
+ }
+
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputRank = inputType.getRank();
+ int64_t outputRank = transpose.getType().getRank();
+ int64_t deltaRank = outputRank - inputRank;
+
+ int low = 0;
+ for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+ bool notOne = inputShape[inputIndex] != 1;
+ bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+ bool groupEndFound = notOne || prevNotOne;
+ if (groupEndFound) {
+ int high = inputIndex + deltaRank;
+ // Return failure if not all permutation destinations for indices in
+ // [low, high) are in [low, high), i.e. the permutation is not local to
+ // the group.
+ for (int i = low; i < high; ++i) {
+ if (permutation[i] < low || permutation[i] >= high) {
+ return rewriter.notifyMatchFailure(
+ transpose, "permutation not local to group");
+ }
+ }
+ }
+ }
+
+ // We don't need to check the final group [low, outputRank) because if it is
+ // not locally bound, there must be a preceding group that already failed
+ // the check (impossible to have just 1 non-locally bound group).
+
+ // The preceding logic also ensures that at this point, the output of the
+ // transpose is definitely broadcastable from the input shape, assert so:
+ assert(vector::isBroadcastableTo(inputType, outputType) ==
+ vector::BroadcastableToResult::Success &&
+ "not broadcastable directly to transpose output");
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
+ transpose.getVector());
+
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
+ FoldTransposeBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..733a2c67d2c0c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2218,30 +2218,6 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
// -----
-// CHECK-LABEL: func @transpose_scalar_broadcast1
-// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
-// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
-// CHECK: return %[[V]] : vector<1x8xf32>
-func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
- %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
- %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
- return %t : vector<1x8xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_scalar_broadcast2
-// CHECK-SAME: (%[[ARG:.+]]: f32)
-// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
-// CHECK: return %[[V]] : vector<1x8xf32>
-func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
- %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
- %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
- return %t : vector<1x8xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @transpose_splat_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
new file mode 100644
index 0000000000000..e97e147459de2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -0,0 +1,139 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some canonicalizations tests involving vector.transpose.
+
+// CHECK-LABEL: func @transpose_scalar_broadcast1
+// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
+// CHECK: return %[[V]] : vector<1x8xf32>
+func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
+ %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
+ %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+ return %t : vector<1x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_scalar_broadcast2
+// CHECK-SAME: (%[[ARG:.+]]: f32)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
+// CHECK: return %[[V]] : vector<1x8xf32>
+func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
+ %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
+ %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+ return %t : vector<1x8xf32>
+}
+
+// -----
+
+
+// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
+// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
+// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+ %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+// CHECK: return %[[BC]] : vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_ones_to_broadcast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+// CHECK: return %[[RES]] : vector<2x3x4xi8>
+func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast
+// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+// CHECK: return %[[RES]] : vector<8x1xi8>
+func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+ %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+ return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+ %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+ return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_final_group
+// CHECK-SAME: %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8>
+// CHECK: return %[[RES]] : vector<4x7x2x3xi8>
+func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8>
+ %1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8>
+ return %1 : vector<4x7x2x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_square
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
+// CHECK: return %[[TRP]] : vector<4x4xi8>
+func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_hypercube
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
+// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
+func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
+ %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
+ return %1 : vector<4x4x4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_102
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_021
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
More information about the Mlir-commits
mailing list