[Mlir-commits] [mlir] [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (PR #130000)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 5 23:29:49 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

Adds a check to make sure that the linalg op is safe to erase by ensuring that the `linalg.yield` is yielding one of the generic's block args.


Closes https://github.com/llvm/llvm-project/issues/129414

---
Full diff: https://github.com/llvm/llvm-project/pull/130000.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+11-7) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+20-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..c044c94c5af3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1285,13 +1285,17 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
 
     // In the buffer case, we need to check exact buffer equality.
     if (linalgOp.hasPureBufferSemantics()) {
-      if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
-          linalgOp.getDpsInputOperand(0)->get() ==
-              linalgOp.getDpsInitOperand(0)->get()) {
-        rewriter.eraseOp(linalgOp);
-        return success();
-      }
-      return failure();
+      if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+          linalgOp.getDpsInputOperand(0)->get() !=
+              linalgOp.getDpsInitOperand(0)->get())
+        return failure();
+
+      auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+      if (!yieldArg || yieldArg.getOwner() != &body)
+        return failure();
+
+      rewriter.eraseOp(linalgOp);
+      return success();
     }
 
     // Mixed semantics is not supported yet.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..08d99c65a291d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -415,7 +415,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
 
 // -----
 
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -431,6 +431,25 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @no_fold_fill_like
+//       CHECK:   %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   linalg.generic 
+//       CHECK:     linalg.yield %[[VAL0]] : f32
+func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
+  %1 = arith.constant 0.0 : f32
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%0 : memref<4x16xf32>)
+    outs(%0 : memref<4x16xf32>) {
+      ^bb0(%arg4: f32, %arg5: f32):
+        linalg.yield %1 : f32
+    }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_static_pad_fill
 //       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/130000


More information about the Mlir-commits mailing list