[Mlir-commits] [mlir] [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (PR #150825)

Longsheng Mou llvmlistbot at llvm.org
Mon Jul 28 05:36:30 PDT 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/150825

>From 23f53b5ad7b0520e5fadf24e22b570bc9e030fc0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 27 Jul 2025 18:49:38 +0800
Subject: [PATCH 1/2] [mlir][linalg] Add folder for broadcast(broadcast) ->
 broadcast

Back to back `linalg.broadcast` can be rewritten to a single broadcast.
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 33 +++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 46 ++++++++++++++++++++++
 2 files changed, 78 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b661781f10f..81b66b692ca3c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2292,9 +2292,40 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+/// Fold broadcast with broadcast.
+struct FoldBroadcastWithBroadcast : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+    if (!defBroadcastOp)
+      return failure();
+    ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    SmallVector<int64_t> foldedDims(dimensions);
+    Value init = broadcastOp.getInit();
+    int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+    // Mapping from input dims to init dims.
+    SmallVector<int64_t> dimMap;
+    for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+      if (!llvm::is_contained(dimensions, dim))
+        dimMap.push_back(dim);
+    }
+    for (auto dim : defDimensions)
+      foldedDims.push_back(dimMap[dim]);
+
+    llvm::sort(foldedDims);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcastWithBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cbb56e4de884..fc88bb2dad0bd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x3xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x3xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x3xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [2]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x4xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x4xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x4xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [1]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
   %transpose = linalg.transpose

>From ba7f3ea8abb2336693a1478fc2a637a65900e0ba Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 28 Jul 2025 20:35:47 +0800
Subject: [PATCH 2/2] rename pattern

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 81b66b692ca3c..232ed60a0e8ce 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2292,8 +2292,8 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
-/// Fold broadcast with broadcast.
-struct FoldBroadcastWithBroadcast : OpRewritePattern<linalg::BroadcastOp> {
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
   using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
@@ -2324,8 +2324,7 @@ struct FoldBroadcastWithBroadcast : OpRewritePattern<linalg::BroadcastOp> {
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcastWithBroadcast>(
-      context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list