[Mlir-commits] [mlir] [MLIR][IR] Enable linalg fusion for custom dialect constants (PR #177740)
Ryan Kim
llvmlistbot at llvm.org
Sat Jan 24 17:53:39 PST 2026
https://github.com/chokobole updated https://github.com/llvm/llvm-project/pull/177740
>From 060001f0c7b3e73385d57ef024e9843f4cc8fae9 Mon Sep 17 00:00:00 2001
From: Ryan Kim <chokobole33 at gmail.com>
Date: Sun, 25 Jan 2026 10:52:54 +0900
Subject: [PATCH] [mlir][linalg] Use materializeConstant for dialect-agnostic
constant creation
Change scalar constant creation in ElementwiseOpFusion to use
the dialect's materializeConstant method instead of directly
creating arith::ConstantOp. This makes the fusion pattern work
correctly with constants from any dialect, not just arith.
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 8 ++++--
.../Linalg/fusion-elementwise-ops.mlir | 27 +++++++++++++++++++
2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 72acd02d0d13d..ffc4a43123731 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2307,8 +2307,12 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
}
// Create a constant scalar value from the splat constant.
- Value scalarConstant =
- arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
+ Operation *scalarConstantOp = def->getDialect()->materializeConstant(
+ rewriter, constantAttr, constantAttr.getType(), def->getLoc());
+ if (!scalarConstantOp)
+ return rewriter.notifyMatchFailure(
+ genericOp, "failed to materialize scalar constant");
+ Value scalarConstant = scalarConstantOp->getResult(0);
SmallVector<Value> outputOperands = genericOp.getOutputs();
auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 6f1a422324e08..3b97309731ade 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1056,3 +1056,30 @@ module {
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
+
+// -----
+
+// Test that splat constant fusion uses dialect's materializeConstant.
+// This test uses test.constant to verify non-arith constant creation.
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @generic_op_test_constant_fusion(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32>
+{
+ %cst = "test.constant"() {value = dense<3.0> : tensor<4x8xf32>} : () -> tensor<4x8xf32>
+ %0 = tensor.empty() : tensor<4x8xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%cst, %arg0 : tensor<4x8xf32>, tensor<4x8xf32>)
+ outs(%0 : tensor<4x8xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %2 = arith.mulf %arg1, %arg2 : f32
+ linalg.yield %2 : f32
+ } -> tensor<4x8xf32>
+ return %1 : tensor<4x8xf32>
+}
+// CHECK-LABEL: func @generic_op_test_constant_fusion
+// CHECK: %[[CST:.*]] = "test.constant"() <{value = 3.000000e+00 : f32}> : () -> f32
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}]
+// CHECK: ^{{.*}}(%[[ARG1:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
+// CHECK: arith.mulf %[[ARG1]], %[[CST]]
More information about the Mlir-commits
mailing list