[Mlir-commits] [mlir] 10c0b42 - [mlir][linalg] Add splat broadcast canonicalization pattern (#195980)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 20 23:21:11 PDT 2026


Author: Hocky Yudhiono
Date: 2026-05-21T14:21:06+08:00
New Revision: 10c0b42b5086e8d69dff8f3d46f5d28b374e50d8

URL: https://github.com/llvm/llvm-project/commit/10c0b42b5086e8d69dff8f3d46f5d28b374e50d8
DIFF: https://github.com/llvm/llvm-project/commit/10c0b42b5086e8d69dff8f3d46f5d28b374e50d8.diff

LOG: [mlir][linalg] Add splat broadcast canonicalization pattern (#195980)

Add `linalg.broadcast` splat constant to `linalg.fill` canonicalization.

Assisted-by: Cursor (GPT-5.5)

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6ebf147394356..de7f4d1610bd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2405,9 +2405,39 @@ struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
   }
 };
 
+/// Rewrite a broadcast of a dense splat constant into a dense splat constant of
+/// the broadcast output shape.
+struct FoldBroadcastSplatConstant : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (!broadcastOp.hasPureTensorSemantics())
+      return failure();
+
+    auto splatValue =
+        getScalarConstantAttrFromDenseSplat(broadcastOp.getInput());
+
+    if (!splatValue.has_value())
+      return failure();
+
+    auto resultType =
+        cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
+    if (!resultType.hasStaticShape())
+      return rewriter.notifyMatchFailure(broadcastOp,
+                                         "result type has dynamic shape");
+
+    auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(broadcastOp, resultType,
+                                                   resultAttr);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
+              FoldBroadcastSplatConstant>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 8896cf1ab7e2d..12bdeb84e47e0 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,51 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL: @broadcast_splat_constant_to_dense
+//       CHECK:   %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<2x3xf32>
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[CST]] : tensor<2x3xf32>
+func.func @broadcast_splat_constant_to_dense(%init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %cst = arith.constant dense<1.000000e+00> : tensor<3xf32>
+  %0 = linalg.broadcast
+      ins(%cst: tensor<3xf32>)
+      outs(%init: tensor<2x3xf32>)
+      dimensions = [0]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_splat_constant_dynamic_shape
+//       CHECK:   %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<3xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[CST]] : tensor<3xf32>) outs({{.*}} : tensor<?x3xf32>) dimensions = [0]
+//       CHECK:   return %[[BROADCAST]] : tensor<?x3xf32>
+func.func @broadcast_splat_constant_dynamic_shape(%init: tensor<?x3xf32>) -> tensor<?x3xf32> {
+  %cst = arith.constant dense<1.000000e+00> : tensor<3xf32>
+  %0 = linalg.broadcast
+      ins(%cst: tensor<3xf32>)
+      outs(%init: tensor<?x3xf32>)
+      dimensions = [0]
+  return %0 : tensor<?x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_non_splat_constant
+//       CHECK:   %[[CST:.+]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[CST]] : tensor<3xf32>) outs({{.*}} : tensor<2x3xf32>) dimensions = [0]
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3xf32>
+func.func @broadcast_non_splat_constant(%init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
+  %0 = linalg.broadcast
+      ins(%cst: tensor<3xf32>)
+      outs(%init: tensor<2x3xf32>)
+      dimensions = [0]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @broadcast_broadcast_fold
 //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
 //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>


        


More information about the Mlir-commits mailing list