[Mlir-commits] [mlir] 4e9f4d0 - [mlir] Fix bug in copy removal
Ehsan Toosi
llvmlistbot at llvm.org
Tue Sep 8 05:20:22 PDT 2020
Author: Ehsan Toosi
Date: 2020-09-08T14:17:13+02:00
New Revision: 4e9f4d0b9d1dbf2c1d3e389b870a16c3dbd5c302
URL: https://github.com/llvm/llvm-project/commit/4e9f4d0b9d1dbf2c1d3e389b870a16c3dbd5c302
DIFF: https://github.com/llvm/llvm-project/commit/4e9f4d0b9d1dbf2c1d3e389b870a16c3dbd5c302.diff
LOG: [mlir] Fix bug in copy removal
A crash could happen due to copy removal. The bug is fixed and two more
test cases are added.
Differential Revision: https://reviews.llvm.org/D87128
Added:
Modified:
mlir/lib/Transforms/CopyRemoval.cpp
mlir/test/Transforms/copy-removal.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp
index ccfd02630ac2..c5a8da632956 100644
--- a/mlir/lib/Transforms/CopyRemoval.cpp
+++ b/mlir/lib/Transforms/CopyRemoval.cpp
@@ -30,16 +30,35 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
reuseCopySourceAsTarget(copyOp);
reuseCopyTargetAsSource(copyOp);
});
+ for (std::pair<Value, Value> &pair : replaceList)
+ pair.first.replaceAllUsesWith(pair.second);
for (Operation *op : eraseList)
op->erase();
}
private:
/// List of operations that need to be removed.
- DenseSet<Operation *> eraseList;
+ llvm::SmallPtrSet<Operation *, 4> eraseList;
+
+ /// List of values that need to be replaced with their counterparts.
+ llvm::SmallDenseSet<std::pair<Value, Value>, 4> replaceList;
+
+ /// Returns the allocation operation for `value` in `block` if it exists.
+ /// nullptr otherwise.
+ Operation *getAllocationOpInBlock(Value value, Block *block) {
+ assert(block && "Block cannot be null");
+ Operation *op = value.getDefiningOp();
+ if (op && op->getBlock() == block) {
+ auto effects = dyn_cast<MemoryEffectOpInterface>(op);
+ if (effects && effects.hasEffect<Allocate>())
+ return op;
+ }
+ return nullptr;
+ }
/// Returns the deallocation operation for `value` in `block` if it exists.
- Operation *getDeallocationInBlock(Value value, Block *block) {
+ /// nullptr otherwise.
+ Operation *getDeallocationOpInBlock(Value value, Block *block) {
assert(block && "Block cannot be null");
auto valueUsers = value.getUsers();
auto it = llvm::find_if(valueUsers, [&](Operation *op) {
@@ -119,9 +138,10 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
Value to = copyOp.getTarget();
Operation *copy = copyOp.getOperation();
+ Block *copyBlock = copy->getBlock();
Operation *fromDefiningOp = from.getDefiningOp();
- Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
- Operation *toDefiningOp = to.getDefiningOp();
+ Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
+ Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock);
if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp ||
!areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) ||
hasUsersBetween(to, toDefiningOp, copy) ||
@@ -129,7 +149,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
hasMemoryEffectOpBetween(copy, fromFreeingOp))
return;
- to.replaceAllUsesWith(from);
+ replaceList.insert({to, from});
eraseList.insert(copy);
eraseList.insert(toDefiningOp);
eraseList.insert(fromFreeingOp);
@@ -169,8 +189,9 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
Value to = copyOp.getTarget();
Operation *copy = copyOp.getOperation();
- Operation *fromDefiningOp = from.getDefiningOp();
- Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
+ Block *copyBlock = copy->getBlock();
+ Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock);
+ Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
if (!fromDefiningOp || !fromFreeingOp ||
!areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) ||
hasUsersBetween(to, fromDefiningOp, copy) ||
@@ -178,7 +199,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
hasMemoryEffectOpBetween(copy, fromFreeingOp))
return;
- from.replaceAllUsesWith(to);
+ replaceList.insert({from, to});
eraseList.insert(copy);
eraseList.insert(fromDefiningOp);
eraseList.insert(fromFreeingOp);
diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir
index f750dabb18a0..a0d1193b77d5 100644
--- a/mlir/test/Transforms/copy-removal.mlir
+++ b/mlir/test/Transforms/copy-removal.mlir
@@ -283,3 +283,67 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
dealloc %temp : memref<2xf32>
return
}
+
+// -----
+
+// The only redundant copy is linalg.copy(%4, %5)
+
+// CHECK-LABEL: func @loop_alloc
+func @loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) {
+ // CHECK: %{{.*}} = alloc()
+ %0 = alloc() : memref<2xf32>
+ dealloc %0 : memref<2xf32>
+ // CHECK: %{{.*}} = alloc()
+ %1 = alloc() : memref<2xf32>
+ // CHECK: linalg.copy
+ linalg.copy(%arg3, %1) : memref<2xf32>, memref<2xf32>
+ %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) {
+ %3 = cmpi "eq", %arg5, %arg1 : index
+ // CHECK: dealloc
+ dealloc %arg6 : memref<2xf32>
+ // CHECK: %[[PERCENT4:.*]] = alloc()
+ %4 = alloc() : memref<2xf32>
+ // CHECK-NOT: alloc
+ // CHECK-NOT: linalg.copy
+ // CHECK-NOT: dealloc
+ %5 = alloc() : memref<2xf32>
+ linalg.copy(%4, %5) : memref<2xf32>, memref<2xf32>
+ dealloc %4 : memref<2xf32>
+ // CHECK: %[[PERCENT6:.*]] = alloc()
+ %6 = alloc() : memref<2xf32>
+ // CHECK: linalg.copy(%[[PERCENT4]], %[[PERCENT6]])
+ linalg.copy(%5, %6) : memref<2xf32>, memref<2xf32>
+ scf.yield %6 : memref<2xf32>
+ }
+ // CHECK: linalg.copy
+ linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32>
+ dealloc %2 : memref<2xf32>
+ return
+}
+
+// -----
+
+// The linalg.copy operation can be removed in addition to alloc and dealloc
+// operations. All uses of %0 is then replaced with %arg2.
+
+// CHECK-LABEL: func @check_with_affine_dialect
+func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[RES:.*]]: memref<4xf32>)
+ // CHECK-NOT: alloc
+ %0 = alloc() : memref<4xf32>
+ affine.for %arg3 = 0 to 4 {
+ %5 = affine.load %arg0[%arg3] : memref<4xf32>
+ %6 = affine.load %arg1[%arg3] : memref<4xf32>
+ %7 = cmpf "ogt", %5, %6 : f32
+ // CHECK: %[[SELECT_RES:.*]] = select
+ %8 = select %7, %5, %6 : f32
+ // CHECK-NEXT: affine.store %[[SELECT_RES]], %[[RES]]
+ affine.store %8, %0[%arg3] : memref<4xf32>
+ }
+ // CHECK-NOT: linalg.copy
+ // CHECK-NOT: dealloc
+ "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
+ dealloc %0 : memref<4xf32>
+ //CHECK: return
+ return
+}
More information about the Mlir-commits
mailing list