[Mlir-commits] [mlir] [mlir][linalg] Fix a DCE crash with memref<0x..> and the op has uses (PR #73908)

Kohei Yamaguchi llvmlistbot at llvm.org
Thu Nov 30 00:13:47 PST 2023


https://github.com/sott0n created https://github.com/llvm/llvm-project/pull/73908

The DCE for LinalgOp when processing operands as `memref<0x..>` causes a crash if this op has uses. This PR addresses it.

Fixes https://github.com/llvm/llvm-project/issues/73547

>From 937d314d43631aa3015a1a0de460a06e113ef766 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Thu, 30 Nov 2023 16:46:06 +0000
Subject: [PATCH] [mlir][linalg] Fix a DCE crash with memref<0x..> and the op
 has uses

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   |  2 +-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 29 +++++++++++++++++++++-
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e939..323ded3aadcd3a6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2101,7 +2101,7 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
       auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
       if (!mt)
         continue;
-      if (llvm::is_contained(op.getShape(&opOperand), 0)) {
+      if (llvm::is_contained(op.getShape(&opOperand), 0) && op->use_empty()) {
         rewriter.eraseOp(op);
         return success();
       }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index e875bae4730946b..42188c01be16a57 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -28,7 +28,7 @@ func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
 }
 
 func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
-  // memref<0x32> is expected to be dce'ed
+  // memref<0xf32> is expected to be dce'ed
   memref.copy %arg0, %arg0 : memref<0xf32> to memref<0xf32>
 
   // tensor<0xf32> cannot be dce'ed
@@ -47,6 +47,33 @@ func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tenso
 
 // -----
 
+#accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>
+]
+
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"]
+}
+
+func.func @no_dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
+  // memref<0xf32> cannot be dce'ed
+  %2 = linalg.generic #trait ins(%arg0: memref<0xf32>) outs(%arg1 : tensor<0xf32>) {
+  ^bb(%0: f32, %1: f32) :
+    linalg.yield %1 : f32
+  } -> tensor<0xf32>
+
+  return %2: tensor<0xf32>
+}
+
+// CHECK-LABEL: @no_dce_zero_memref
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
+//  CHECK-NEXT:   linalg.generic
+
+// -----
+
 func.func @dce_self_linalg_copy(%arg0 : memref<?xf32>) {
   linalg.copy ins(%arg0: memref<?xf32>) outs(%arg0: memref<?xf32>)
   return



More information about the Mlir-commits mailing list