[Mlir-commits] [mlir] 10c0b42 - [mlir][linalg] Add splat broadcast canonicalization pattern (#195980)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 20 23:21:11 PDT 2026
Author: Hocky Yudhiono
Date: 2026-05-21T14:21:06+08:00
New Revision: 10c0b42b5086e8d69dff8f3d46f5d28b374e50d8
URL: https://github.com/llvm/llvm-project/commit/10c0b42b5086e8d69dff8f3d46f5d28b374e50d8
DIFF: https://github.com/llvm/llvm-project/commit/10c0b42b5086e8d69dff8f3d46f5d28b374e50d8.diff
LOG: [mlir][linalg] Add splat broadcast canonicalization pattern (#195980)
Add `linalg.broadcast` splat constant to `linalg.fill` canonicalization.
Assisted-by: Cursor (GPT-5.5)
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6ebf147394356..de7f4d1610bd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2405,9 +2405,39 @@ struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
}
};
+/// Rewrite a broadcast of a dense splat constant into a dense splat constant of
+/// the broadcast output shape.
+struct FoldBroadcastSplatConstant : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (!broadcastOp.hasPureTensorSemantics())
+ return failure();
+
+ auto splatValue =
+ getScalarConstantAttrFromDenseSplat(broadcastOp.getInput());
+
+ if (!splatValue.has_value())
+ return failure();
+
+ auto resultType =
+ cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
+ if (!resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(broadcastOp,
+ "result type has dynamic shape");
+
+ auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(broadcastOp, resultType,
+ resultAttr);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
+ results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
+ FoldBroadcastSplatConstant>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 8896cf1ab7e2d..12bdeb84e47e0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,51 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
// -----
+// CHECK-LABEL: @broadcast_splat_constant_to_dense
+// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<2x3xf32>
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[CST]] : tensor<2x3xf32>
+func.func @broadcast_splat_constant_to_dense(%init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<3xf32>
+ %0 = linalg.broadcast
+ ins(%cst: tensor<3xf32>)
+ outs(%init: tensor<2x3xf32>)
+ dimensions = [0]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_splat_constant_dynamic_shape
+// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<3xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[CST]] : tensor<3xf32>) outs({{.*}} : tensor<?x3xf32>) dimensions = [0]
+// CHECK: return %[[BROADCAST]] : tensor<?x3xf32>
+func.func @broadcast_splat_constant_dynamic_shape(%init: tensor<?x3xf32>) -> tensor<?x3xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<3xf32>
+ %0 = linalg.broadcast
+ ins(%cst: tensor<3xf32>)
+ outs(%init: tensor<?x3xf32>)
+ dimensions = [0]
+ return %0 : tensor<?x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_non_splat_constant
+// CHECK: %[[CST:.+]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[CST]] : tensor<3xf32>) outs({{.*}} : tensor<2x3xf32>) dimensions = [0]
+// CHECK: return %[[BROADCAST]] : tensor<2x3xf32>
+func.func @broadcast_non_splat_constant(%init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
+ %0 = linalg.broadcast
+ ins(%cst: tensor<3xf32>)
+ outs(%init: tensor<2x3xf32>)
+ dimensions = [0]
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @broadcast_broadcast_fold
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
More information about the Mlir-commits
mailing list