[Mlir-commits] [mlir] d75a611 - [mlir] Update `simplifyRegions` to use RewriterBase for erasure notifications

River Riddle llvmlistbot at llvm.org
Fri Mar 19 16:40:25 PDT 2021


Author: River Riddle
Date: 2021-03-19T16:33:54-07:00
New Revision: d75a611afbc7c5f8c343e0398dd2b506684e506b

URL: https://github.com/llvm/llvm-project/commit/d75a611afbc7c5f8c343e0398dd2b506684e506b
DIFF: https://github.com/llvm/llvm-project/commit/d75a611afbc7c5f8c343e0398dd2b506684e506b.diff

LOG: [mlir] Update `simplifyRegions` to use RewriterBase for erasure notifications

This allows for notifying callers when operations/blocks get erased, which is especially useful for the greedy pattern driver. The current greedy pattern driver "throws away" all information on constants in the operation folder because it doesn't know if they get erased or not. By passing in RewriterBase, we can directly track this and prevent the need for the pattern driver to rediscover all of the existing constants. In some situations this cuts the compile time of the canonicalizer in half.

Differential Revision: https://reviews.llvm.org/D98755

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/RegionUtils.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 72c2f51c9e70..c2124d8b70f0 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SetVector.h"
 
 namespace mlir {
+class RewriterBase;
 
 /// Check if all values in the provided range are defined above the `limit`
 /// region.  That is, if they are defined in a region that is a proper ancestor
@@ -53,8 +54,10 @@ void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
-/// of the regions were simplified, failure otherwise.
-LogicalResult simplifyRegions(MutableArrayRef<Region> regions);
+/// of the regions were simplified, failure otherwise. The provided rewriter is
+/// used to notify callers of operation and block deletion.
+LogicalResult simplifyRegions(RewriterBase &rewriter,
+                              MutableArrayRef<Region> regions);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 9ed3b3514db6..922fbb1bee06 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -114,7 +114,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
       // TODO: This is based on the fact that zero use operations
       // may be deleted, and that single use values often have more
       // canonicalization opportunities.
-      if (!operand.use_empty() && !operand.hasOneUse())
+      if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
         continue;
       if (auto *defInst = operand.getDefiningOp())
         addToWorklist(defInst);
@@ -202,10 +202,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 
     // After applying patterns, make sure that the CFG of each of the regions is
     // kept up to date.
-    if (succeeded(simplifyRegions(regions))) {
-      folder.clear();
-      changed = true;
-    }
+    changed |= succeeded(simplifyRegions(*this, regions));
   } while (changed && ++i < maxIterations);
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return !changed;

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 21d0ff53fdc8..47635c3bbf49 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -75,7 +76,8 @@ void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
 /// Erase the unreachable blocks within the provided regions. Returns success
 /// if any blocks were erased, failure otherwise.
 // TODO: We could likely merge this with the DCE algorithm below.
-static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
+static LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
   // Set of blocks found to be reachable within a given region.
   llvm::df_iterator_default_set<Block *, 16> reachable;
   // If any blocks were found to be dead.
@@ -108,7 +110,7 @@ static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
     for (Block &block : llvm::make_early_inc_range(*region)) {
       if (!reachable.count(&block)) {
         block.dropAllDefinedValueUses();
-        block.erase();
+        rewriter.eraseBlock(&block);
         erasedDeadBlocks = true;
         continue;
       }
@@ -305,7 +307,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
   }
 }
 
-static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
+static LogicalResult deleteDeadness(RewriterBase &rewriter,
+                                    MutableArrayRef<Region> regions,
                                     LiveMap &liveMap) {
   bool erasedAnything = false;
   for (Region &region : regions) {
@@ -324,10 +327,10 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
         if (!liveMap.wasProvenLive(&childOp)) {
           erasedAnything = true;
           childOp.dropAllUses();
-          childOp.erase();
+          rewriter.eraseOp(&childOp);
         } else {
-          erasedAnything |=
-              succeeded(deleteDeadness(childOp.getRegions(), liveMap));
+          erasedAnything |= succeeded(
+              deleteDeadness(rewriter, childOp.getRegions(), liveMap));
         }
       }
     }
@@ -359,7 +362,8 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
 //
 // This function returns success if any operations or arguments were deleted,
 // failure otherwise.
-static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
+static LogicalResult runRegionDCE(RewriterBase &rewriter,
+                                  MutableArrayRef<Region> regions) {
   LiveMap liveMap;
   do {
     liveMap.resetChanged();
@@ -368,7 +372,7 @@ static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
       propagateLiveness(region, liveMap);
   } while (liveMap.hasChanged());
 
-  return deleteDeadness(regions, liveMap);
+  return deleteDeadness(rewriter, regions, liveMap);
 }
 
 //===----------------------------------------------------------------------===//
@@ -456,7 +460,7 @@ class BlockMergeCluster {
   LogicalResult addToCluster(BlockEquivalenceData &blockData);
 
   /// Try to merge all of the blocks within this cluster into the leader block.
-  LogicalResult merge();
+  LogicalResult merge(RewriterBase &rewriter);
 
 private:
   /// The equivalence data for the leader of the cluster.
@@ -550,7 +554,7 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
-LogicalResult BlockMergeCluster::merge() {
+LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
     return failure();
@@ -613,7 +617,7 @@ LogicalResult BlockMergeCluster::merge() {
   // Replace all uses of the merged blocks with the leader and erase them.
   for (Block *block : blocksToMerge) {
     block->replaceAllUsesWith(leaderBlock);
-    block->erase();
+    rewriter.eraseBlock(block);
   }
   return success();
 }
@@ -621,7 +625,8 @@ LogicalResult BlockMergeCluster::merge() {
 /// Identify identical blocks within the given region and merge them, inserting
 /// new block arguments as necessary. Returns success if any blocks were merged,
 /// failure otherwise.
-static LogicalResult mergeIdenticalBlocks(Region &region) {
+static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
+                                          Region &region) {
   if (region.empty() || llvm::hasSingleElement(region))
     return failure();
 
@@ -659,7 +664,7 @@ static LogicalResult mergeIdenticalBlocks(Region &region) {
         clusters.emplace_back(std::move(data));
     }
     for (auto &cluster : clusters)
-      mergedAnyBlocks |= succeeded(cluster.merge());
+      mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
   }
 
   return success(mergedAnyBlocks);
@@ -667,14 +672,15 @@ static LogicalResult mergeIdenticalBlocks(Region &region) {
 
 /// Identify identical blocks within the given regions and merge them, inserting
 /// new block arguments as necessary.
-static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
+static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
+                                          MutableArrayRef<Region> regions) {
   llvm::SmallSetVector<Region *, 1> worklist;
   for (auto &region : regions)
     worklist.insert(&region);
   bool anyChanged = false;
   while (!worklist.empty()) {
     Region *region = worklist.pop_back_val();
-    if (succeeded(mergeIdenticalBlocks(*region))) {
+    if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
       worklist.insert(region);
       anyChanged = true;
     }
@@ -697,10 +703,12 @@ static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise.
-LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
-  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions));
-  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions));
-  bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions));
+LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
+                                    MutableArrayRef<Region> regions) {
+  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
+  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
+  bool mergedIdenticalBlocks =
+      succeeded(mergeIdenticalBlocks(rewriter, regions));
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2824fdea6e90..0a1558f31c18 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -21,12 +21,12 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // CHECK-LABEL:   func @single_iteration(
 // CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?x?xi32>) {
+// CHECK:           [[C42:%.*]] = constant 42 : i32
 // CHECK:           [[C0:%.*]] = constant 0 : index
 // CHECK:           [[C2:%.*]] = constant 2 : index
 // CHECK:           [[C3:%.*]] = constant 3 : index
 // CHECK:           [[C6:%.*]] = constant 6 : index
 // CHECK:           [[C7:%.*]] = constant 7 : index
-// CHECK:           [[C42:%.*]] = constant 42 : i32
 // CHECK:           scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) {
 // CHECK:             memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref<?x?x?xi32>
 // CHECK:             scf.yield


        


More information about the Mlir-commits mailing list