[Mlir-commits] [mlir] a2b837a - [mlir] GreedyPatternRewriteDriver: Entry point takes single region

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 02:24:37 PST 2023


Author: Matthias Springer
Date: 2023-01-27T11:23:04+01:00
New Revision: a2b837ab0448869c74cc042155dd454833c60d62

URL: https://github.com/llvm/llvm-project/commit/a2b837ab0448869c74cc042155dd454833c60d62
DIFF: https://github.com/llvm/llvm-project/commit/a2b837ab0448869c74cc042155dd454833c60d62.diff

LOG: [mlir] GreedyPatternRewriteDriver: Entry point takes single region

The rewrite driver is typically applied to a single region or all regions of the same op. There is no longer an overload to apply the rewrite driver to a list of regions.

This simplifies the rewrite driver implementation because the scope is now a single region as opposed to a list of regions.

Note: This change is not NFC because `config.maxIterations` and `config.maxNumRewrites` is now counted for each region separately. Furthermore, worklist filtering (`scope`) is now applied to each region separately.

Differential Revision: https://reviews.llvm.org/D142611

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 5a043775a01d2..6ee565ffe5ec4 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -59,9 +59,9 @@ class GreedyRewriteConfig {
 // applyPatternsGreedily
 //===----------------------------------------------------------------------===//
 
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner.
+/// Rewrite ops in the given region, which must be isolated from above, by
+/// repeatedly applying the highest benefit patterns in a greedy work-list
+/// driven manner.
 ///
 /// This variant may stop after a predefined number of iterations, see the
 /// alternative below to provide a specific number of iterations before stopping
@@ -76,14 +76,18 @@ class GreedyRewriteConfig {
 ///
 /// You may configure several aspects of this with GreedyRewriteConfig.
 LogicalResult applyPatternsAndFoldGreedily(
-    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+    Region &region, const FrozenRewritePatternSet &patterns,
     GreedyRewriteConfig config = GreedyRewriteConfig());
 
-/// Rewrite the given regions, which must be isolated from above.
+/// Rewrite ops in all regions of the given op, which must be isolated from
+/// above.
 inline LogicalResult applyPatternsAndFoldGreedily(
     Operation *op, const FrozenRewritePatternSet &patterns,
     GreedyRewriteConfig config = GreedyRewriteConfig()) {
-  return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
+  bool failed = false;
+  for (Region &region : op->getRegions())
+    failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed();
+  return failure(failed);
 }
 
 /// Applies the specified rewrite patterns on `ops` while also trying to fold

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e9a8c8326fee2..2504a2ab0c9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1867,8 +1867,7 @@ struct LinalgElementwiseOpFusionPass
     // Use TopDownTraversal for compile time reasons
     GreedyRewriteConfig grc;
     grc.useTopDownTraversal = true;
-    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
-                                       grc);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
   }
 };
 

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 71c99e9174ea6..8ef16d5eeaec4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -781,8 +781,7 @@ struct ExpandStridedMetadataPass final
 void ExpandStridedMetadataPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateExpandStridedMetadataPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
-                                     std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
 std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 92f02c068d2b9..33e9ee71ee3b5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -605,8 +605,7 @@ struct FoldMemRefAliasOpsPass final
 void FoldMemRefAliasOpsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateFoldMemRefAliasOpPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
-                                     std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
 std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 7d3f1fbd5293d..650d71e732a7c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -149,8 +149,7 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
-                                          std::move(patterns))))
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
 
@@ -158,8 +157,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
-                                          std::move(patterns))))
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index b225662e58c5b..eaba09753f7f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -120,8 +120,7 @@ struct LowerVectorMaskPass
     RewritePatternSet loweringPatterns(context);
     populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
 
-    if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
-                                            std::move(loweringPatterns))))
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
       signalPassFailure();
   }
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index a5ddd9138873b..36317e039ef2f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -40,10 +40,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
                                       const GreedyRewriteConfig &config,
-                                      const DenseSet<Region *> &scope);
+                                      const Region &scope);
 
-  /// Simplify the operations within the given regions.
-  bool simplify(MutableArrayRef<Region> regions) &&;
+  /// Simplify the ops within the given region.
+  bool simplify(Region &region) &&;
 
   /// Add the given operation and its ancestors to the worklist.
   void addToWorklist(Operation *op);
@@ -104,7 +104,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   const GreedyRewriteConfig config;
 
   /// Only ops within this scope are simplified.
-  const DenseSet<Region *> scope;
+  const Region &scope;
 
 private:
 #ifndef NDEBUG
@@ -116,7 +116,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-    const GreedyRewriteConfig &config, const DenseSet<Region *> &scope)
+    const GreedyRewriteConfig &config, const Region &scope)
     : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
       scope(scope) {
   worklist.reserve(64);
@@ -125,7 +125,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
   matcher.applyDefaultCostModel();
 }
 
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
+bool GreedyPatternRewriteDriver::simplify(Region &region) && {
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -167,15 +167,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
 
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
-      for (auto &region : regions) {
         region.walk([&](Operation *op) {
           if (!insertKnownConstant(op))
             addToWorklist(op);
         });
-      }
     } else {
       // Add all nested operations to the worklist in preorder.
-      for (auto &region : regions) {
         region.walk<WalkOrder::PreOrder>([&](Operation *op) {
           if (!insertKnownConstant(op)) {
             worklist.push_back(op);
@@ -183,7 +180,6 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
           }
           return WalkResult::skip();
         });
-      }
 
       // Reverse the list so our pop-back loop processes them in-order.
       std::reverse(worklist.begin(), worklist.end());
@@ -305,7 +301,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
     // After applying patterns, make sure that the CFG of each of the regions
     // is kept up to date.
     if (config.enableRegionSimplification)
-      changed |= succeeded(simplifyRegions(*this, regions));
+      changed |= succeeded(simplifyRegions(*this, region));
   } while (changed);
 
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
@@ -317,7 +313,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
   SmallVector<Operation *, 8> ancestors;
   ancestors.push_back(op);
   while (Region *region = op->getParentRegion()) {
-    if (scope.contains(region)) {
+    if (&scope == region) {
       // All gathered ops are in fact ancestors.
       for (Operation *op : ancestors)
         addSingleOpToWorklist(op);
@@ -429,31 +425,19 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
 /// top-level operation itself.
 ///
 LogicalResult
-mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+mlir::applyPatternsAndFoldGreedily(Region &region,
                                    const FrozenRewritePatternSet &patterns,
                                    GreedyRewriteConfig config) {
-  if (regions.empty())
-    return success();
-
   // The top-level operation must be known to be isolated from above to
   // prevent performing canonicalizations on operations defined at or above
   // the region containing 'op'.
-  auto regionIsIsolated = [](Region &region) {
-    return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
-  };
-  (void)regionIsIsolated;
-  assert(llvm::all_of(regions, regionIsIsolated) &&
+  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
          "patterns can only be applied to operations IsolatedFromAbove");
 
-  // Limit ops on the worklist to this scope.
-  DenseSet<Region *> scope;
-  for (Region &r : regions)
-    scope.insert(&r);
-
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config,
-                                    scope);
-  bool converged = std::move(driver).simplify(regions);
+  GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
+                                    region);
+  bool converged = std::move(driver).simplify(region);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
                  << config.maxIterations << " times\n";
@@ -476,7 +460,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(
       MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-      const DenseSet<Region *> &scope, GreedyRewriteStrictness strictMode,
+      const Region &scope, GreedyRewriteStrictness strictMode,
       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
         strictMode(strictMode), survivingOps(survivingOps) {}
@@ -680,10 +664,8 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
 
   // Start the pattern driver.
   llvm::SmallDenseSet<Operation *, 4> surviving;
-  DenseSet<Region *> scopeSet;
-  scopeSet.insert(scope);
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     scopeSet, strictMode,
+                                     *scope, strictMode,
                                      allErased ? &surviving : nullptr);
   LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
   if (allErased)

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 98896c736a3cb..c47c8f139e406 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1633,8 +1633,7 @@ struct TestSelectiveReplacementPatternDriver
     MLIRContext *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     patterns.add<TestSelectiveOpReplacementPattern>(context);
-    (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
-                                       std::move(patterns));
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list