[Mlir-commits] [mlir] 64716b2 - [GreedyPatternRewriter] Introduce a config object that allows controlling internal parameters. NFC.

Chris Lattner llvmlistbot at llvm.org
Mon May 24 12:48:35 PDT 2021


Author: Chris Lattner
Date: 2021-05-24T12:40:40-07:00
New Revision: 64716b2c39c10a3ea3a893da6106d2d55a0e8deb

URL: https://github.com/llvm/llvm-project/commit/64716b2c39c10a3ea3a893da6106d2d55a0e8deb
DIFF: https://github.com/llvm/llvm-project/commit/64716b2c39c10a3ea3a893da6106d2d55a0e8deb.diff

LOG: [GreedyPatternRewriter] Introduce a config object that allows controlling internal parameters. NFC.

This exposes the iterations and top-down processing as flags, and also
allows controlling whether region simplification is desirable for a client.
This allows deleting some duplicated entrypoints to
applyPatternsAndFoldGreedily.

This also deletes the Constant Preprocessing pass, which isn't worth it
on balance.

All defaults are all kept the same, so no one should see a behavior change.

Differential Revision: https://reviews.llvm.org/D102988

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 9f78f374a6fef..7f4166c12ded9 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -33,11 +33,6 @@ class OperationFolder {
 public:
   OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
 
-  /// Scan the specified region for constants that can be used in folding,
-  /// moving them to the entry block (or any custom insertion location specified
-  /// by shouldMaterializeInto), and add them to our known-constants table.
-  void processExistingConstants(Region &region);
-
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, replaces `op`'s uses with
   /// folded results, and returns success. `preReplaceAction` is invoked on `op`

diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index b67fc5ffa06b7..a85bdef8cea01 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,24 @@
 
 namespace mlir {
 
+/// This struct allows control over how the GreedyPatternRewriteDriver works.
+struct GreedyRewriteConfig {
+  /// This specifies the order of initial traversal that populates the rewriters
+  /// worklist.  When set to true, it walks the operations top-down, which is
+  /// generally more efficient in compile time.  When set to false, its initial
+  /// traversal of the region tree is bottom up on each block, which may match
+  /// larger patterns when given an ambiguous pattern set.
+  bool useTopDownTraversal = false;
+
+  // Perform control flow optimizations to the region tree after applying all
+  // patterns.
+  bool enableRegionSimplification = true;
+
+  /// This specifies the maximum number of times the rewriter will iterate
+  /// between applying patterns and simplifying regions.
+  unsigned maxIterations = 10;
+};
+
 //===----------------------------------------------------------------------===//
 // applyPatternsGreedily
 //===----------------------------------------------------------------------===//
@@ -37,33 +55,17 @@ namespace mlir {
 ///       These methods also perform folding and simple dead-code elimination
 ///       before attempting to match any of the provided patterns.
 ///
-/// You may choose the order of initial traversal with the `useTopDownTraversal`
-/// boolean.  When set to true, it walks the operations top-down, which is
-/// generally more efficient in compile time.  When set to false, its initial
-/// traversal of the region tree is post-order, which may match larger patterns
-/// when given an ambiguous pattern set.
-LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternSet &patterns,
-                             bool useTopDownTraversal = false);
-
-/// Rewrite the regions of the specified operation, with a user-provided limit
-/// on iterations to attempt before reaching convergence.
+/// You may configure several aspects of this with GreedyRewriteConfig.
 LogicalResult applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternSet &patterns,
-    unsigned maxIterations, bool useTopDownTraversal = false);
+    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+    GreedyRewriteConfig config = GreedyRewriteConfig());
 
 /// Rewrite the given regions, which must be isolated from above.
-LogicalResult
-applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const FrozenRewritePatternSet &patterns,
-                             bool useTopDownTraversal = false);
-
-/// Rewrite the given regions, with a user-provided limit on iterations to
-/// attempt before reaching convergence.
-LogicalResult applyPatternsAndFoldGreedily(
-    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
-    unsigned maxIterations, bool useTopDownTraversal = false);
+inline LogicalResult applyPatternsAndFoldGreedily(
+    Operation *op, const FrozenRewritePatternSet &patterns,
+    GreedyRewriteConfig config = GreedyRewriteConfig()) {
+  return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
+}
 
 /// Applies the specified patterns on `op` alone while also trying to fold it,
 /// by selecting the highest benefits patterns in a greedy manner. Returns

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index cc3a2170190d2..5d91243507eb7 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -31,10 +31,10 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
     return success();
   }
   void runOnOperation() override {
-    (void)applyPatternsAndFoldGreedily(
-        getOperation()->getRegions(), patterns,
-        /*maxIterations=*/10, /*useTopDownTraversal=*/
-        topDownProcessingEnabled);
+    GreedyRewriteConfig config;
+    config.useTopDownTraversal = topDownProcessingEnabled;
+    (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns,
+                                       config);
   }
 
   FrozenRewritePatternSet patterns;

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index e6faa860bc76d..af415d5fc5228 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -84,85 +84,6 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
 // OperationFolder
 //===----------------------------------------------------------------------===//
 
-/// Scan the specified region for constants that can be used in folding,
-/// moving them to the entry block (or any custom insertion location specified
-/// by shouldMaterializeInto), and add them to our known-constants table.
-void OperationFolder::processExistingConstants(Region &region) {
-  if (region.empty())
-    return;
-
-  // March the constant insertion point forward, moving all constants to the
-  // top of the block, but keeping them in their order of discovery.
-  Region *insertRegion = getInsertionRegion(interfaces, &region.front());
-  auto &uniquedConstants = foldScopes[insertRegion];
-
-  Block &insertBlock = insertRegion->front();
-  Block::iterator constantIterator = insertBlock.begin();
-
-  // Process each constant that we discover in this region.
-  auto processConstant = [&](Operation *op, Attribute value) {
-    assert(op->getNumResults() == 1 && "constants have one result");
-    // Check to see if we already have an instance of this constant.
-    Operation *&constOp = uniquedConstants[std::make_tuple(
-        op->getDialect(), value, op->getResult(0).getType())];
-
-    // If we already have an instance of this constant, CSE/delete this one as
-    // we go.
-    if (constOp) {
-      if (constantIterator == Block::iterator(op))
-        ++constantIterator; // Don't invalidate our iterator when scanning.
-      op->getResult(0).replaceAllUsesWith(constOp->getResult(0));
-      op->erase();
-      return;
-    }
-
-    // Otherwise, remember that we have this constant.
-    constOp = op;
-    referencedDialects[op].push_back(op->getDialect());
-
-    // If the constant isn't already at the insertion point then move it up.
-    if (constantIterator != Block::iterator(op))
-      op->moveBefore(&insertBlock, constantIterator);
-    else
-      ++constantIterator; // It was pointing at the constant.
-  };
-
-  // Collect all the constants for this region of isolation or insertion (as
-  // specified by the shouldMaterializeInto hook).  Collect any subregions of
-  // isolation/constant insertion for subsequent processing.
-  SmallVector<Operation *> insertionSubregionOps;
-  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-    // If this is a constant, process it.
-    Attribute value;
-    if (matchPattern(op, m_Constant(&value))) {
-      processConstant(op, value);
-      // We may have deleted the operation, don't check it for regions.
-      return WalkResult::skip();
-    }
-
-    // If the operation has regions and is isolated, don't recurse into it.
-    if (op->getNumRegions() != 0) {
-      auto hasDifferentInsertRegion = [&](Region &region) {
-        return !region.empty() &&
-               getInsertionRegion(interfaces, &region.front()) != insertRegion;
-      };
-      if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) {
-        insertionSubregionOps.push_back(op);
-        return WalkResult::skip();
-      }
-    }
-
-    // Otherwise keep going.
-    return WalkResult::advance();
-  });
-
-  // Process regions in any isolated ops separately.
-  for (Operation *subregionOps : insertionSubregionOps) {
-    for (Region &region : subregionOps->getRegions())
-      processExistingConstants(region);
-  }
-}
-
 LogicalResult OperationFolder::tryToFold(
     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
     function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 1d0a79d054060..5b028d63b2379 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -24,9 +24,6 @@ using namespace mlir;
 
 #define DEBUG_TYPE "pattern-matcher"
 
-/// The max number of iterations scanning for pattern match.
-static unsigned maxPatternMatchIterations = 10;
-
 //===----------------------------------------------------------------------===//
 // GreedyPatternRewriteDriver
 //===----------------------------------------------------------------------===//
@@ -38,16 +35,15 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
-                                      bool useTopDownTraversal)
-      : PatternRewriter(ctx), matcher(patterns), folder(ctx),
-        useTopDownTraversal(useTopDownTraversal) {
+                                      const GreedyRewriteConfig &config)
+      : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
     worklist.reserve(64);
 
     // Apply a simple cost model based solely on pattern benefit.
     matcher.applyDefaultCostModel();
   }
 
-  bool simplify(MutableArrayRef<Region> regions, int maxIterations);
+  bool simplify(MutableArrayRef<Region> regions);
 
   void addToWorklist(Operation *op) {
     // Check to see if the worklist already contains this op.
@@ -137,40 +133,30 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// Non-pattern based folder for operations.
   OperationFolder folder;
 
-  /// Whether to use a top-down or bottom-up traversal to seed the initial
-  /// worklist.
-  bool useTopDownTraversal;
+  /// Configuration information for how to simplify.
+  GreedyRewriteConfig config;
 };
 } // end anonymous namespace
 
 /// Performs the rewrites while folding and erasing any dead ops. Returns true
 /// if the rewrite converges in `maxIterations`.
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
-                                          int maxIterations) {
-  // For maximum compatibility with existing passes, do not process existing
-  // constants unless we're performing a top-down traversal.
-  // TODO: This is just for compatibility with older MLIR, remove this.
-  if (useTopDownTraversal) {
-    // Perform a prepass over the IR to discover constants.
-    for (auto &region : regions)
-      folder.processExistingConstants(region);
-  }
-
+bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
   bool changed = false;
-  int iteration = 0;
+  unsigned iteration = 0;
   do {
     worklist.clear();
     worklistMap.clear();
 
-    // Add all nested operations to the worklist in preorder.
-    for (auto &region : regions)
-      if (useTopDownTraversal)
+    if (!config.useTopDownTraversal) {
+      // Add operations to the worklist in postorder.
+      for (auto &region : regions)
+        region.walk([this](Operation *op) { addToWorklist(op); });
+    } else {
+      // Add all nested operations to the worklist in preorder.
+      for (auto &region : regions)
         region.walk<WalkOrder::PreOrder>(
             [this](Operation *op) { worklist.push_back(op); });
-      else
-        region.walk([this](Operation *op) { addToWorklist(op); });
 
-    if (useTopDownTraversal) {
       // Reverse the list so our pop-back loop processes them in-order.
       std::reverse(worklist.begin(), worklist.end());
       // Remember the reverse index.
@@ -234,8 +220,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 
     // After applying patterns, make sure that the CFG of each of the regions
     // is kept up to date.
-    changed |= succeeded(simplifyRegions(*this, regions));
-  } while (changed && ++iteration < maxIterations);
+    if (config.enableRegionSimplification)
+      changed |= succeeded(simplifyRegions(*this, regions));
+  } while (changed && ++iteration < config.maxIterations);
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return !changed;
@@ -248,29 +235,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 /// top-level operation itself.
 ///
 LogicalResult
-mlir::applyPatternsAndFoldGreedily(Operation *op,
-                                   const FrozenRewritePatternSet &patterns,
-                                   bool useTopDownTraversal) {
-  return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
-                                      useTopDownTraversal);
-}
-LogicalResult mlir::applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternSet &patterns,
-    unsigned maxIterations, bool useTopDownTraversal) {
-  return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
-                                      useTopDownTraversal);
-}
-/// Rewrite the given regions, which must be isolated from above.
-LogicalResult
 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
                                    const FrozenRewritePatternSet &patterns,
-                                   bool useTopDownTraversal) {
-  return applyPatternsAndFoldGreedily(
-      regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
-}
-LogicalResult mlir::applyPatternsAndFoldGreedily(
-    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
-    unsigned maxIterations, bool useTopDownTraversal) {
+                                   GreedyRewriteConfig config) {
   if (regions.empty())
     return success();
 
@@ -285,12 +252,11 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(
          "patterns can only be applied to operations IsolatedFromAbove");
 
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
-                                    useTopDownTraversal);
-  bool converged = driver.simplify(regions, maxIterations);
+  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
+  bool converged = driver.simplify(regions);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
-                 << maxIterations << " times\n";
+                 << config.maxIterations << " times\n";
   });
   return success(converged);
 }
@@ -391,15 +357,16 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
 LogicalResult mlir::applyOpPatternsAndFold(
     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
   // Start the pattern driver.
+  GreedyRewriteConfig config;
   OpPatternRewriteDriver driver(op->getContext(), patterns);
   bool opErased;
   LogicalResult converged =
-      driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
+      driver.simplifyLocally(op, config.maxIterations, opErased);
   if (erased)
     *erased = opErased;
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
-                 << maxPatternMatchIterations << " times";
+                 << config.maxIterations << " times";
   });
   return converged;
 }


        


More information about the Mlir-commits mailing list