[Mlir-commits] [mlir] [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (PR #130000)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 5 23:29:49 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Ian Wood (IanWood1)
<details>
<summary>Changes</summary>
Adds a check to make sure that the linalg op is safe to erase by ensuring that the `linalg.yield` is yielding one of the generic's block args.
Closes https://github.com/llvm/llvm-project/issues/129414
---
Full diff: https://github.com/llvm/llvm-project/pull/130000.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+11-7)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+20-1)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..c044c94c5af3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1285,13 +1285,17 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
// In the buffer case, we need to check exact buffer equality.
if (linalgOp.hasPureBufferSemantics()) {
- if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
- linalgOp.getDpsInputOperand(0)->get() ==
- linalgOp.getDpsInitOperand(0)->get()) {
- rewriter.eraseOp(linalgOp);
- return success();
- }
- return failure();
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+ linalgOp.getDpsInputOperand(0)->get() !=
+ linalgOp.getDpsInitOperand(0)->get())
+ return failure();
+
+ auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+ if (!yieldArg || yieldArg.getOwner() != &body)
+ return failure();
+
+ rewriter.eraseOp(linalgOp);
+ return success();
}
// Mixed semantics is not supported yet.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index db4f6181f517c..08d99c65a291d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -415,7 +415,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
// -----
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// CHECK-NEXT: return
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -431,6 +431,25 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// -----
+// CHECK-LABEL: func @no_fold_fill_like
+// CHECK: %[[VAL0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: linalg.generic
+// CHECK: linalg.yield %[[VAL0]] : f32
+func.func @no_fold_fill_like(%0 : memref<4x16xf32>) {
+ %1 = arith.constant 0.0 : f32
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : memref<4x16xf32>)
+ outs(%0 : memref<4x16xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32):
+ linalg.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @fold_static_pad_fill
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/130000
More information about the Mlir-commits
mailing list