[Mlir-commits] [mlir] c005df3 - [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (#130000)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 2 12:19:00 PDT 2025
Author: Ian Wood
Date: 2025-06-02T12:18:57-07:00
New Revision: c005df3c7e7f8bf788803a95e27d57b339c965fe
URL: https://github.com/llvm/llvm-project/commit/c005df3c7e7f8bf788803a95e27d57b339c965fe
DIFF: https://github.com/llvm/llvm-project/commit/c005df3c7e7f8bf788803a95e27d57b339c965fe.diff
LOG: [mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (#130000)
Adds a check to make sure that the linalg op is safe to erase by
ensuring that the `linalg.yield` is yielding one of the linalg op's
block args. This check already exists for linalg ops with pure tensor
semantics.
Closes https://github.com/llvm/llvm-project/issues/129414
---------
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5fc3ace5d6aab..5dbb2403eddbd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1278,8 +1278,9 @@ LogicalResult GenericOp::verify() { return success(); }
namespace {
-/// Remove any linalg operation (on tensors) that are just copying
-/// the values from inputs to the results. Requirements are
+/// Remove linalg operations that are just copying the values from inputs to
+/// results. In the memref case, the operation must be copying to and from the
+/// same value. Requirements are:
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
@@ -1304,18 +1305,27 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
// In the buffer case, we need to check exact buffer equality.
if (linalgOp.hasPureBufferSemantics()) {
- if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
- linalgOp.getDpsInputOperand(0)->get() ==
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
+ linalgOp.getDpsInputOperand(0)->get() !=
linalgOp.getDpsInitOperand(0)->get()) {
- rewriter.eraseOp(linalgOp);
- return success();
+ return rewriter.notifyMatchFailure(
+ linalgOp, "expected single input and output to be the same value");
}
- return failure();
+
+ auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
+ if (!yieldArg || yieldArg.getOwner() != &body) {
+ return rewriter.notifyMatchFailure(linalgOp,
+ "cannot fold fill-like op");
+ }
+
+ rewriter.eraseOp(linalgOp);
+ return success();
}
- // Mixed semantics is not supported yet.
- if (!linalgOp.hasPureTensorSemantics())
- return failure();
+ if (!linalgOp.hasPureTensorSemantics()) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "mixed semantics is not supported yet");
+ }
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 3daf221f4402d..7284ae7dbd673 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -495,7 +495,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
// -----
-// CHECK: func @fold_self_copy
+// CHECK-LABEL: func @fold_self_copy
func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// CHECK-NEXT: return
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -511,6 +511,36 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// -----
+// CHECK-LABEL: func @no_fold_fill_like_memref
+// CHECK-NEXT: linalg.generic
+func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in_out : memref<4x16xf32>)
+ outs(%in_out : memref<4x16xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ linalg.yield %fill_val : f32
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @no_fold_fill_like_tensor
+// CHECK-NEXT: linalg.generic
+func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
+ %result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in_out : tensor<4x16xf32>)
+ outs(%in_out : tensor<4x16xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ linalg.yield %fill_val : f32
+ } -> tensor<4x16xf32>
+ return %result : tensor<4x16xf32>
+}
+
// CHECK-LABEL: func @fold_static_pad_fill
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>
More information about the Mlir-commits
mailing list