[Mlir-commits] [mlir] [mlir][linalg] Enhance `isaInlinedFillOp` (PR #151155)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 29 07:07:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR extends `isaInlinedFillOp` to support converting a generic operation with unused input operands to `linalg.fill`.

---
Full diff: https://github.com/llvm/llvm-project/pull/151155.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+1-2) 
- (modified) mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir (-17) 
- (modified) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (+17-12) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1eb96b5..66c282ef155a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
 /// constant. If so, returns the constant value. Otherwise, returns
 /// std::nullopt.
 static std::optional<Value> isaInlinedFillOp(GenericOp op) {
-  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
-      op.getNumDpsInputs() != 0)
+  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1)
     return std::nullopt;
 
   // Init should not be referenced.
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 5d66837fca510..357f2c11a7936 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,20 +29,3 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
   } -> tensor<8xi32>
   return %res : tensor<8xi32>
 }
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @not_copy
-//  CHECK-NOT:    linalg.copy
-//      CHECK:    linalg.generic
-func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
-  %c0_i32 = arith.constant 0 : i32
-  %res = linalg.generic {
-    indexing_maps = [#map, #map], iterator_types = ["parallel"]
-  } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
-  ^bb0(%in: i32, %out: i32):
-    linalg.yield %c0_i32 : i32
-  } -> tensor<8xi32>
-  return %res : tensor<8xi32>
-}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8ede2e0add10b..801c834a36970 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -142,25 +142,15 @@ func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
   } -> tensor<7x7xf32>
   return %0 : tensor<7x7xf32>
 }
+
 // CHECK-LABEL: linalg_generic_fill
 // CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
 // CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
+  %0 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
   ^bb0(%out: f32):
     linalg.yield %cst : f32
   } -> tensor<7x7xf32>
@@ -172,6 +162,21 @@ func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tenso
 // CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
 
+func.func @linalg_generic_inlined_constant_fill_has_input(%input: tensor<8x8xi32>, %init: tensor<8x8xi32>) -> tensor<8x8xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %res = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%input: tensor<8x8xi32>) outs(%init: tensor<8x8xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    linalg.yield %c0_i32 : i32
+  } -> tensor<8x8xi32>
+  return %res : tensor<8x8xi32>
+}
+
+// CHECK-LABEL: func @linalg_generic_inlined_constant_fill_has_input
+//  CHECK-SAME:   %[[INPUT:.+]]: tensor<8x8xi32>,
+//  CHECK-SAME:   %[[INIT:.+]]: tensor<8x8xi32>) -> tensor<8x8xi32>
+//       CHECK: %[[CST:.+]] = arith.constant 0 : i32
+//       CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT]] : tensor<8x8xi32>) -> tensor<8x8xi32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op

``````````

</details>


https://github.com/llvm/llvm-project/pull/151155


More information about the Mlir-commits mailing list