[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)

James Newling llvmlistbot at llvm.org
Thu Apr 10 11:17:32 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135096

>From a60254eef1e7f5ac4516d8ddfae3f849a95eac94 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 15:18:19 -0700
Subject: [PATCH 1/5] transpose(broadcast) -> broadcast folder

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 101 ++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir |  77 ++++++++++++++++
 2 files changed, 177 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..ba56242edd6f3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2609,6 +2609,7 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -6155,12 +6156,110 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' here means the flattened
+/// inputs and outputs of the transpose have identical values.
+///
+/// Example:
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+///                                                 to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+
+    vector::BroadcastOp broadcast =
+        transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast)
+      return false;
+
+    auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+    bool inputIsScalar = !inputType;
+    auto inputShape = inputType.getShape();
+    auto inputRank = inputType.getRank();
+    auto outputRank = transpose.getType().getRank();
+    auto deltaRank = outputRank - inputRank;
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+    if (inputIsScalar)
+      return true;
+
+    // Return true if all permutation destinations for indices in [low, high)
+    // are in [low, high), so the permutation is local to the group.
+    auto isGroupBound = [&](int low, int high) {
+      auto perm = transpose.getPermutation();
+      for (int j = low; j < high; ++j) {
+        if (perm[j] < low || perm[j] >= high) {
+          return false;
+        }
+      }
+      return true;
+    };
+
+    // Groups are either contiguous sequences  of 1s and non-1s (1-element
+    // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
+    // to broadcasting from 1x1x4x1x1x7.
+    //                      ^^^ ^ ^^^ ^
+    //               groups: 0  1  2  3
+    // Order preserving permutations for this example are ones that only permute
+    // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+    int low = 0;
+    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+      bool notOne = inputShape[inputIndex] != 1;
+      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+      bool groupEndFound = notOne || prevNotOne;
+      if (groupEndFound) {
+        int high = inputIndex + deltaRank;
+        if (!isGroupBound(low, high)) {
+          return false;
+        }
+        low = high;
+      }
+    }
+    if (!isGroupBound(low, outputRank)) {
+      return false;
+    }
+
+    bool isBroadcastable =
+        vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
+        vector::BroadcastableToResult::Success;
+    assert(isBroadcastable && "it should be broadcastable at this point");
+
+    return true;
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+    if (!canFoldIntoPrecedingBroadcast(transpose))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        transpose, transpose.getResultVectorType(), transpose.getVector());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..d443b85d40351 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+
+
+
 // CHECK-LABEL: create_vector_mask_to_constant_mask
 func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index
@@ -2215,6 +2218,80 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+//       CHECK:  return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+  %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+  return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+//       CHECK:  return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+  %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+  return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_1d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>

>From e63d35b0eff4d8787154d161f5525852aa3f32d8 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 15:42:30 -0700
Subject: [PATCH 2/5] tidy

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 26 +++++++++++++---------
 mlir/test/Dialect/Vector/canonicalize.mlir |  3 ---
 2 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ba56242edd6f3..05ff93da13aea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2609,7 +2609,6 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
-
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -6156,9 +6155,9 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(broadcast(x)) into broadcast(x) if the transpose is
-/// 'order preserving', where 'order preserving' here means the flattened
-/// inputs and outputs of the transpose have identical values.
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
 ///
 /// Example:
 /// ```
@@ -6188,10 +6187,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 
     auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
     bool inputIsScalar = !inputType;
-    auto inputShape = inputType.getShape();
-    auto inputRank = inputType.getRank();
-    auto outputRank = transpose.getType().getRank();
-    auto deltaRank = outputRank - inputRank;
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t inputRank = inputType.getRank();
+    int64_t outputRank = transpose.getType().getRank();
+    int64_t deltaRank = outputRank - inputRank;
 
     // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
     if (inputIsScalar)
@@ -6200,9 +6199,9 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     // Return true if all permutation destinations for indices in [low, high)
     // are in [low, high), so the permutation is local to the group.
     auto isGroupBound = [&](int low, int high) {
-      auto perm = transpose.getPermutation();
+      ArrayRef<int64_t> permutation = transpose.getPermutation();
       for (int j = low; j < high; ++j) {
-        if (perm[j] < low || perm[j] >= high) {
+        if (permutation[j] < low || permutation[j] >= high) {
           return false;
         }
       }
@@ -6233,10 +6232,15 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
       return false;
     }
 
+    // The preceding logic ensures that by this point, the ouutput of the
+    // transpose is definitely broadcastable from the input shape. So we don't
+    // need to call 'vector::isBroadcastableTo', but asserting here just as a
+    // sanity check:
     bool isBroadcastable =
         vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
         vector::BroadcastableToResult::Success;
-    assert(isBroadcastable && "it should be broadcastable at this point");
+    assert(isBroadcastable &&
+           "(I think) it must be broadcastable at this point.");
 
     return true;
   }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d443b85d40351..03a338985299d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1,8 +1,5 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
-
-
-
 // CHECK-LABEL: create_vector_mask_to_constant_mask
 func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index

>From 7a9035802e4d07fcd90028bf74726b797d5ee912 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 10:51:59 -0700
Subject: [PATCH 3/5] address review comments: notify match failure, and create
 new test file

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 74 ++++++++-------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 86 +++--------------
 .../Vector/vector-transpose-canonicalize.mlir | 92 +++++++++++++++++++
 3 files changed, 144 insertions(+), 108 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 05ff93da13aea..33d4f5eaab9d4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
 
 #include <cassert>
 #include <cstdint>
@@ -6172,49 +6173,57 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 /// The algorithm works by partitioning dimensions into groups that can be
 /// locally permuted while preserving order, and checks that the transpose
 /// only permutes within these groups.
+///
+/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
+/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
+/// broadcasting from 1x1x4x1x1x7.
+///                   ^^^ ^ ^^^ ^
+///          groups:   0  1  2  3
+/// Order preserving permutations for this example are ones that only permute
+/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
   FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
 
-  static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
 
     vector::BroadcastOp broadcast =
         transpose.getVector().getDefiningOp<vector::BroadcastOp>();
-    if (!broadcast)
-      return false;
+    if (!broadcast) {
+      return rewriter.notifyMatchFailure(transpose,
+                                         "not preceded by a broadcast");
+    }
 
     auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
     bool inputIsScalar = !inputType;
+    if (inputIsScalar) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          transpose, transpose.getResultVectorType(), transpose.getVector());
+      return success();
+    }
+
+    ArrayRef<int64_t> permutation = transpose.getPermutation();
     ArrayRef<int64_t> inputShape = inputType.getShape();
     int64_t inputRank = inputType.getRank();
     int64_t outputRank = transpose.getType().getRank();
     int64_t deltaRank = outputRank - inputRank;
 
-    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
-    if (inputIsScalar)
-      return true;
-
     // Return true if all permutation destinations for indices in [low, high)
     // are in [low, high), so the permutation is local to the group.
-    auto isGroupBound = [&](int low, int high) {
-      ArrayRef<int64_t> permutation = transpose.getPermutation();
-      for (int j = low; j < high; ++j) {
-        if (permutation[j] < low || permutation[j] >= high) {
+    auto isGroupBound = [permutation](int low, int high) {
+      for (int i = low; i < high; ++i) {
+        if (permutation[i] < low || permutation[i] >= high) {
           return false;
         }
       }
       return true;
     };
 
-    // Groups are either contiguous sequences  of 1s and non-1s (1-element
-    // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
-    // to broadcasting from 1x1x4x1x1x7.
-    //                      ^^^ ^ ^^^ ^
-    //               groups: 0  1  2  3
-    // Order preserving permutations for this example are ones that only permute
-    // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
     int low = 0;
     for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
       bool notOne = inputShape[inputIndex] != 1;
@@ -6223,32 +6232,29 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
       if (groupEndFound) {
         int high = inputIndex + deltaRank;
         if (!isGroupBound(low, high)) {
-          return false;
+          return rewriter.notifyMatchFailure(
+              transpose, llvm::formatv("output dimensions in interval [{0}, "
+                                       "{1}) aren't locally permuted.",
+                                       low, high));
         }
         low = high;
       }
     }
     if (!isGroupBound(low, outputRank)) {
-      return false;
+      return rewriter.notifyMatchFailure(
+          transpose,
+          llvm::formatv("output dimensions in final interval [{0}, {1}) "
+                        "aren't locally permuted.",
+                        low, outputRank));
     }
 
-    // The preceding logic ensures that by this point, the ouutput of the
-    // transpose is definitely broadcastable from the input shape. So we don't
-    // need to call 'vector::isBroadcastableTo', but asserting here just as a
-    // sanity check:
+    // The preceding logic ensures that at this point, the output of the
+    // transpose is definitely broadcastable from the input shape. We confirm
+    // this as a sanity check:
     bool isBroadcastable =
         vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
         vector::BroadcastableToResult::Success;
-    assert(isBroadcastable &&
-           "(I think) it must be broadcastable at this point.");
-
-    return true;
-  }
-
-  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
-                                PatternRewriter &rewriter) const override {
-    if (!canFoldIntoPrecedingBroadcast(transpose))
-      return failure();
+    assert(isBroadcastable && "It must be broadcastable at this point.");
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         transpose, transpose.getResultVectorType(), transpose.getVector());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 03a338985299d..d3eb52b4ac037 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -66,6 +66,18 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
 
 // -----
 
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+//       CHECK:  %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+  %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[BC]] : vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
 func.func @create_mask_transpose_to_transposed_create_mask(
@@ -2215,80 +2227,6 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
-// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
-//  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
-//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
-//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
-func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
-  %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
-  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
-  return %1 : vector<2x3x4xi8>
-}
-
-// -----
-
-// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
-//  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
-//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
-//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
-func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
-  %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
-  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
-  return %1 : vector<2x3x4xi8>
-}
-
-// -----
-
-// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
-//  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
-//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
-//       CHECK:  return %[[RES]] : vector<8x1xi8>
-func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
-  %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
-  %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
-  return %1 : vector<8x1xi8>
-}
-
-// -----
-
-// CHECK-LABEL: broadcast_transpose_mixed_example_folds
-//  CHECK-SAME:  %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
-//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
-//       CHECK:  return %[[RES]] : vector<3x2x4x5x6x7xi8>
-func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
-  %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
-  %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
-  return %1 : vector<3x2x4x5x6x7xi8>
-}
-
-// -----
-
-// CHECK-LABEL: broadcast_transpose_102_nofold
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
-//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
-  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
-  %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
-  return %1 : vector<3x3x3xi8>
-}
-
-// -----
-
-// CHECK-LABEL: broadcast_transpose_021_nofold
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
-//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
-  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
-  %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
-  return %1 : vector<3x3x3xi8>
-}
-
-// -----
-
 // CHECK-LABEL: func.func @insert_1d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
diff --git a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
new file mode 100644
index 0000000000000..0a9af4534a89d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose. 
+
+intentional bug to sanity check CI picks this new test up
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+//       CHECK:  return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+  %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+  return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+//       CHECK:  return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+  %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+  return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_square_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
+//       CHECK:  return %[[TRP]] : vector<4x4xi8>
+func.func @broadcast_transpose_square_nofold(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_hypercube_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
+//       CHECK:  return %[[TRP]] : vector<4x4x4x4xi8>
+func.func @broadcast_transpose_hypercube_nofold(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
+  %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
+  return %1 : vector<4x4x4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+

>From 7b8fc7bf6717e8acf8bebbcb6e80cd07110d4002 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 11:17:31 -0700
Subject: [PATCH 4/5] confimed that test failed with planted bug

---
 mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
index 0a9af4534a89d..973796fa0b94f 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
@@ -2,8 +2,6 @@
 
 // This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose. 
 
-intentional bug to sanity check CI picks this new test up
-
 // CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
 //  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
 //       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>

>From 482da07f5186c52fdcfe6a948faa3c9091d18496 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Apr 2025 11:22:06 -0700
Subject: [PATCH 5/5] whitespace

---
 mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
index 973796fa0b94f..c41c2013e0107 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-canonicalize.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
-// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose. 
+// This file contains some (but not all) tests of canonicalizations that eliminate vector.transpose.
 
 // CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
 //  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {



More information about the Mlir-commits mailing list