[Mlir-commits] [mlir] [mlir][linalg] Fix a DCE crash with memref<0x..> and the op has uses (PR #73908)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 00:14:15 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kohei Yamaguchi (sott0n)

<details>
<summary>Changes</summary>

The DCE for LinalgOp when processing operands as `memref<0x..>` causes a crash if this op has uses. This PR addresses it.

Fixes https://github.com/llvm/llvm-project/issues/73547

---
Full diff: https://github.com/llvm/llvm-project/pull/73908.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1-1) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+28-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e939..323ded3aadcd3a6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2101,7 +2101,7 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
       auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
       if (!mt)
         continue;
-      if (llvm::is_contained(op.getShape(&opOperand), 0)) {
+      if (llvm::is_contained(op.getShape(&opOperand), 0) && op->use_empty()) {
         rewriter.eraseOp(op);
         return success();
       }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index e875bae4730946b..42188c01be16a57 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -28,7 +28,7 @@ func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
 }
 
 func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
-  // memref<0x32> is expected to be dce'ed
+  // memref<0xf32> is expected to be dce'ed
   memref.copy %arg0, %arg0 : memref<0xf32> to memref<0xf32>
 
   // tensor<0xf32> cannot be dce'ed
@@ -47,6 +47,33 @@ func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tenso
 
 // -----
 
+#accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>
+]
+
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"]
+}
+
+func.func @no_dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
+  // memref<0xf32> cannot be dce'ed
+  %2 = linalg.generic #trait ins(%arg0: memref<0xf32>) outs(%arg1 : tensor<0xf32>) {
+  ^bb(%0: f32, %1: f32) :
+    linalg.yield %1 : f32
+  } -> tensor<0xf32>
+
+  return %2: tensor<0xf32>
+}
+
+// CHECK-LABEL: @no_dce_zero_memref
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
+//  CHECK-NEXT:   linalg.generic
+
+// -----
+
 func.func @dce_self_linalg_copy(%arg0 : memref<?xf32>) {
   linalg.copy ins(%arg0: memref<?xf32>) outs(%arg0: memref<?xf32>)
   return

``````````

</details>


https://github.com/llvm/llvm-project/pull/73908


More information about the Mlir-commits mailing list