[Mlir-commits] [mlir] b30034d - [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (#150825)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 29 04:25:29 PDT 2025
Author: Longsheng Mou
Date: 2025-07-29T19:25:26+08:00
New Revision: b30034da0ffaf67a144d062607e8f627e14227d1
URL: https://github.com/llvm/llvm-project/commit/b30034da0ffaf67a144d062607e8f627e14227d1
DIFF: https://github.com/llvm/llvm-project/commit/b30034da0ffaf67a144d062607e8f627e14227d1.diff
LOG: [mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (#150825)
Back to back `linalg.broadcast` can be rewritten to a single broadcast.
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b56a212c18cb3..34c63d378e1ca 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+ if (!defBroadcastOp)
+ return failure();
+ ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ SmallVector<int64_t> foldedDims(dimensions);
+ Value init = broadcastOp.getInit();
+ int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensions, dim))
+ dimMap.push_back(dim);
+ }
+ for (auto dim : defDimensions)
+ foldedDims.push_back(dimMap[dim]);
+
+ llvm::sort(foldedDims);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+ results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 39a7b1b1a2775..5c5f7e861d37d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
// -----
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x3xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x3xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x3xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [2]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x4xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x4xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x4xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [1]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
%transpose = linalg.transpose
More information about the Mlir-commits
mailing list