[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)

James Newling llvmlistbot at llvm.org
Mon Apr 14 08:10:24 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135096

>From 2498d7d11a3da7c8cdca6646fa6f23cd24053e50 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 13:51:54 -0700
Subject: [PATCH 1/5]  transpose(broadcast) -> broadcast folder (squashed)

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  97 ++++++++++++++++-
 .../Vector/canonicalize/vector-transpose.mlir | 102 ++++++++++++++++++
 2 files changed, 198 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 754dab21ee1f3..4ecf65153a167 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6151,12 +6151,107 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+///                                                 to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+///
+/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
+/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
+/// broadcasting from 1x1x4x1x1x7.
+///                   ^^^ ^ ^^^ ^
+///          groups:   0  1  2  3
+/// Order preserving permutations for this example are ones that only permute
+/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    vector::BroadcastOp broadcast =
+        transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast) {
+      return rewriter.notifyMatchFailure(transpose,
+                                         "not preceded by a broadcast");
+    }
+
+    auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid, and
+    // transpose(broadcast(all ones)) -> broadcast(all ones) is always valid
+    bool inputIsScalar = !inputType;
+    bool inputIsSizeOneVector = inputType.getNumElements() == 1;
+    if (inputIsScalar || inputIsSizeOneVector) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          transpose, transpose.getResultVectorType(), transpose.getVector());
+      return success();
+    }
+
+    ArrayRef<int64_t> permutation = transpose.getPermutation();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t inputRank = inputType.getRank();
+    int64_t outputRank = transpose.getType().getRank();
+    int64_t deltaRank = outputRank - inputRank;
+
+    int low = 0;
+    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+      bool notOne = inputShape[inputIndex] != 1;
+      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+      bool groupEndFound = notOne || prevNotOne;
+      if (groupEndFound) {
+        // Return failure if all not permutation destinations for indices in
+        // [low, high) are in [low, high), i.e. the permutation is not local to
+        // the group.
+        int high = inputIndex + deltaRank;
+        for (int i = low; i < high; ++i) {
+          if (permutation[i] < low || permutation[i] >= high) {
+            return rewriter.notifyMatchFailure(
+                transpose, "permutation not local to group");
+          }
+        }
+      }
+    }
+
+    // We don't need to check the final group [low, outputRank) because
+    // if it is not locally bound, there must be a preceding group that
+    // already failed the check (impossible to have just 1 non-locally
+    // bound group).
+
+    // The preceding logic also ensures that at this point, the output of the
+    // transpose is definitely broadcastable from the input shape, so we
+    // don't need to check vector::isBroadcastableTo now.
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        transpose, transpose.getResultVectorType(), transpose.getVector());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
new file mode 100644
index 0000000000000..87aad69d62508
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some canonicalizations tests involving vector.transpose.
+
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast
+//  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+func.func @scalar_broadcast_transpose_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
+//       CHECK:  %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+  %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[BC]] : vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+//       CHECK:  return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+  %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+  return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example
+//  CHECK-SAME:  %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+//       CHECK:  return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+  %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+  return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_square
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
+//       CHECK:  return %[[TRP]] : vector<4x4xi8>
+func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_hypercube
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
+//       CHECK:  return %[[TRP]] : vector<4x4x4x4xi8>
+func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
+  %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
+  return %1 : vector<4x4x4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_102
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: negative_broadcast_transpose_021
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @neagtive_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+

>From d3fe38a6f972ae9c6e5793b92c45855fba07f930 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 11 Apr 2025 11:13:47 -0700
Subject: [PATCH 2/5] simplify and add one test

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 15 +++++------
 .../Vector/canonicalize/vector-transpose.mlir | 26 ++++++++++++++-----
 2 files changed, 26 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4ecf65153a167..93ea5ba860324 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6216,10 +6216,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
       bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
       bool groupEndFound = notOne || prevNotOne;
       if (groupEndFound) {
-        // Return failure if all not permutation destinations for indices in
+        int high = inputIndex + deltaRank;
+        // Return failure if not all permutation destinations for indices in
         // [low, high) are in [low, high), i.e. the permutation is not local to
         // the group.
-        int high = inputIndex + deltaRank;
         for (int i = low; i < high; ++i) {
           if (permutation[i] < low || permutation[i] >= high) {
             return rewriter.notifyMatchFailure(
@@ -6229,14 +6229,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
       }
     }
 
-    // We don't need to check the final group [low, outputRank) because
-    // if it is not locally bound, there must be a preceding group that
-    // already failed the check (impossible to have just 1 non-locally
-    // bound group).
+    // We don't need to check the final group [low, outputRank) because if it is
+    // not locally bound, there must be a preceding group that already failed
+    // the check (impossible to have just 1 non-locally bound group).
 
     // The preceding logic also ensures that at this point, the output of the
-    // transpose is definitely broadcastable from the input shape, so we
-    // don't need to check vector::isBroadcastableTo now.
+    // transpose is definitely broadcastable from the input shape, so we don't
+    // need to check vector::isBroadcastableTo now.
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         transpose, transpose.getResultVectorType(), transpose.getVector());
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 87aad69d62508..28ce34b6ec9cd 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -2,9 +2,9 @@
 
 // This file contains some canonicalizations tests involving vector.transpose.
 
-// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast
+// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
 //  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
-func.func @scalar_broadcast_transpose_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
+func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
 //       CHECK:  %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
   %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
   %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
@@ -14,11 +14,11 @@ func.func @scalar_broadcast_transpose_to_broadcast(%arg0 : i8) -> vector<2x3x4xi
 
 // -----
 
-// CHECK-LABEL: ones_broadcast_transpose_to_broadcast
+// CHECK-LABEL: broadcast_transpose_ones_to_broadcast
 //  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
 //       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
 //       CHECK:  return %[[RES]] : vector<2x3x4xi8>
-func.func @ones_broadcast_transpose_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
   %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
   %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
   return %1 : vector<2x3x4xi8>
@@ -26,11 +26,11 @@ func.func @ones_broadcast_transpose_to_broadcast(%arg0 : vector<1x1x1xi8>) -> ve
 
 // -----
 
-// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast
+// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast
 //  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
 //       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
 //       CHECK:  return %[[RES]] : vector<8x1xi8>
-func.func @partial_ones_broadcast_transpose_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
   %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
   %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
   return %1 : vector<8x1xi8>
@@ -50,6 +50,18 @@ func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vect
 
 // -----
 
+// CHECK-LABEL: broadcast_transpose_final_group
+//  CHECK-SAME:  %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8>
+//       CHECK:  return %[[RES]] : vector<4x7x2x3xi8>
+func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8>
+  %1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8>
+  return %1 : vector<4x7x2x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: negative_broadcast_transpose_square
 //  CHECK-SAME:  %[[ARG:.*]]:
 //       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
@@ -94,7 +106,7 @@ func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<
 //       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
 //       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
 //       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @neagtive_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
   %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
   %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
   return %1 : vector<3x3x3xi8>

>From af69672e2dd85265ad311665b51844992659924a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 11 Apr 2025 13:45:13 -0700
Subject: [PATCH 3/5] remove subsumed pattern, remove edge case check for
 nelms=1 (not needed)

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++---------------------
 1 file changed, 4 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93ea5ba860324..1fb02a49516ff 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6075,28 +6075,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
-struct FoldTransposedScalarBroadcast final
-    : public OpRewritePattern<vector::TransposeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
-    if (!bcastOp)
-      return failure();
-
-    auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
-    if (!srcVectorType || srcVectorType.getNumElements() == 1) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
 public:
@@ -6194,11 +6172,9 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 
     auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
 
-    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid, and
-    // transpose(broadcast(all ones)) -> broadcast(all ones) is always valid
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
     bool inputIsScalar = !inputType;
-    bool inputIsSizeOneVector = inputType.getNumElements() == 1;
-    if (inputIsScalar || inputIsSizeOneVector) {
+    if (inputIsScalar) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
           transpose, transpose.getResultVectorType(), transpose.getVector());
       return success();
@@ -6248,9 +6224,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
-      context);
+  results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
+              FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//

>From 7df6355d4c63fefa5a0c6f43b861ee2ebe2ce9f6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 14 Apr 2025 07:44:03 -0700
Subject: [PATCH 4/5] move tests for removed pattern to location of subsuming
 pattern's tests

---
 mlir/test/Dialect/Vector/canonicalize.mlir    | 24 ------------------
 .../Vector/canonicalize/vector-transpose.mlir | 25 +++++++++++++++++++
 2 files changed, 25 insertions(+), 24 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6d82b85777b0..0065635b662d5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2218,30 +2218,6 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
 
 // -----
 
-// CHECK-LABEL: func @transpose_scalar_broadcast1
-//  CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
-//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
-//       CHECK:   return %[[V]] : vector<1x8xf32>
-func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
-  %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
-  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
-  return %t : vector<1x8xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_scalar_broadcast2
-//  CHECK-SAME: (%[[ARG:.+]]: f32)
-//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
-//       CHECK:   return %[[V]] : vector<1x8xf32>
-func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
-  %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
-  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
-  return %t : vector<1x8xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @transpose_splat_constant
 //       CHECK:   %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
 //       CHECK:   return %[[CST]]
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 28ce34b6ec9cd..e97e147459de2 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -2,6 +2,31 @@
 
 // This file contains some canonicalizations tests involving vector.transpose.
 
+// CHECK-LABEL: func @transpose_scalar_broadcast1
+//  CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
+//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
+//       CHECK:   return %[[V]] : vector<1x8xf32>
+func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
+  %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
+  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+  return %t : vector<1x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_scalar_broadcast2
+//  CHECK-SAME: (%[[ARG:.+]]: f32)
+//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
+//       CHECK:   return %[[V]] : vector<1x8xf32>
+func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
+  %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
+  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+  return %t : vector<1x8xf32>
+}
+
+// -----
+
+
 // CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
 //  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
 func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {

>From 384a5ca121b9fc842c4822c283a5df86f61d4448 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 14 Apr 2025 08:14:44 -0700
Subject: [PATCH 5/5] add back assert

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb02a49516ff..a600f8114404f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6171,12 +6171,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     }
 
     auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+    VectorType outputType = transpose.getResultVectorType();
 
     // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
     bool inputIsScalar = !inputType;
     if (inputIsScalar) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          transpose, transpose.getResultVectorType(), transpose.getVector());
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
+                                                       transpose.getVector());
       return success();
     }
 
@@ -6210,11 +6211,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     // the check (impossible to have just 1 non-locally bound group).
 
     // The preceding logic also ensures that at this point, the output of the
-    // transpose is definitely broadcastable from the input shape, so we don't
-    // need to check vector::isBroadcastableTo now.
+    // transpose is definitely broadcastable from the input shape, assert so:
+    assert(vector::isBroadcastableTo(inputType, outputType) ==
+               vector::BroadcastableToResult::Success &&
+           "not broadcastable directly to transpose output");
 
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        transpose, transpose.getResultVectorType(), transpose.getVector());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
+                                                     transpose.getVector());
 
     return success();
   }



More information about the Mlir-commits mailing list