[Mlir-commits] [mlir] b30034d - [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (#150825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 29 04:25:29 PDT 2025


Author: Longsheng Mou
Date: 2025-07-29T19:25:26+08:00
New Revision: b30034da0ffaf67a144d062607e8f627e14227d1

URL: https://github.com/llvm/llvm-project/commit/b30034da0ffaf67a144d062607e8f627e14227d1
DIFF: https://github.com/llvm/llvm-project/commit/b30034da0ffaf67a144d062607e8f627e14227d1.diff

LOG: [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (#150825)

Back to back `linalg.broadcast` can be rewritten to a single broadcast.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b56a212c18cb3..34c63d378e1ca 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+    if (!defBroadcastOp)
+      return failure();
+    ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    SmallVector<int64_t> foldedDims(dimensions);
+    Value init = broadcastOp.getInit();
+    int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+    // Mapping from input dims to init dims.
+    SmallVector<int64_t> dimMap;
+    for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+      if (!llvm::is_contained(dimensions, dim))
+        dimMap.push_back(dim);
+    }
+    for (auto dim : defDimensions)
+      foldedDims.push_back(dimMap[dim]);
+
+    llvm::sort(foldedDims);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 39a7b1b1a2775..5c5f7e861d37d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x3xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x3xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x3xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [2]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x4xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x4xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x4xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [1]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
   %transpose = linalg.transpose


        


More information about the Mlir-commits mailing list