[Mlir-commits] [mlir] [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (PR #150825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 27 03:53:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Back to back `linalg.broadcast` can be rewritten to a single broadcast.

---
Full diff: https://github.com/llvm/llvm-project/pull/150825.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+32-1) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+46) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b661781f10f..c1544d571a938 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2292,9 +2292,40 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+/// Fold broadcast with broadcast.
+struct FoldBroadcastWithBroadcast : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+    if (!defBroadcastOp)
+      return failure();
+    ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    SmallVector<int64_t> foldedDims(dimensions);
+    Value init = broadcastOp.getInit();
+    int64_t initRank = init.getType().getRank();
+    // Mapping from input dims to init dims.
+    SmallVector<int64_t> dimMap;
+    for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+      if (!llvm::is_contained(dimensions, dim))
+        dimMap.push_back(dim);
+    }
+    for (auto dim : defDimensions)
+      foldedDims.push_back(dimMap[dim]);
+
+    llvm::sort(foldedDims);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcastWithBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cbb56e4de884..0e8546404f9e2 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) permutation = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x3xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x3xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x3xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [2]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) permutation = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x4xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x4xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x4xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [1]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
   %transpose = linalg.transpose

``````````

</details>


https://github.com/llvm/llvm-project/pull/150825


More information about the Mlir-commits mailing list