[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)
James Newling
llvmlistbot at llvm.org
Thu Apr 10 11:17:32 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135096
>From a60254eef1e7f5ac4516d8ddfae3f849a95eac94 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 15:18:19 -0700
Subject: [PATCH 1/5] transpose(broadcast) -> broadcast folder
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 101 ++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 77 ++++++++++++++++
2 files changed, 177 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..ba56242edd6f3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2609,6 +2609,7 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -6155,12 +6156,110 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' here means the flattened
+/// inputs and outputs of the transpose have identical values.
+///
+/// Example:
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+/// to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+
+ vector::BroadcastOp broadcast =
+ transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast)
+ return false;
+
+ auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+ bool inputIsScalar = !inputType;
+ auto inputShape = inputType.getShape();
+ auto inputRank = inputType.getRank();
+ auto outputRank = transpose.getType().getRank();
+ auto deltaRank = outputRank - inputRank;
+
+ // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+ if (inputIsScalar)
+ return true;
+
+ // Return true if all permutation destinations for indices in [low, high)
+ // are in [low, high), so the permutation is local to the group.
+ auto isGroupBound = [&](int low, int high) {
+ auto perm = transpose.getPermutation();
+ for (int j = low; j < high; ++j) {
+ if (perm[j] < low || perm[j] >= high) {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ // Groups are either contiguous sequences of 1s and non-1s (1-element
+ // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
+ // to broadcasting from 1x1x4x1x1x7.
+ // ^^^ ^ ^^^ ^
+ // groups: 0 1 2 3
+ // Order preserving permutations for this example are ones that only permute
+ // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+ int low = 0;
+ for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+ bool notOne = inputShape[inputIndex] != 1;
+ bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+ bool groupEndFound = notOne || prevNotOne;
+ if (groupEndFound) {
+ int high = inputIndex + deltaRank;
+ if (!isGroupBound(low, high)) {
+ return false;
+ }
+ low = high;
+ }
+ }
+ if (!isGroupBound(low, outputRank)) {
+ return false;
+ }
+
+ bool isBroadcastable =
+ vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
+ vector::BroadcastableToResult::Success;
+ assert(isBroadcastable && "it should be broadcastable at this point");
+
+ return true;
+ }
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+ if (!canFoldIntoPrecedingBroadcast(transpose))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transpose, transpose.getResultVectorType(), transpose.getVector());
+
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..d443b85d40351 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+
+
// CHECK-LABEL: create_vector_mask_to_constant_mask
func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
@@ -2215,6 +2218,80 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
// -----
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+// CHECK: return %[[RES]] : vector<2x3x4xi8>
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+// CHECK: return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+// CHECK: return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+ %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+ return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+ %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+ return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
// CHECK-LABEL: func.func @insert_1d_constant
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
>From e63d35b0eff4d8787154d161f5525852aa3f32d8 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 15:42:30 -0700
Subject: [PATCH 2/5] tidy
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 26 +++++++++++++---------
mlir/test/Dialect/Vector/canonicalize.mlir | 3 ---
2 files changed, 15 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ba56242edd6f3..05ff93da13aea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2609,7 +2609,6 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
-
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -6156,9 +6155,9 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
-/// 'order preserving', where 'order preserving' here means the flattened
-/// inputs and outputs of the transpose have identical values.
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
///
/// Example:
/// ```
@@ -6188,10 +6187,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
bool inputIsScalar = !inputType;
- auto inputShape = inputType.getShape();
- auto inputRank = inputType.getRank();
- auto outputRank = transpose.getType().getRank();
- auto deltaRank = outputRank - inputRank;
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputRank = inputType.getRank();
+ int64_t outputRank = transpose.getType().getRank();
+ int64_t deltaRank = outputRank - inputRank;
// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
if (inputIsScalar)
@@ -6200,9 +6199,9 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
// Return true if all permutation destinations for indices in [low, high)
// are in [low, high), so the permutation is local to the group.
auto isGroupBound = [&](int low, int high) {
- auto perm = transpose.getPermutation();
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
for (int j = low; j < high; ++j) {
- if (perm[j] < low || perm[j] >= high) {
+ if (permutation[j] < low || permutation[j] >= high) {
return false;
}
}
@@ -6233,10 +6232,15 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
return false;
}
+ // The preceding logic ensures that by this point, the ouutput of the
+ // transpose is definitely broadcastable from the input shape. So we don't
+ // need to call 'vector::isBroadcastableTo', but asserting here just as a
+ // sanity check:
bool isBroadcastable =
vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
vector::BroadcastableToResult::Success;
- assert(isBroadcastable && "it should be broadcastable at this point");
+ assert(isBroadcastable &&
+ "(I think) it must be broadcastable at this point.");
return true;
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d443b85d40351..03a338985299d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1,8 +1,5 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
-
-
-
// CHECK-LABEL: create_vector_mask_to_constant_mask
func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
>From 7a9035802e4d07fcd90028bf74726b797d5ee912 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 10:51:59 -0700
Subject: [PATCH 3/5] address review comments: notify match failure, and create
new test file
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 74 ++++++++-------
mlir/test/Dialect/Vector/canonicalize.mlir | 86 +++--------------
.../Vector/vector-transpose-canonicalize.mlir | 92 +++++++++++++++++++
3 files changed, 144 insertions(+), 108 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 05ff93da13aea..33d4f5eaab9d4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <cstdint>
@@ -6172,49 +6173,57 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
/// The algorithm works by partitioning dimensions into groups that can be
/// locally permuted while preserving order, and checks that the transpose
/// only permutes within these groups.
+///
+/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
+/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
+/// broadcasting from 1x1x4x1x1x7.
+/// ^^^ ^ ^^^ ^
+/// groups: 0 1 2 3
+/// Order preserving permutations for this example are ones that only permute
+/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
- static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
vector::BroadcastOp broadcast =
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
- if (!broadcast)
- return false;
+ if (!broadcast) {
+ return rewriter.notifyMatchFailure(transpose,
+ "not preceded by a broadcast");
+ }
auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+
+ // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
bool inputIsScalar = !inputType;
+ if (inputIsScalar) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transpose, transpose.getResultVectorType(), transpose.getVector());
+ return success();
+ }
+
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputRank = inputType.getRank();
int64_t outputRank = transpose.getType().getRank();
int64_t deltaRank = outputRank - inputRank;
- // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
- if (inputIsScalar)
- return true;
-
// Return true if all permutation destinations for indices in [low, high)
// are in [low, high), so the permutation is local to the group.
- auto isGroupBound = [&](int low, int high) {
- ArrayRef<int64_t> permutation = transpose.getPermutation();
- for (int j = low; j < high; ++j) {
- if (permutation[j] < low || permutation[j] >= high) {
+ auto isGroupBound = [permutation](int low, int high) {
+ for (int i = low; i < high; ++i) {
+ if (permutation[i] < low || permutation[i] >= high) {
return false;
}
}
return true;
};
- // Groups are either contiguous sequences of 1s and non-1s (1-element
- // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
- // to broadcasting from 1x1x4x1x1x7.
- // ^^^ ^ ^^^ ^
- // groups: 0 1 2 3
- // Order preserving permutations for this example are ones that only permute
- // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
int low = 0;
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
bool notOne = inputShape[inputIndex] != 1;
@@ -6223,32 +6232,29 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
if (groupEndFound) {
int high = inputIndex + deltaRank;
if (!isGroupBound(low, high)) {
- return false;
+ return rewriter.notifyMatchFailure(
+ transpose, llvm::formatv("output dimensions in interval [{0}, "
+ "{1}) aren't locally permuted.",
+ low, high));
}
low = high;
}
}
if (!isGroupBound(low, outputRank)) {
- return false;
+ return rewriter.notifyMatchFailure(
+ transpose,
+ llvm::formatv("output dimensions in final interval [{0}, {1}) "
+ "aren't locally permuted.",
+ low, outputRank));
}
- // The preceding logic ensures that by this point, the ouutput of the
- // transpose is definitely broadcastable from the input shape. So we don't
- // need to call 'vector::isBroadcastableTo', but asserting here just as a
- // sanity check:
+ // The preceding logic ensures that at this point, the output of the
+ // transpose is definitely broadcastable from the input shape. We confirm
+ // this as a sanity check:
bool isBroadcastable =
vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
vector::BroadcastableToResult::Success;
- assert(isBroadcastable &&
- "(I think) it must be broadcastable at this point.");
-
- return true;
- }
-
- LogicalResult matchAndRewrite(vector::TransposeOp transpose,
- PatternRewriter &rewriter) const override {
- if (!canFoldIntoPrecedingBroadcast(transpose))
- return failure();
+ assert(isBroadcastable && "It must be broadcastable at this point.");
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
transpose, transpose.getResultVectorType(), transpose.getVector());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 03a338985299d..d3eb52b4ac037 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -66,6 +66,18 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
// -----
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+ %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+// CHECK: return %[[BC]] : vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
func.func @create_mask_transpose_to_transposed_create_mask(
@@ -2215,80 +2227,6 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
// -----
-// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
-// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
-// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
-// CHECK: return %[[RES]] : vector<2x3x4xi8>
-func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
- %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
- %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
- return %1 : vector<2x3x4xi8>
-}
-
-// -----
-
-// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
-// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
-// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
-// CHECK: return %[[RES]] : vector<2x3x4xi8>
-func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
- %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
- %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
- return %1 : vector<2x3x4xi8>
-}
-
-// -----
-
-// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
-// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
-// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
-// CHECK: return %[[RES]] : vector<8x1xi8>
-func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
- %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
- %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
- return %1 : vector<8x1xi8>
-}
-
-// -----
-
-// CHECK-LABEL: broadcast_transpose_mixed_example_folds
-// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
-// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
-// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
-func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
- %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
- %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
- return %1 : vector<3x2x4x5x6x7xi8>
-}
-
-// -----
-
-// CHECK-LABEL: broadcast_transpose_102_nofold
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
-// CHECK: return %[[TRP]] : vector<3x3x3xi8>
-func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
- %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
- %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
- return %1 : vector<3x3x3xi8>
-}
-
-// -----
-
-// CHECK-LABEL: broadcast_transpose_021_nofold
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
-// CHECK: return %[[TRP]] : vector<3x3x3xi8>
-func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
- %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
- %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
- return %1 : vector<3x3x3xi8>
-}
-
-// -----
-
// CHECK-LABEL: func.func @insert_1d_constant
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
diff --git a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
new file mode 100644
index 0000000000000..0a9af4534a89d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose.
+
+intentional bug to sanity check CI picks this new test up
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+// CHECK: return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+ %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+ return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+// CHECK: return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+ %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+ return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+ %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+ return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_square_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
+// CHECK: return %[[TRP]] : vector<4x4xi8>
+func.func @broadcast_transpose_square_nofold(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_hypercube_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
+// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
+func.func @broadcast_transpose_hypercube_nofold(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
+ %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
+ return %1 : vector<4x4x4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
+// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+// CHECK: return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+ %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+ %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+ return %1 : vector<3x3x3xi8>
+}
+
>From 7b8fc7bf6717e8acf8bebbcb6e80cd07110d4002 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 11:17:31 -0700
Subject: [PATCH 4/5] confimed that test failed with planted bug
---
mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
index 0a9af4534a89d..973796fa0b94f 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
@@ -2,8 +2,6 @@
// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose.
-intentional bug to sanity check CI picks this new test up
-
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
>From 482da07f5186c52fdcfe6a948faa3c9091d18496 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 11:22:06 -0700
Subject: [PATCH 5/5] whitespace
---
mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
index 973796fa0b94f..c41c2013e0107 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
-// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose.
+// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose.
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
More information about the Mlir-commits
mailing list