[Mlir-commits] [mlir] c005df3 - [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (#130000)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 2 12:19:00 PDT 2025


Author: Ian Wood
Date: 2025-06-02T12:18:57-07:00
New Revision: c005df3c7e7f8bf788803a95e27d57b339c965fe

URL: https://github.com/llvm/llvm-project/commit/c005df3c7e7f8bf788803a95e27d57b339c965fe
DIFF: https://github.com/llvm/llvm-project/commit/c005df3c7e7f8bf788803a95e27d57b339c965fe.diff

LOG: [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (#130000)

Adds a check to make sure that the linalg op is safe to erase by
ensuring that the `linalg.yield` is yielding one of the linalg op's
block args. This check already exists for linalg ops with pure tensor
semantics.


Closes https://github.com/llvm/llvm-project/issues/129414

---------

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5fc3ace5d6aab..5dbb2403eddbd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1278,8 +1278,9 @@ LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
 
-/// Remove any linalg operation (on tensors) that are just copying
-/// the values from inputs to the results. Requirements are
+/// Remove linalg operations that are just copying the values from inputs to
+/// results. In the memref case, the operation must be copying to and from the
+/// same value. Requirements are:
 /// 1) All iterator types are parallel
 /// 2) The body contains just a yield operation with the yielded values being
 ///    the arguments corresponding to the operands.
@@ -1304,18 +1305,27 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
 
     // In the buffer case, we need to check exact buffer equality.
     if (linalgOp.hasPureBufferSemantics()) {
-      if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
-          linalgOp.getDpsInputOperand(0)->get() ==
+      if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+          linalgOp.getDpsInputOperand(0)->get() !=
               linalgOp.getDpsInitOperand(0)->get()) {
-        rewriter.eraseOp(linalgOp);
-        return success();
+        return rewriter.notifyMatchFailure(
+            linalgOp, "expected single input and output to be the same value");
       }
-      return failure();
+
+      auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+      if (!yieldArg || yieldArg.getOwner() != &body) {
+        return rewriter.notifyMatchFailure(linalgOp,
+                                           "cannot fold fill-like op");
+      }
+
+      rewriter.eraseOp(linalgOp);
+      return success();
     }
 
-    // Mixed semantics is not supported yet.
-    if (!linalgOp.hasPureTensorSemantics())
-      return failure();
+    if (!linalgOp.hasPureTensorSemantics()) {
+      return rewriter.notifyMatchFailure(
+          linalgOp, "mixed semantics is not supported yet");
+    }
 
     // Get the argument number of the returned values. That is the operand
     // number to use for replacing uses of this operation.

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 3daf221f4402d..7284ae7dbd673 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -495,7 +495,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
 
 // -----
 
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -511,6 +511,36 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @no_fold_fill_like_memref
+//  CHECK-NEXT:   linalg.generic 
+func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%in_out : memref<4x16xf32>)
+    outs(%in_out : memref<4x16xf32>) {
+      ^bb0(%arg0: f32, %arg1: f32):
+        linalg.yield %fill_val : f32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @no_fold_fill_like_tensor
+//  CHECK-NEXT:   linalg.generic 
+func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
+  %result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%in_out : tensor<4x16xf32>)
+    outs(%in_out : tensor<4x16xf32>) {
+      ^bb0(%arg0: f32, %arg1: f32):
+        linalg.yield %fill_val : f32
+  } -> tensor<4x16xf32>
+  return %result : tensor<4x16xf32>
+}
+
 // CHECK-LABEL: func @fold_static_pad_fill
 //       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>


        


More information about the Mlir-commits mailing list