[Mlir-commits] [mlir] 486d275 - [mlir][NFC] Polish copy removal transform

Ehsan Toosi llvmlistbot at llvm.org
Mon Jul 27 23:36:42 PDT 2020


Author: Ehsan Toosi
Date: 2020-07-28T08:34:44+02:00
New Revision: 486d2750c7151d3d93b785a4669e2d7d5c9286ac

URL: https://github.com/llvm/llvm-project/commit/486d2750c7151d3d93b785a4669e2d7d5c9286ac
DIFF: https://github.com/llvm/llvm-project/commit/486d2750c7151d3d93b785a4669e2d7d5c9286ac.diff

LOG: [mlir][NFC] Polish copy removal transform

Address a few remaining comments in copy removal transform.

Differential Revision: https://reviews.llvm.org/D84529

Added: 
    

Modified: 
    mlir/lib/Transforms/CopyRemoval.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp
index 28648e0b4294..ccfd02630ac2 100644
--- a/mlir/lib/Transforms/CopyRemoval.cpp
+++ b/mlir/lib/Transforms/CopyRemoval.cpp
@@ -19,16 +19,28 @@ namespace {
 //===----------------------------------------------------------------------===//
 // CopyRemovalPass
 //===----------------------------------------------------------------------===//
+
 /// This pass removes the redundant Copy operations. Additionally, it
 /// removes the leftover definition and deallocation operations by erasing the
 /// copy operation.
 class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
+public:
+  void runOnOperation() override {
+    getOperation()->walk([&](CopyOpInterface copyOp) {
+      reuseCopySourceAsTarget(copyOp);
+      reuseCopyTargetAsSource(copyOp);
+    });
+    for (Operation *op : eraseList)
+      op->erase();
+  }
+
 private:
   /// List of operations that need to be removed.
   DenseSet<Operation *> eraseList;
 
   /// Returns the deallocation operation for `value` in `block` if it exists.
   Operation *getDeallocationInBlock(Value value, Block *block) {
+    assert(block && "Block cannot be null");
     auto valueUsers = value.getUsers();
     auto it = llvm::find_if(valueUsers, [&](Operation *op) {
       auto effects = dyn_cast<MemoryEffectOpInterface>(op);
@@ -40,12 +52,12 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
   /// Returns true if an operation between start and end operations has memory
   /// effect.
   bool hasMemoryEffectOpBetween(Operation *start, Operation *end) {
+    assert((start || end) && "Start and end operations cannot be null");
     assert(start->getBlock() == end->getBlock() &&
            "Start and end operations should be in the same block.");
     Operation *op = start->getNextNode();
     while (op->isBeforeInBlock(end)) {
-      auto effects = dyn_cast<MemoryEffectOpInterface>(op);
-      if (effects)
+      if (isa<MemoryEffectOpInterface>(op))
         return true;
       op = op->getNextNode();
     }
@@ -55,6 +67,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
   /// Returns true if `val` value has at least a user between `start` and
   /// `end` operations.
   bool hasUsersBetween(Value val, Operation *start, Operation *end) {
+    assert((start || end) && "Start and end operations cannot be null");
     Block *block = start->getBlock();
     assert(block == end->getBlock() &&
            "Start and end operations should be in the same block.");
@@ -65,10 +78,11 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
   };
 
   bool areOpsInTheSameBlock(ArrayRef<Operation *> operations) {
-    llvm::SmallPtrSet<Block *, 4> blocks;
-    for (Operation *op : operations)
-      blocks.insert(op->getBlock());
-    return blocks.size() == 1;
+    assert(!operations.empty() &&
+           "The operations list should contain at least a single operation");
+    Block *block = operations.front()->getBlock();
+    return llvm::none_of(
+        operations, [&](Operation *op) { return block != op->getBlock(); });
   }
 
   /// Input:
@@ -97,7 +111,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
   /// TODO: Alias analysis is not available at the moment. Currently, we check
   /// if there are any operations with memory effects between copy and
   /// deallocation operations.
-  void ReuseCopySourceAsTarget(CopyOpInterface copyOp) {
+  void reuseCopySourceAsTarget(CopyOpInterface copyOp) {
     if (eraseList.count(copyOp))
       return;
 
@@ -147,7 +161,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
   /// TODO: Alias analysis is not available at the moment. Currently, we check
   /// if there are any operations with memory effects between copy and
   /// deallocation operations.
-  void ReuseCopyTargetAsSource(CopyOpInterface copyOp) {
+  void reuseCopyTargetAsSource(CopyOpInterface copyOp) {
     if (eraseList.count(copyOp))
       return;
 
@@ -169,16 +183,6 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
     eraseList.insert(fromDefiningOp);
     eraseList.insert(fromFreeingOp);
   }
-
-public:
-  void runOnOperation() override {
-    getOperation()->walk([&](CopyOpInterface copyOp) {
-      ReuseCopySourceAsTarget(copyOp);
-      ReuseCopyTargetAsSource(copyOp);
-    });
-    for (Operation *op : eraseList)
-      op->erase();
-  }
 };
 
 } // end anonymous namespace
@@ -186,6 +190,7 @@ class CopyRemovalPass : public PassWrapper<CopyRemovalPass, OperationPass<>> {
 //===----------------------------------------------------------------------===//
 // CopyRemovalPass construction
 //===----------------------------------------------------------------------===//
+
 std::unique_ptr<Pass> mlir::createCopyRemovalPass() {
   return std::make_unique<CopyRemovalPass>();
 }


        


More information about the Mlir-commits mailing list