[Mlir-commits] [mlir] 977cddb - [mlir] GreedyPatternRewriteDriver: All entry points take a config

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 05:34:11 PST 2023


Author: Matthias Springer
Date: 2023-01-27T14:33:54+01:00
New Revision: 977cddb95eac67a6dc6680a7d0fadee81114de11

URL: https://github.com/llvm/llvm-project/commit/977cddb95eac67a6dc6680a7d0fadee81114de11
DIFF: https://github.com/llvm/llvm-project/commit/977cddb95eac67a6dc6680a7d0fadee81114de11.diff

LOG: [mlir] GreedyPatternRewriteDriver: All entry points take a config

The multi-op entry point now also takes a GreedyPatternRewriteConfig and respects config.maxNumRewrites. The scope is also a part of the config now.

Differential Revision: https://reviews.llvm.org/D142614

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 6ee565ffe5ec4..e4251103d6d58 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -37,15 +37,21 @@ class GreedyRewriteConfig {
   /// generally more efficient in compile time.  When set to false, its initial
   /// traversal of the region tree is bottom up on each block, which may match
   /// larger patterns when given an ambiguous pattern set.
+  ///
+  /// Note: Only applicable when simplifying entire regions.
   bool useTopDownTraversal = false;
 
-  // Perform control flow optimizations to the region tree after applying all
-  // patterns.
+  /// Perform control flow optimizations to the region tree after applying all
+  /// patterns.
+  ///
+  /// Note: Only applicable when simplifying entire regions.
   bool enableRegionSimplification = true;
 
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
   /// disable this iteration limit.
+  ///
+  /// Note: Only applicable when simplifying entire regions.
   int64_t maxIterations = 10;
 
   /// This specifies the maximum number of rewrites within an iteration. Use
@@ -53,6 +59,10 @@ class GreedyRewriteConfig {
   int64_t maxNumRewrites = kNoLimit;
 
   static constexpr int64_t kNoLimit = -1;
+
+  /// Only ops within the scope are added to the worklist. If no scope is
+  /// specified, the closest enclosing region is used as a scope.
+  Region *scope = nullptr;
 };
 
 //===----------------------------------------------------------------------===//
@@ -117,12 +127,12 @@ inline LogicalResult applyPatternsAndFoldGreedily(
 /// Returns success if the iterative process converged and no more patterns can
 /// be matched. `changed` is set to true if the IR was modified at all.
 /// `allOpsErased` is set to true if all ops in `ops` were erased.
-LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
-                                     const FrozenRewritePatternSet &patterns,
-                                     GreedyRewriteStrictness strictMode,
-                                     bool *changed = nullptr,
-                                     bool *allErased = nullptr,
-                                     Region *scope = nullptr);
+LogicalResult
+applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+                       const FrozenRewritePatternSet &patterns,
+                       GreedyRewriteStrictness strictMode,
+                       GreedyRewriteConfig config = GreedyRewriteConfig(),
+                       bool *changed = nullptr, bool *allErased = nullptr);
 
 /// Applies the specified patterns on `op` alone while also trying to fold it,
 /// by selecting the highest benefits patterns in a greedy manner. Returns
@@ -133,9 +143,10 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
 /// be matched.
 inline LogicalResult
 applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
+                       GreedyRewriteConfig config = GreedyRewriteConfig(),
                        bool *erased = nullptr) {
   return applyOpPatternsAndFold(ArrayRef(op), patterns,
-                                GreedyRewriteStrictness::ExistingOps,
+                                GreedyRewriteStrictness::ExistingOps, config,
                                 /*changed=*/nullptr, erased);
 }
 

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index ec85e566e778b..a146d5adfa637 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -322,8 +322,8 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
         RewritePatternSet patterns(res.getContext());
         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
         bool erased;
-        (void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
-
+        (void)applyOpPatternsAndFold(res, std::move(patterns),
+                                     GreedyRewriteConfig(), &erased);
         if (!erased && !prologue)
           prologue = res;
         if (!erased)

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 180fef853e202..286c38ec18ee5 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -415,7 +415,8 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
   bool erased;
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
+  (void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
+                               &erased);
   if (erased) {
     if (folded)
       *folded = true;

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36317e039ef2f..4a37730129691 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,8 +39,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
-                                      const GreedyRewriteConfig &config,
-                                      const Region &scope);
+                                      const GreedyRewriteConfig &config);
 
   /// Simplify the ops within the given region.
   bool simplify(Region &region) &&;
@@ -103,9 +102,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// Configuration information for how to simplify.
   const GreedyRewriteConfig config;
 
-  /// Only ops within this scope are simplified.
-  const Region &scope;
-
 private:
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
@@ -116,9 +112,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-    const GreedyRewriteConfig &config, const Region &scope)
-    : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
-      scope(scope) {
+    const GreedyRewriteConfig &config)
+    : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+  assert(config.scope && "scope is not specified");
   worklist.reserve(64);
 
   // Apply a simple cost model based solely on pattern benefit.
@@ -313,7 +309,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
   SmallVector<Operation *, 8> ancestors;
   ancestors.push_back(op);
   while (Region *region = op->getParentRegion()) {
-    if (&scope == region) {
+    if (config.scope == region) {
       // All gathered ops are in fact ancestors.
       for (Operation *op : ancestors)
         addSingleOpToWorklist(op);
@@ -434,9 +430,12 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
   assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
          "patterns can only be applied to operations IsolatedFromAbove");
 
+  // Set scope if not specified.
+  if (!config.scope)
+    config.scope = ®ion;
+
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
-                                    region);
+  GreedyPatternRewriteDriver driver(region.getContext(), patterns, config);
   bool converged = std::move(driver).simplify(region);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -460,9 +459,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(
       MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-      const Region &scope, GreedyRewriteStrictness strictMode,
+      GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
-      : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
+      : GreedyPatternRewriteDriver(ctx, patterns, config),
         strictMode(strictMode), survivingOps(survivingOps) {}
 
   /// Performs the specified rewrites on `ops` while also trying to fold these
@@ -636,11 +635,10 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   return region;
 }
 
-LogicalResult
-mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
-                             const FrozenRewritePatternSet &patterns,
-                             GreedyRewriteStrictness strictMode, bool *changed,
-                             bool *allErased, Region *scope) {
+LogicalResult mlir::applyOpPatternsAndFold(
+    ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
+    GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
+    bool *changed, bool *allErased) {
   if (ops.empty()) {
     if (changed)
       *changed = false;
@@ -649,14 +647,15 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
     return success();
   }
 
-  if (!scope) {
+  // Determine scope of rewrite.
+  if (!config.scope) {
     // Compute scope if none was provided.
-    scope = findCommonAncestor(ops);
+    config.scope = findCommonAncestor(ops);
   } else {
     // If a scope was provided, make sure that all ops are in scope.
 #ifndef NDEBUG
     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
-      return static_cast<bool>(scope->findAncestorOpInRegion(*op));
+      return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
     });
     assert(allOpsInScope && "ops must be within the specified scope");
 #endif // NDEBUG
@@ -665,14 +664,14 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
   // Start the pattern driver.
   llvm::SmallDenseSet<Operation *, 4> surviving;
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     *scope, strictMode,
+                                     strictMode, config,
                                      allErased ? &surviving : nullptr);
   LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
   if (allErased)
     *allErased = surviving.empty();
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after "
-                 << GreedyRewriteConfig().maxNumRewrites << " rewrites";
+                 << config.maxNumRewrites << " rewrites";
   });
   return converged;
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c47c8f139e406..da018708c702e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -283,7 +283,7 @@ struct TestStrictPatternDriver
     bool changed = false;
     bool allErased = false;
     (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
-                                 &changed, &allErased);
+                                 GreedyRewriteConfig(), &changed, &allErased);
     Builder b(ctx);
     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
     getOperation()->setAttr("pattern_driver_all_erased",


        


More information about the Mlir-commits mailing list