[Mlir-commits] [mlir] [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (PR #130000)

Ian Wood llvmlistbot at llvm.org
Thu May 29 10:20:17 PDT 2025


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/130000

>From 878b5826028f09977980260b0a8b341b4edba9cd Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 6 Mar 2025 11:09:21 -0800
Subject: [PATCH 1/3] [mlir][linalg] Fix pattern to erase identity linalg ops

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 18 +++++++++++-------
 mlir/test/Dialect/Linalg/canonicalize.mlir | 21 ++++++++++++++++++++-
 2 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..c044c94c5af3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1285,13 +1285,17 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
 
     // In the buffer case, we need to check exact buffer equality.
     if (linalgOp.hasPureBufferSemantics()) {
-      if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
-          linalgOp.getDpsInputOperand(0)->get() ==
-              linalgOp.getDpsInitOperand(0)->get()) {
-        rewriter.eraseOp(linalgOp);
-        return success();
-      }
-      return failure();
+      if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+          linalgOp.getDpsInputOperand(0)->get() !=
+              linalgOp.getDpsInitOperand(0)->get())
+        return failure();
+
+      auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+      if (!yieldArg || yieldArg.getOwner() != &body)
+        return failure();
+
+      rewriter.eraseOp(linalgOp);
+      return success();
     }
 
     // Mixed semantics is not supported yet.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..08d99c65a291d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -415,7 +415,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
 
 // -----
 
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -431,6 +431,25 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @no_fold_fill_like
+//       CHECK:   %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   linalg.generic 
+//       CHECK:     linalg.yield %[[VAL0]] : f32
+func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
+  %1 = arith.constant 0.0 : f32
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%0 : memref<4x16xf32>)
+    outs(%0 : memref<4x16xf32>) {
+      ^bb0(%arg4: f32, %arg5: f32):
+        linalg.yield %1 : f32
+    }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_static_pad_fill
 //       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>

>From a5539f66c1a902b5d155e7de69a69883944c414c Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 29 May 2025 10:11:36 -0700
Subject: [PATCH 2/3] Address comments

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 10 ++++++----
 mlir/test/Dialect/Linalg/canonicalize.mlir | 15 ++++++---------
 2 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8724014e69a5e..51ff12794eeb4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1307,19 +1307,21 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
       if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
           linalgOp.getDpsInputOperand(0)->get() !=
               linalgOp.getDpsInitOperand(0)->get())
-        return failure();
+        return rewriter.notifyMatchFailure(
+            linalgOp, "expected single input and output to be the same value");
 
       auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
       if (!yieldArg || yieldArg.getOwner() != &body)
-        return failure();
+        return rewriter.notifyMatchFailure(linalgOp,
+                                           "cannot fold fill-like op");
 
       rewriter.eraseOp(linalgOp);
       return success();
     }
 
-    // Mixed semantics is not supported yet.
     if (!linalgOp.hasPureTensorSemantics())
-      return failure();
+      return rewriter.notifyMatchFailure(
+          linalgOp, "mixed semantics is not supported yet");
 
     // Get the argument number of the returned values. That is the operand
     // number to use for replacing uses of this operation.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 62382de9fd277..1b2857af42f54 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -512,18 +512,15 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // -----
 
 // CHECK-LABEL: func @no_fold_fill_like
-//       CHECK:   %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
-//       CHECK:   linalg.generic 
-//       CHECK:     linalg.yield %[[VAL0]] : f32
-func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
-  %1 = arith.constant 0.0 : f32
+//  CHECK-NEXT:   linalg.generic 
+func.func @no_fold_fill_like(%in_out : memref<4x16xf32>, %fill_val : f32) {
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                    affine_map<(d0, d1) -> (d0, d1)>],
                   iterator_types = ["parallel", "parallel"]}
-    ins(%0 : memref<4x16xf32>)
-    outs(%0 : memref<4x16xf32>) {
-      ^bb0(%arg4: f32, %arg5: f32):
-        linalg.yield %1 : f32
+    ins(%in_out : memref<4x16xf32>)
+    outs(%in_out : memref<4x16xf32>) {
+      ^bb0(%arg0: f32, %arg1: f32):
+        linalg.yield %fill_val : f32
     }
   return
 }

>From 9fc46c578b3f70c2e8b6aa45ef4ddfd28e32662f Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 29 May 2025 10:20:49 -0700
Subject: [PATCH 3/3] Fix comment and add test

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   |  5 +++--
 mlir/test/Dialect/Linalg/canonicalize.mlir | 20 +++++++++++++++++---
 2 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 51ff12794eeb4..7b485f142f138 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1278,8 +1278,9 @@ LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
 
-/// Remove any linalg operation (on tensors) that are just copying
-/// the values from inputs to the results. Requirements are
+/// Remove linalg operations that are just copying the values from inputs to
+/// results. In the memref case, the operation must be copying to and from the
+/// same value. Requirements are:
 /// 1) All iterator types are parallel
 /// 2) The body contains just a yield operation with the yielded values being
 ///    the arguments corresponding to the operands.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 1b2857af42f54..7284ae7dbd673 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -511,9 +511,9 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @no_fold_fill_like
+// CHECK-LABEL: func @no_fold_fill_like_memref
 //  CHECK-NEXT:   linalg.generic 
-func.func @no_fold_fill_like(%in_out : memref<4x16xf32>, %fill_val : f32) {
+func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                    affine_map<(d0, d1) -> (d0, d1)>],
                   iterator_types = ["parallel", "parallel"]}
@@ -521,12 +521,26 @@ func.func @no_fold_fill_like(%in_out : memref<4x16xf32>, %fill_val : f32) {
     outs(%in_out : memref<4x16xf32>) {
       ^bb0(%arg0: f32, %arg1: f32):
         linalg.yield %fill_val : f32
-    }
+  }
   return
 }
 
 // -----
 
+// CHECK-LABEL: func @no_fold_fill_like_tensor
+//  CHECK-NEXT:   linalg.generic 
+func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
+  %result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%in_out : tensor<4x16xf32>)
+    outs(%in_out : tensor<4x16xf32>) {
+      ^bb0(%arg0: f32, %arg1: f32):
+        linalg.yield %fill_val : f32
+  } -> tensor<4x16xf32>
+  return %result : tensor<4x16xf32>
+}
+
 // CHECK-LABEL: func @fold_static_pad_fill
 //       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>



More information about the Mlir-commits mailing list