[Mlir-commits] [mlir] [mlir][vector] transpose(broadcast) -> broadcast canonicalization (PR #135096)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 15:43:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

<details>
<summary>Changes</summary>

Example seen in the 'real world':
 
 ```
 %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
 %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
 ```
 
 This PR adds a canonicalizer that rewrites the above as
 
```
  %1 = vector.broadcast %arg0 : vector<1xi8> to vector<8x1xi8>
```

It works by determining if a transpose is only shuffling contiguous broadcast dimensions. 

---
Full diff: https://github.com/llvm/llvm-project/pull/135096.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+104-1) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+74) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..05ff93da13aea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6155,12 +6155,115 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
+/// 'order preserving', where 'order preserving' means the flattened
+/// inputs and outputs of the transpose have identical (numerical) values.
+///
+/// Example:
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
+///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
+///                                                 to vector<8x1xi32>
+/// ```
+/// can be rewritten as the equivalent
+/// ```
+///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
+/// ```
+/// The algorithm works by partitioning dimensions into groups that can be
+/// locally permuted while preserving order, and checks that the transpose
+/// only permutes within these groups.
+class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  static bool canFoldIntoPrecedingBroadcast(vector::TransposeOp transpose) {
+
+    vector::BroadcastOp broadcast =
+        transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast)
+      return false;
+
+    auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
+    bool inputIsScalar = !inputType;
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t inputRank = inputType.getRank();
+    int64_t outputRank = transpose.getType().getRank();
+    int64_t deltaRank = outputRank - inputRank;
+
+    // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
+    if (inputIsScalar)
+      return true;
+
+    // Return true if all permutation destinations for indices in [low, high)
+    // are in [low, high), so the permutation is local to the group.
+    auto isGroupBound = [&](int low, int high) {
+      ArrayRef<int64_t> permutation = transpose.getPermutation();
+      for (int j = low; j < high; ++j) {
+        if (permutation[j] < low || permutation[j] >= high) {
+          return false;
+        }
+      }
+      return true;
+    };
+
+    // Groups are either contiguous sequences  of 1s and non-1s (1-element
+    // groups). Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent
+    // to broadcasting from 1x1x4x1x1x7.
+    //                      ^^^ ^ ^^^ ^
+    //               groups: 0  1  2  3
+    // Order preserving permutations for this example are ones that only permute
+    // within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+    int low = 0;
+    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
+      bool notOne = inputShape[inputIndex] != 1;
+      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
+      bool groupEndFound = notOne || prevNotOne;
+      if (groupEndFound) {
+        int high = inputIndex + deltaRank;
+        if (!isGroupBound(low, high)) {
+          return false;
+        }
+        low = high;
+      }
+    }
+    if (!isGroupBound(low, outputRank)) {
+      return false;
+    }
+
+    // The preceding logic ensures that by this point, the ouutput of the
+    // transpose is definitely broadcastable from the input shape. So we don't
+    // need to call 'vector::isBroadcastableTo', but asserting here just as a
+    // sanity check:
+    bool isBroadcastable =
+        vector::isBroadcastableTo(inputType, transpose.getResultVectorType()) ==
+        vector::BroadcastableToResult::Success;
+    assert(isBroadcastable &&
+           "(I think) it must be broadcastable at this point.");
+
+    return true;
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+    if (!canFoldIntoPrecedingBroadcast(transpose))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        transpose, transpose.getResultVectorType(), transpose.getVector());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat, FoldTransposeBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..03a338985299d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2215,6 +2215,80 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @scalar_broadcast_transpose_to_broadcast_folds(%arg0 : i8) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
+//       CHECK:  return %[[RES]] : vector<2x3x4xi8>
+func.func @ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
+  return %1 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
+//       CHECK:  return %[[RES]] : vector<8x1xi8>
+func.func @partial_ones_broadcast_transpose_to_broadcast_folds(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
+  %0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
+  return %1 : vector<8x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_mixed_example_folds
+//  CHECK-SAME:  %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+//       CHECK:  %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
+//       CHECK:  return %[[RES]] : vector<3x2x4x5x6x7xi8>
+func.func @broadcast_transpose_mixed_example_folds(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
+  %0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
+  %1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
+  return %1 : vector<3x2x4x5x6x7xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_102_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_102_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_transpose_021_nofold
+//  CHECK-SAME:  %[[ARG:.*]]:
+//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
+//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
+//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
+func.func @broadcast_transpose_021_nofold(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+  %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
+  %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
+  return %1 : vector<3x3x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_1d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/135096


More information about the Mlir-commits mailing list