[Mlir-commits] [mlir] [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (PR #130000)
Ian Wood
llvmlistbot at llvm.org
Wed Mar 5 23:29:16 PST 2025
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/130000
Adds a check to make sure that the linalg op is safe to erase by ensuring that the `linalg.yield` is yielding one of the generic's block args.
Closes https://github.com/llvm/llvm-project/issues/129414
>From 878b5826028f09977980260b0a8b341b4edba9cd Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 6 Mar 2025 11:09:21 -0800
Subject: [PATCH] [mlir][linalg] Fix pattern to erase identity linalg ops
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 +++++++++++-------
mlir/test/Dialect/Linalg/canonicalize.mlir | 21 ++++++++++++++++++++-
2 files changed, 31 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..c044c94c5af3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1285,13 +1285,17 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
// In the buffer case, we need to check exact buffer equality.
if (linalgOp.hasPureBufferSemantics()) {
- if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
- linalgOp.getDpsInputOperand(0)->get() ==
- linalgOp.getDpsInitOperand(0)->get()) {
- rewriter.eraseOp(linalgOp);
- return success();
- }
- return failure();
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+ linalgOp.getDpsInputOperand(0)->get() !=
+ linalgOp.getDpsInitOperand(0)->get())
+ return failure();
+
+ auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+ if (!yieldArg || yieldArg.getOwner() != &body)
+ return failure();
+
+ rewriter.eraseOp(linalgOp);
+ return success();
}
// Mixed semantics is not supported yet.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..08d99c65a291d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -415,7 +415,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
// -----
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// CHECK-NEXT: return
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -431,6 +431,25 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// -----
+// CHECK-LABEL: func @no_fold_fill_like
+// CHECK: %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: linalg.generic
+// CHECK: linalg.yield %[[VAL0]] : f32
+func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
+ %1 = arith.constant 0.0 : f32
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : memref<4x16xf32>)
+ outs(%0 : memref<4x16xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ linalg.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @fold_static_pad_fill
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>
More information about the Mlir-commits
mailing list