[Mlir-commits] [mlir] [mlir][linalg] Add support for inlined const to isaFillOpInterface (PR #144870)

Shay Kleiman llvmlistbot at llvm.org
Thu Jun 19 03:51:21 PDT 2025


https://github.com/shay-kl created https://github.com/llvm/llvm-project/pull/144870

None

>From 3cf863dcb1524cd9fb1b4dd4df543d3bb33806ae Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shay.kleiman at mobileye.com>
Date: Thu, 19 Jun 2025 13:44:38 +0300
Subject: [PATCH] [mlir][linalg] Add support for inlined const to
 isaFillOpInterface

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  3 ++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 38 ++++++++++++++++++-
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  5 ++-
 .../Linalg/transform-op-specialize.mlir       | 25 ++++++++++++
 4 files changed, 68 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index df32cafd2d024..0f960fb5ad795 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -142,6 +142,9 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
 bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
 
 /// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+/// Supports two patterns:
+/// 1. External: linalg.generic ins(%scalar) outs(%tensor) { yield %scalar }
+/// 2. Inlined: linalg.generic outs(%tensor) { yield %constant }
 /// Returns the scalar fill value if true.
 std::optional<Value> isaFillOpInterface(GenericOp genericOp);
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7d1844df42195..139e9901b0a29 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -77,7 +77,37 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
 //===----------------------------------------------------------------------===//
 // FillOpInterface implementation
 //===----------------------------------------------------------------------===//
-std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
+/// Detects if a linalg.generic operation represents a fill with an inlined
+/// constant. If so, returns the constant value. Otherwise, returns
+/// std::nullopt.
+static std::optional<Value> isaInlinedFillOp(GenericOp op) {
+  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
+      op.getNumDpsInputs() != 0)
+    return std::nullopt;
+
+  // Init should not be referenced.
+  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
+    return std::nullopt;
+
+  Block *body = op.getBody();
+  if (body->getOperations().size() != 1)
+    return std::nullopt;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1)
+    return std::nullopt;
+
+  Value yieldOperand = yieldOp->getOperand(0);
+  if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
+      !yieldOperand.getDefiningOp<complex::ConstantOp>())
+    return std::nullopt;
+
+  return yieldOperand;
+}
+
+/// Detects if a linalg.generic operation represents an external scalar input.
+/// If so, returns the constant value. Otherwise, returns std::nullopt.
+static std::optional<Value> isaExternalFillOp(GenericOp op) {
   // Structural.
   if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
       !op.isSingleYieldOp())
@@ -94,6 +124,12 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
   return value->get();
 }
 
+std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
+  if (auto fillVal = isaInlinedFillOp(op))
+    return fillVal;
+  return isaExternalFillOp(op);
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOpInterface implementation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 512fb7555a6b7..933f5966cc6ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -267,9 +267,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   }
 
   // Fill
-  if (isaFillOpInterface(genericOp)) {
+  if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
+    // Always use the detected fill value, regardless of pattern
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
-        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+        genericOp, *fillValue, genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 31f2f6b1ab513..8ede2e0add10b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -154,3 +154,28 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
+  ^bb0(%out: f32):
+    linalg.yield %cst : f32
+  } -> tensor<7x7xf32>
+  return %0 : tensor<7x7xf32>
+}
+
+// CHECK-LABEL: linalg_generic_inlined_constant_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list