[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Apr 11 05:51:23 PDT 2025
================
@@ -6155,12 +6156,120 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+/// to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+///
+/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
+/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
+/// broadcasting from 1x1x4x1x1x7.
+/// ^^^ ^ ^^^ ^
+/// groups: 0 1 2 3
+/// Order preserving permutations for this example are ones that only permute
+/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+
+ vector::BroadcastOp broadcast =
+ transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast) {
+ return rewriter.notifyMatchFailure(transpose,
+ "not preceded by a broadcast");
+ }
+
+ auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+
+ // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+ bool inputIsScalar = !inputType;
+ if (inputIsScalar) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transpose, transpose.getResultVectorType(), transpose.getVector());
+ return success();
+ }
+
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputRank = inputType.getRank();
+ int64_t outputRank = transpose.getType().getRank();
+ int64_t deltaRank = outputRank - inputRank;
+
+ // Return true if all permutation destinations for indices in [low, high)
+ // are in [low, high), so the permutation is local to the group.
+ auto isGroupBound = [permutation](int low, int high) {
+ for (int i = low; i < high; ++i) {
+ if (permutation[i] < low || permutation[i] >= high) {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ int low = 0;
+ for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+ bool notOne = inputShape[inputIndex] != 1;
+ bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+ bool groupEndFound = notOne || prevNotOne;
+ if (groupEndFound) {
+ int high = inputIndex + deltaRank;
+ if (!isGroupBound(low, high)) {
+ return rewriter.notifyMatchFailure(
+ transpose, llvm::formatv("output dimensions in interval [{0}, "
+ "{1}) aren't locally permuted.",
+ low, high));
----------------
banach-space wrote:
This is neat, but from what I can tell, we don't use `llvm::formatv` in `notivyMatchFailure()`. I don't remember the rationale, so I asked:
* https://discord.com/channels/636084430946959380/642426447167881246/1360229979534655779
https://github.com/llvm/llvm-project/pull/135096
More information about the Mlir-commits
mailing list