[Mlir-commits] [mlir] 391cb54 - [mlir] Add option to limit number of pattern rewrites in CanonicalizerPass

Matthias Springer llvmlistbot at llvm.org
Fri Dec 23 04:09:23 PST 2022


Author: Matthias Springer
Date: 2022-12-23T13:08:53+01:00
New Revision: 391cb541223bb0d41620eb5e25c107563dc3e12c

URL: https://github.com/llvm/llvm-project/commit/391cb541223bb0d41620eb5e25c107563dc3e12c
DIFF: https://github.com/llvm/llvm-project/commit/391cb541223bb0d41620eb5e25c107563dc3e12c.diff

LOG: [mlir] Add option to limit number of pattern rewrites in CanonicalizerPass

The greedy pattern rewriter consists of two nested loops. `config.maxIterations` (which configurable on the CanonicalizerPass) controls the maximum number of iterations of the outer loop.

```
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
int64_t maxIterations = 10;
```

This change adds `config.maxNumRewrites` which controls the maximum number of pattern rewrites within an iteration. (It effectively control the maximum number of iterations of the inner loop.)

This flag is meant for debugging and useful in cases where one or multiple faulty patterns can be applied indefinitely, resulting in an infinite loop.

Differential Revision: https://reviews.llvm.org/D140525

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/Pass/run-reproducer.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index d9d272110b31..5478587dcc43 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -33,11 +33,15 @@ class GreedyRewriteConfig {
   bool enableRegionSimplification = true;
 
   /// This specifies the maximum number of times the rewriter will iterate
-  /// between applying patterns and simplifying regions. Use `kNoIterationLimit`
-  /// to disable this iteration limit.
+  /// between applying patterns and simplifying regions. Use `kNoLimit` to
+  /// disable this iteration limit.
   int64_t maxIterations = 10;
 
-  static constexpr int64_t kNoIterationLimit = -1;
+  /// This specifies the maximum number of rewrites within an iteration. Use
+  /// `kNoLimit` to disable this limit.
+  int64_t maxNumRewrites = kNoLimit;
+
+  static constexpr int64_t kNoLimit = -1;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index d45f5f08b300..e7d122323ae3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -30,10 +30,12 @@ def Canonicalizer : Pass<"canonicalize"> {
            "Seed the worklist in general top-down order">,
     Option<"enableRegionSimplification", "region-simplify", "bool",
            /*default=*/"true",
-           "Seed the worklist in general top-down order">,
+           "Perform control flow optimizations to the region tree">,
     Option<"maxIterations", "max-iterations", "int64_t",
            /*default=*/"10",
-           "Seed the worklist in general top-down order">
+           "Max. iterations between applying patterns / simplifying regions">,
+    Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1",
+           "Max. number of pattern rewrites within an iteration">
   ] # RewritePassUtils.options;
 }
 

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index a4215629a964..dc3bf97b3238 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     this->topDownProcessingEnabled = config.useTopDownTraversal;
     this->enableRegionSimplification = config.enableRegionSimplification;
     this->maxIterations = config.maxIterations;
+    this->maxNumRewrites = config.maxNumRewrites;
     this->disabledPatterns = disabledPatterns;
     this->enabledPatterns = enabledPatterns;
   }
@@ -55,6 +56,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     config.useTopDownTraversal = topDownProcessingEnabled;
     config.enableRegionSimplification = enableRegionSimplification;
     config.maxIterations = maxIterations;
+    config.maxNumRewrites = maxNumRewrites;
     (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config);
   }
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 935ca2eb9374..0d6fdaf3039c 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -183,6 +183,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
     SmallVector<Value, 8> originalOperands, resultValues;
 
     changed = false;
+    int64_t numRewrites = 0;
     while (!worklist.empty()) {
       auto *op = popFromWorklist();
 
@@ -279,16 +280,20 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
 #else
       LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
 #endif
-      changed |= succeeded(matchResult);
+      if (succeeded(matchResult)) {
+        changed = true;
+        if (numRewrites++ >= config.maxNumRewrites &&
+            config.maxNumRewrites != GreedyRewriteConfig::kNoLimit)
+          break;
+      }
     }
 
     // After applying patterns, make sure that the CFG of each of the regions
     // is kept up to date.
     if (config.enableRegionSimplification)
       changed |= succeeded(simplifyRegions(*this, regions));
-  } while (changed &&
-           (iteration++ < config.maxIterations ||
-            config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
+  } while (changed && (iteration++ < config.maxIterations ||
+                       config.maxIterations == GreedyRewriteConfig::kNoLimit));
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return !changed;
@@ -506,9 +511,8 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
     changed |= succeeded(matcher.matchAndRewrite(op, *this));
     if ((erased = opErasedViaPatternRewrites))
       return success();
-  } while (changed &&
-           (++iterations < maxIterations ||
-            maxIterations == GreedyRewriteConfig::kNoIterationLimit));
+  } while (changed && (++iterations < maxIterations ||
+                       maxIterations == GreedyRewriteConfig::kNoLimit));
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return failure(changed);

diff  --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 496471d032a5..3a958f8a9250 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,8 +14,8 @@ func.func @bar() {
   external_resources: {
     mlir_reproducer: {
       verify_each: true,
-      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 region-simplify=false top-down=false}))
-      pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 region-simplify=false top-down=false}))",
+      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))
+      pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
       disable_threading: true
     }
   }


        


More information about the Mlir-commits mailing list