[Mlir-commits] [mlir] cccc7e5 - [MLIR] Don't remove memref allocation if stored into another allocation

William S. Moses llvmlistbot at llvm.org
Mon Jun 28 09:06:15 PDT 2021


Author: William S. Moses
Date: 2021-06-28T12:05:59-04:00
New Revision: cccc7e5aa8088b3b721e1f430c47d199575fae9b

URL: https://github.com/llvm/llvm-project/commit/cccc7e5aa8088b3b721e1f430c47d199575fae9b
DIFF: https://github.com/llvm/llvm-project/commit/cccc7e5aa8088b3b721e1f430c47d199575fae9b.diff

LOG: [MLIR] Don't remove memref allocation if stored into another allocation

A canonicalization accidentally will remove a memref allocation if it is only stored into. However, this is incorrect if the allocation is the value being stored, not the allocation being stored into.

Differential Revision: https://reviews.llvm.org/D104947

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index cc4e7a49363a..6f358d834bee 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -174,8 +174,10 @@ struct SimplifyDeadAlloc : public OpRewritePattern<T> {
 
   LogicalResult matchAndRewrite(T alloc,
                                 PatternRewriter &rewriter) const override {
-    if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
-          return !isa<StoreOp, DeallocOp>(op);
+    if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
+        if (auto storeOp = dyn_cast<StoreOp>(op))
+          return storeOp.value() == alloc;
+        return !isa<DeallocOp>(op);
         }))
       return failure();
 

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index cbf2126a9ea2..c59d1d30f7ec 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -420,3 +420,13 @@ func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> {
   %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, #map0>
   return %0 : memref<?xi32, #map0>
 }
+
+// -----
+// CHECK-LABEL: func @allocator
+// CHECK:   %[[alloc:.+]] = memref.alloc
+// CHECK:   memref.store %[[alloc:.+]], %arg0
+func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
+  %0 = memref.alloc(%arg1) : memref<?xi32>
+  memref.store %0, %arg0[] : memref<memref<?xi32>>
+  return 
+}


        


More information about the Mlir-commits mailing list