[Mlir-commits] [mlir] [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (PR #150825)

Longsheng Mou llvmlistbot at llvm.org
Sun Jul 27 03:53:20 PDT 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/150825

Back to back `linalg.broadcast` can be rewritten to a single broadcast.

>From 11eb0e0343244b60524827361d4a99cf5c5a66f2 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 27 Jul 2025 18:49:38 +0800
Subject: [PATCH] [mlir][linalg] Add folder for broadcast(broadcast) ->
 broadcast

Back to back `linalg.broadcast` can be rewritten to a single broadcast.
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 33 +++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 46 ++++++++++++++++++++++
 2 files changed, 78 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b661781f10f..c1544d571a938 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2292,9 +2292,40 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+/// Fold broadcast with broadcast.
+struct FoldBroadcastWithBroadcast : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+    if (!defBroadcastOp)
+      return failure();
+    ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    SmallVector<int64_t> foldedDims(dimensions);
+    Value init = broadcastOp.getInit();
+    int64_t initRank = init.getType().getRank();
+    // Mapping from input dims to init dims.
+    SmallVector<int64_t> dimMap;
+    for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+      if (!llvm::is_contained(dimensions, dim))
+        dimMap.push_back(dim);
+    }
+    for (auto dim : defDimensions)
+      foldedDims.push_back(dimMap[dim]);
+
+    llvm::sort(foldedDims);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcastWithBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cbb56e4de884..0e8546404f9e2 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) permutation = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x3xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x3xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x3xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [2]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) permutation = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x4xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x4xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x4xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [1]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
   %transpose = linalg.transpose



More information about the Mlir-commits mailing list