[Mlir-commits] [mlir] 4e9f4d0 - [mlir] Fix bug in copy removal

Ehsan Toosi llvmlistbot at llvm.org
Tue Sep 8 05:20:22 PDT 2020


Author: Ehsan Toosi
Date: 2020-09-08T14:17:13+02:00
New Revision: 4e9f4d0b9d1dbf2c1d3e389b870a16c3dbd5c302

URL: https://github.com/llvm/llvm-project/commit/4e9f4d0b9d1dbf2c1d3e389b870a16c3dbd5c302
DIFF: https://github.com/llvm/llvm-project/commit/4e9f4d0b9d1dbf2c1d3e389b870a16c3dbd5c302.diff

LOG: [mlir] Fix bug in copy removal

A crash could happen due to copy removal. The bug is fixed and two more
test cases are added.

Differential Revision: https://reviews.llvm.org/D87128

Added: 
    

Modified: 
    mlir/lib/Transforms/CopyRemoval.cpp
    mlir/test/Transforms/copy-removal.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp
index ccfd02630ac2..c5a8da632956 100644
--- a/mlir/lib/Transforms/CopyRemoval.cpp
+++ b/mlir/lib/Transforms/CopyRemoval.cpp
@@ -30,16 +30,35 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
       reuseCopySourceAsTarget(copyOp);
       reuseCopyTargetAsSource(copyOp);
     });
+    for (std::pair<Value, Value> &pair : replaceList)
+      pair.first.replaceAllUsesWith(pair.second);
     for (Operation *op : eraseList)
       op->erase();
   }
 
 private:
   /// List of operations that need to be removed.
-  DenseSet<Operation *> eraseList;
+  llvm::SmallPtrSet<Operation *, 4> eraseList;
+
+  /// List of values that need to be replaced with their counterparts.
+  llvm::SmallDenseSet<std::pair<Value, Value>, 4> replaceList;
+
+  /// Returns the allocation operation for `value` in `block` if it exists.
+  /// nullptr otherwise.
+  Operation *getAllocationOpInBlock(Value value, Block *block) {
+    assert(block && "Block cannot be null");
+    Operation *op = value.getDefiningOp();
+    if (op && op->getBlock() == block) {
+      auto effects = dyn_cast<MemoryEffectOpInterface>(op);
+      if (effects && effects.hasEffect<Allocate>())
+        return op;
+    }
+    return nullptr;
+  }
 
   /// Returns the deallocation operation for `value` in `block` if it exists.
-  Operation *getDeallocationInBlock(Value value, Block *block) {
+  /// nullptr otherwise.
+  Operation *getDeallocationOpInBlock(Value value, Block *block) {
     assert(block && "Block cannot be null");
     auto valueUsers = value.getUsers();
     auto it = llvm::find_if(valueUsers, [&](Operation *op) {
@@ -119,9 +138,10 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
     Value to = copyOp.getTarget();
 
     Operation *copy = copyOp.getOperation();
+    Block *copyBlock = copy->getBlock();
     Operation *fromDefiningOp = from.getDefiningOp();
-    Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
-    Operation *toDefiningOp = to.getDefiningOp();
+    Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
+    Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock);
     if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp ||
         !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) ||
         hasUsersBetween(to, toDefiningOp, copy) ||
@@ -129,7 +149,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
         hasMemoryEffectOpBetween(copy, fromFreeingOp))
       return;
 
-    to.replaceAllUsesWith(from);
+    replaceList.insert({to, from});
     eraseList.insert(copy);
     eraseList.insert(toDefiningOp);
     eraseList.insert(fromFreeingOp);
@@ -169,8 +189,9 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
     Value to = copyOp.getTarget();
 
     Operation *copy = copyOp.getOperation();
-    Operation *fromDefiningOp = from.getDefiningOp();
-    Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
+    Block *copyBlock = copy->getBlock();
+    Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock);
+    Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
     if (!fromDefiningOp || !fromFreeingOp ||
         !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) ||
         hasUsersBetween(to, fromDefiningOp, copy) ||
@@ -178,7 +199,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
         hasMemoryEffectOpBetween(copy, fromFreeingOp))
       return;
 
-    from.replaceAllUsesWith(to);
+    replaceList.insert({from, to});
     eraseList.insert(copy);
     eraseList.insert(fromDefiningOp);
     eraseList.insert(fromFreeingOp);

diff  --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir
index f750dabb18a0..a0d1193b77d5 100644
--- a/mlir/test/Transforms/copy-removal.mlir
+++ b/mlir/test/Transforms/copy-removal.mlir
@@ -283,3 +283,67 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
   dealloc %temp : memref<2xf32>
   return
 }
+
+// -----
+
+// The only redundant copy is linalg.copy(%4, %5)
+
+// CHECK-LABEL: func @loop_alloc
+func @loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) {
+  // CHECK: %{{.*}} = alloc()
+  %0 = alloc() : memref<2xf32>
+  dealloc %0 : memref<2xf32>
+  // CHECK: %{{.*}} = alloc()
+  %1 = alloc() : memref<2xf32>
+  // CHECK: linalg.copy
+  linalg.copy(%arg3, %1) : memref<2xf32>, memref<2xf32>
+  %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) {
+    %3 = cmpi "eq", %arg5, %arg1 : index
+    // CHECK: dealloc
+    dealloc %arg6 : memref<2xf32>
+    // CHECK: %[[PERCENT4:.*]] = alloc()
+    %4 = alloc() : memref<2xf32>
+    // CHECK-NOT: alloc
+    // CHECK-NOT: linalg.copy
+    // CHECK-NOT: dealloc
+    %5 = alloc() : memref<2xf32>
+    linalg.copy(%4, %5) : memref<2xf32>, memref<2xf32>
+    dealloc %4 : memref<2xf32>
+    // CHECK: %[[PERCENT6:.*]] = alloc()
+    %6 = alloc() : memref<2xf32>
+    // CHECK: linalg.copy(%[[PERCENT4]], %[[PERCENT6]])
+    linalg.copy(%5, %6) : memref<2xf32>, memref<2xf32>
+    scf.yield %6 : memref<2xf32>
+  }
+  // CHECK: linalg.copy
+  linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32>
+  dealloc %2 : memref<2xf32>
+  return
+}
+
+// -----
+
+// The linalg.copy operation can be removed in addition to alloc and dealloc
+// operations. All uses of %0 is then replaced with %arg2.
+
+// CHECK-LABEL: func @check_with_affine_dialect
+func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
+  // CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[RES:.*]]: memref<4xf32>)
+  // CHECK-NOT: alloc
+  %0 = alloc() : memref<4xf32>
+  affine.for %arg3 = 0 to 4 {
+    %5 = affine.load %arg0[%arg3] : memref<4xf32>
+    %6 = affine.load %arg1[%arg3] : memref<4xf32>
+    %7 = cmpf "ogt", %5, %6 : f32
+    // CHECK: %[[SELECT_RES:.*]] = select
+    %8 = select %7, %5, %6 : f32
+    // CHECK-NEXT: affine.store %[[SELECT_RES]], %[[RES]]
+    affine.store %8, %0[%arg3] : memref<4xf32>
+  }
+  // CHECK-NOT: linalg.copy
+  // CHECK-NOT: dealloc
+  "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
+  dealloc %0 : memref<4xf32>
+  //CHECK: return
+  return
+}


        


More information about the Mlir-commits mailing list