[Mlir-commits] [mlir] [mlir][linalg] Add support for inlined const to isaFillOpInterface (PR #144870)
Shay Kleiman
llvmlistbot at llvm.org
Thu Jun 19 03:51:21 PDT 2025
https://github.com/shay-kl created https://github.com/llvm/llvm-project/pull/144870
None
>From 3cf863dcb1524cd9fb1b4dd4df543d3bb33806ae Mon Sep 17 00:00:00 2001
From: Shay Kleiman <shay.kleiman at mobileye.com>
Date: Thu, 19 Jun 2025 13:44:38 +0300
Subject: [PATCH] [mlir][linalg] Add support for inlined const to
isaFillOpInterface
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 3 ++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 38 ++++++++++++++++++-
.../Dialect/Linalg/Transforms/Specialize.cpp | 5 ++-
.../Linalg/transform-op-specialize.mlir | 25 ++++++++++++
4 files changed, 68 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index df32cafd2d024..0f960fb5ad795 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -142,6 +142,9 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+/// Supports two patterns:
+/// 1. External: linalg.generic ins(%scalar) outs(%tensor) { yield %scalar }
+/// 2. Inlined: linalg.generic outs(%tensor) { yield %constant }
/// Returns the scalar fill value if true.
std::optional<Value> isaFillOpInterface(GenericOp genericOp);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7d1844df42195..139e9901b0a29 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -77,7 +77,37 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
-std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
+/// Detects if a linalg.generic operation represents a fill with an inlined
+/// constant. If so, returns the constant value. Otherwise, returns
+/// std::nullopt.
+static std::optional<Value> isaInlinedFillOp(GenericOp op) {
+ if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
+ op.getNumDpsInputs() != 0)
+ return std::nullopt;
+
+ // Init should not be referenced.
+ if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
+ return std::nullopt;
+
+ Block *body = op.getBody();
+ if (body->getOperations().size() != 1)
+ return std::nullopt;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1)
+ return std::nullopt;
+
+ Value yieldOperand = yieldOp->getOperand(0);
+ if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
+ !yieldOperand.getDefiningOp<complex::ConstantOp>())
+ return std::nullopt;
+
+ return yieldOperand;
+}
+
+/// Detects if a linalg.generic operation represents an external scalar input.
+/// If so, returns the constant value. Otherwise, returns std::nullopt.
+static std::optional<Value> isaExternalFillOp(GenericOp op) {
// Structural.
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
!op.isSingleYieldOp())
@@ -94,6 +124,12 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
return value->get();
}
+std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
+ if (auto fillVal = isaInlinedFillOp(op))
+ return fillVal;
+ return isaExternalFillOp(op);
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 512fb7555a6b7..933f5966cc6ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -267,9 +267,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
// Fill
- if (isaFillOpInterface(genericOp)) {
+ if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
+ // Always use the detected fill value, regardless of pattern
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
- genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ genericOp, *fillValue, genericOp.getDpsInits()[0]);
return namedOp;
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 31f2f6b1ab513..8ede2e0add10b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -154,3 +154,28 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
+ ^bb0(%out: f32):
+ linalg.yield %cst : f32
+ } -> tensor<7x7xf32>
+ return %0 : tensor<7x7xf32>
+}
+
+// CHECK-LABEL: linalg_generic_inlined_constant_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list