[Mlir-commits] [mlir] [mlir][linalg] Enhance `isaInlinedFillOp` (PR #151155)
Longsheng Mou
llvmlistbot at llvm.org
Tue Jul 29 07:06:58 PDT 2025
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/151155
This PR extends `isaInlinedFillOp` to support converting a generic operation with unused input operands to `linalg.fill`.
>From 5a406da6417e866e3c76ebd428a7d43c0fc0ee83 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 29 Jul 2025 11:37:59 +0800
Subject: [PATCH] [mlir][linalg] Enhance `isaInlinedFillOp`
This PR extends `isaInlinedFillOp` to support converting a generic
operation with unused input operands to `linalg.fill`.
---
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 3 +-
.../Linalg/specialize-generic-ops-fail.mlir | 17 -----------
.../Linalg/transform-op-specialize.mlir | 29 +++++++++++--------
3 files changed, 18 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1eb96b5..66c282ef155a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
/// constant. If so, returns the constant value. Otherwise, returns
/// std::nullopt.
static std::optional<Value> isaInlinedFillOp(GenericOp op) {
- if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
- op.getNumDpsInputs() != 0)
+ if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1)
return std::nullopt;
// Init should not be referenced.
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 5d66837fca510..357f2c11a7936 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,20 +29,3 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
} -> tensor<8xi32>
return %res : tensor<8xi32>
}
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @not_copy
-// CHECK-NOT: linalg.copy
-// CHECK: linalg.generic
-func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
- %c0_i32 = arith.constant 0 : i32
- %res = linalg.generic {
- indexing_maps = [#map, #map], iterator_types = ["parallel"]
- } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
- ^bb0(%in: i32, %out: i32):
- linalg.yield %c0_i32 : i32
- } -> tensor<8xi32>
- return %res : tensor<8xi32>
-}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8ede2e0add10b..801c834a36970 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -142,25 +142,15 @@ func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
} -> tensor<7x7xf32>
return %0 : tensor<7x7xf32>
}
+
// CHECK-LABEL: linalg_generic_fill
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
%cst = arith.constant 0.000000e+00 : f32
- %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
+ %0 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
^bb0(%out: f32):
linalg.yield %cst : f32
} -> tensor<7x7xf32>
@@ -172,6 +162,21 @@ func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tenso
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+func.func @linalg_generic_inlined_constant_fill_has_input(%input: tensor<8x8xi32>, %init: tensor<8x8xi32>) -> tensor<8x8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %res = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%input: tensor<8x8xi32>) outs(%init: tensor<8x8xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ linalg.yield %c0_i32 : i32
+ } -> tensor<8x8xi32>
+ return %res : tensor<8x8xi32>
+}
+
+// CHECK-LABEL: func @linalg_generic_inlined_constant_fill_has_input
+// CHECK-SAME: %[[INPUT:.+]]: tensor<8x8xi32>,
+// CHECK-SAME: %[[INIT:.+]]: tensor<8x8xi32>) -> tensor<8x8xi32>
+// CHECK: %[[CST:.+]] = arith.constant 0 : i32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT]] : tensor<8x8xi32>) -> tensor<8x8xi32>
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
More information about the Mlir-commits
mailing list