[Mlir-commits] [mlir] [MLIR][IR] Enable linalg fusion for custom dialect constants (PR #177740)

Ryan Kim llvmlistbot at llvm.org
Sat Jan 24 17:53:39 PST 2026


https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/177740

>From 060001f0c7b3e73385d57ef024e9843f4cc8fae9 Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Sun, 25 Jan 2026 10:52:54 +0900
Subject: [PATCH] [mlir][linalg] Use materializeConstant for dialect-agnostic
 constant creation

Change scalar constant creation in ElementwiseOpFusion to use
the dialect's materializeConstant method instead of directly
creating arith::ConstantOp. This makes the fusion pattern work
correctly with constants from any dialect, not just arith.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  8 ++++--
 .../Linalg/fusion-elementwise-ops.mlir        | 27 +++++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 72acd02d0d13d..ffc4a43123731 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2307,8 +2307,12 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
       }
 
       // Create a constant scalar value from the splat constant.
-      Value scalarConstant =
-          arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+      Operation *scalarConstantOp = def->getDialect()->materializeConstant(
+          rewriter, constantAttr, constantAttr.getType(), def->getLoc());
+      if (!scalarConstantOp)
+        return rewriter.notifyMatchFailure(
+            genericOp, "failed to materialize scalar constant");
+      Value scalarConstant = scalarConstantOp->getResult(0);
 
       SmallVector<Value> outputOperands = genericOp.getOutputs();
       auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 6f1a422324e08..3b97309731ade 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1056,3 +1056,30 @@ module {
 // CHECK:         tensor.expand_shape
 // CHECK:         linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
 // CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
+
+// -----
+
+// Test that splat constant fusion uses dialect's materializeConstant.
+// This test uses test.constant to verify non-arith constant creation.
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @generic_op_test_constant_fusion(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32>
+{
+  %cst = "test.constant"() {value = dense<3.0> : tensor<4x8xf32>} : () -> tensor<4x8xf32>
+  %0 = tensor.empty() : tensor<4x8xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map0, #map0, #map0],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%cst, %arg0 : tensor<4x8xf32>, tensor<4x8xf32>)
+    outs(%0 : tensor<4x8xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      %2 = arith.mulf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4x8xf32>
+  return %1 : tensor<4x8xf32>
+}
+// CHECK-LABEL: func @generic_op_test_constant_fusion
+//       CHECK:   %[[CST:.*]] = "test.constant"() <{value = 3.000000e+00 : f32}> : () -> f32
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}]
+//       CHECK:   ^{{.*}}(%[[ARG1:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
+//       CHECK:     arith.mulf %[[ARG1]], %[[CST]]



More information about the Mlir-commits mailing list