[Mlir-commits] [mlir] 8498c9e - [mlir][GreedyPatternRewriter] Add out param to detect changes in IR in `applyPatternsAndFoldGreedily`

Matthias Springer llvmlistbot at llvm.org
Thu Jun 29 03:55:41 PDT 2023


Author: Joel Wee
Date: 2023-06-29T12:48:00+02:00
New Revision: 8498c9e9489f57c6eb59b464d2409b4974a22a02

URL: https://github.com/llvm/llvm-project/commit/8498c9e9489f57c6eb59b464d2409b4974a22a02
DIFF: https://github.com/llvm/llvm-project/commit/8498c9e9489f57c6eb59b464d2409b4974a22a02.diff

LOG: [mlir][GreedyPatternRewriter] Add out param to detect changes in IR in `applyPatternsAndFoldGreedily`

This allows users of `applyPatternsAndFoldGreedily` to detect if any MLIR changes have occurred. An example use-case is where we expect the `applyPatternsAndFoldGreedily` to change the IR and want to validate that it indeed does change it.

Differential Revision: https://reviews.llvm.org/D153986

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 2c5c63e0ac4dc4..aaf6f0a8951257 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -94,25 +94,36 @@ class GreedyRewriteConfig {
 /// in absence of convergence.
 ///
 /// Return success if the iterative process converged and no more patterns can
-/// be matched in the result operation regions.
+/// be matched in the result operation regions. `changed` is set to true if the
+/// IR was modified at all.
 ///
 /// Note: This does not apply patterns to the top-level operation itself.
 ///       These methods also perform folding and simple dead-code elimination
 ///       before attempting to match any of the provided patterns.
 ///
 /// You may configure several aspects of this with GreedyRewriteConfig.
-LogicalResult applyPatternsAndFoldGreedily(
-    Region &region, const FrozenRewritePatternSet &patterns,
-    GreedyRewriteConfig config = GreedyRewriteConfig());
+LogicalResult
+applyPatternsAndFoldGreedily(Region &region,
+                             const FrozenRewritePatternSet &patterns,
+                             GreedyRewriteConfig config = GreedyRewriteConfig(),
+                             bool *changed = nullptr);
 
 /// Rewrite ops in all regions of the given op, which must be isolated from
 /// above.
-inline LogicalResult applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternSet &patterns,
-    GreedyRewriteConfig config = GreedyRewriteConfig()) {
+inline LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+                             const FrozenRewritePatternSet &patterns,
+                             GreedyRewriteConfig config = GreedyRewriteConfig(),
+                             bool *changed = nullptr) {
   bool failed = false;
-  for (Region &region : op->getRegions())
-    failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed();
+  for (Region &region : op->getRegions()) {
+    bool regionChanged;
+    failed |=
+        applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
+            .failed();
+    if (changed)
+      *changed |= regionChanged;
+  }
   return failure(failed);
 }
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2a39cccfc580d4..fba4944f130c23 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -616,7 +616,7 @@ class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
 
   /// Simplify ops inside `region` and simplify the region itself. Return
   /// success if the transformation converged.
-  LogicalResult simplify() &&;
+  LogicalResult simplify(bool *changed) &&;
 
 private:
   /// The region that is simplified.
@@ -652,7 +652,7 @@ class GreedyPatternRewriteIteration
 };
 } // namespace
 
-LogicalResult RegionPatternRewriteDriver::simplify() && {
+LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
   auto insertKnownConstant = [&](Operation *op) {
     // Check for existing constants when populating the worklist. This avoids
     // accidentally reversing the constant order during processing.
@@ -663,12 +663,12 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
     return false;
   };
 
-  bool changed = false;
+  bool continueRewrites = false;
   int64_t iteration = 0;
   MLIRContext *ctx = getContext();
   do {
     // Check if the iteration limit was reached.
-    if (iteration++ >= config.maxIterations &&
+    if (++iteration > config.maxIterations &&
         config.maxIterations != GreedyRewriteConfig::kNoLimit)
       break;
 
@@ -696,24 +696,27 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
 
     ctx->executeAction<GreedyPatternRewriteIteration>(
         [&] {
-          changed = processWorklist();
+          continueRewrites = processWorklist();
 
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
           if (config.enableRegionSimplification)
-            changed |= succeeded(simplifyRegions(*this, region));
+            continueRewrites |= succeeded(simplifyRegions(*this, region));
         },
         {&region}, iteration);
-  } while (changed);
+  } while (continueRewrites);
+
+  if (changed)
+    *changed = iteration > 1;
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
-  return success(!changed);
+  return success(!continueRewrites);
 }
 
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(Region &region,
                                    const FrozenRewritePatternSet &patterns,
-                                   GreedyRewriteConfig config) {
+                                   GreedyRewriteConfig config, bool *changed) {
   // The top-level operation must be known to be isolated from above to
   // prevent performing canonicalizations on operations defined at or above
   // the region containing 'op'.
@@ -727,7 +730,7 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
   // Start the pattern driver.
   RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
                                     region);
-  LogicalResult converged = std::move(driver).simplify();
+  LogicalResult converged = std::move(driver).simplify(changed);
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
                  << config.maxIterations << " times\n";


        


More information about the Mlir-commits mailing list