[Mlir-commits] [mlir] [mlir] Enable decoupling two kinds of greedy behavior. (PR #104649)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 16 16:22:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

The greedy rewriter is used in many different flows and it has a lot of convenience (work list management, debugging actions, tracing, etc). But it combines two kinds of greedy behavior 1) how ops are matched, 2) folding wherever it can.

These are independent forms of greedy and leads to inefficiency. E.g., cases where one need to create different phases in lowering and is required to applying patterns in specific order split across different passes. Using the driver one ends up needlessly retrying folding/having multiple rounds of folding attempts, where one final run would have sufficed.

Of course folks can locally avoid this behavior by just building their own, but this is also a common requested feature that folks keep on working around locally in suboptimal ways.

---
Full diff: https://github.com/llvm/llvm-project/pull/104649.diff


2 Files Affected:

- (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+46-12) 
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+11-6) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index eaff85804f6b3d..061cdd4b7d4d94 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -91,6 +91,15 @@ class GreedyRewriteConfig {
 
   /// An optional listener that should be notified about IR modifications.
   RewriterBase::Listener *listener = nullptr;
+
+  // Whether this should fold while greedily rewriting.
+  //
+  // Note: greedy here generally refers to two forms, 1) greedily applying
+  // patterns based purely on benefit and applying without backtracking using
+  // default cost model, 2) greedily folding where possible while attempting to
+  // match and rewrite using the provided patterns. With this option set to
+  // false it only does the former.
+  bool fold = true;
 };
 
 //===----------------------------------------------------------------------===//
@@ -104,8 +113,8 @@ class GreedyRewriteConfig {
 /// The greedy rewrite may prematurely stop after a maximum number of
 /// iterations, which can be configured in the configuration parameter.
 ///
-/// Also performs folding and simple dead-code elimination before attempting to
-/// match any of the provided patterns.
+/// Also performs simple dead-code elimination before attempting to match any of
+/// the provided patterns.
 ///
 /// A region scope can be set in the configuration parameter. By default, the
 /// scope is set to the specified region. Only in-scope ops are added to the
@@ -117,10 +126,18 @@ class GreedyRewriteConfig {
 ///
 /// Note: This method does not apply patterns to the region's parent operation.
 LogicalResult
+applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
+                      GreedyRewriteConfig config = GreedyRewriteConfig(),
+                      bool *changed = nullptr);
+/// Same as `applyPatternsAndGreedily` above with folding.
+inline LogicalResult
 applyPatternsAndFoldGreedily(Region &region,
                              const FrozenRewritePatternSet &patterns,
                              GreedyRewriteConfig config = GreedyRewriteConfig(),
-                             bool *changed = nullptr);
+                             bool *changed = nullptr) {
+  config.fold = true;
+  return applyPatternsGreedily(region, patterns, config, changed);
+}
 
 /// Rewrite ops nested under the given operation, which must be isolated from
 /// above, by repeatedly applying the highest benefit patterns in a greedy
@@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region &region,
 /// The greedy rewrite may prematurely stop after a maximum number of
 /// iterations, which can be configured in the configuration parameter.
 ///
-/// Also performs folding and simple dead-code elimination before attempting to
-/// match any of the provided patterns.
+/// Also performs simple dead-code elimination before attempting to match any of
+/// the provided patterns.
 ///
 /// This overload runs a separate greedy rewrite for each region of the
 /// specified op. A region scope can be set in the configuration parameter. By
@@ -147,10 +164,9 @@ applyPatternsAndFoldGreedily(Region &region,
 ///
 /// Note: This method does not apply patterns to the given operation itself.
 inline LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternSet &patterns,
-                             GreedyRewriteConfig config = GreedyRewriteConfig(),
-                             bool *changed = nullptr) {
+applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
+                      GreedyRewriteConfig config = GreedyRewriteConfig(),
+                      bool *changed = nullptr) {
   bool anyRegionChanged = false;
   bool failed = false;
   for (Region &region : op->getRegions()) {
@@ -164,6 +180,15 @@ applyPatternsAndFoldGreedily(Operation *op,
     *changed = anyRegionChanged;
   return failure(failed);
 }
+/// Same as `applyPatternsGreedily` above with folding.
+inline LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+                             const FrozenRewritePatternSet &patterns,
+                             GreedyRewriteConfig config = GreedyRewriteConfig(),
+                             bool *changed = nullptr) {
+  config.fold = true;
+  return applyPatternsGreedily(op, patterns, config, changed);
+}
 
 /// Rewrite the specified ops by repeatedly applying the highest benefit
 /// patterns in a greedy worklist driven manner until a fixpoint is reached.
@@ -171,8 +196,8 @@ applyPatternsAndFoldGreedily(Operation *op,
 /// The greedy rewrite may prematurely stop after a maximum number of
 /// iterations, which can be configured in the configuration parameter.
 ///
-/// Also performs folding and simple dead-code elimination before attempting to
-/// match any of the provided patterns.
+/// Also performs simple dead-code elimination before attempting to match any of
+/// the provided patterns.
 ///
 /// Newly created ops and other pre-existing ops that use results of rewritten
 /// ops or supply operands to such ops are also processed, unless such ops are
@@ -194,10 +219,19 @@ applyPatternsAndFoldGreedily(Operation *op,
 /// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
 /// `ops` were erased.
 LogicalResult
+applyOpPatternsGreedily(ArrayRef<Operation *> ops,
+                        const FrozenRewritePatternSet &patterns,
+                        GreedyRewriteConfig config = GreedyRewriteConfig(),
+                        bool *changed = nullptr, bool *allErased = nullptr);
+/// Same as `applyOpPatternsGreedily` with folding.
+inline LogicalResult
 applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                        const FrozenRewritePatternSet &patterns,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
-                       bool *changed = nullptr, bool *allErased = nullptr);
+                       bool *changed = nullptr, bool *allErased = nullptr) {
+  config.fold = true;
+  return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
+}
 
 } // namespace mlir
 
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index e0d0acd122e26b..4e8b74620da5fe 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements mlir::applyPatternsAndFoldGreedily.
+// This file implements mlir::applyPatternsGreedily.
 //
 //===----------------------------------------------------------------------===//
 
@@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     // infinite folding loop, as every constant op would be folded to an
     // Attribute and then immediately be rematerialized as a constant op, which
     // is then put on the worklist.
-    if (!op->hasTrait<OpTrait::ConstantLike>()) {
+    if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
       SmallVector<OpFoldResult> foldResults;
       if (succeeded(op->fold(foldResults))) {
         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -840,6 +840,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     // regions to enable more aggressive CSE'ing).
     OperationFolder folder(ctx, this);
     auto insertKnownConstant = [&](Operation *op) {
+      // This hoisting is to enable more folding, so skip checking if known
+      // constant, updating dense map etc if not doing folding.
+      if (!config.fold)
+        return false;
+
       // Check for existing constants when populating the worklist. This avoids
       // accidentally reversing the constant order during processing.
       Attribute constValue;
@@ -894,9 +899,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 }
 
 LogicalResult
-mlir::applyPatternsAndFoldGreedily(Region &region,
-                                   const FrozenRewritePatternSet &patterns,
-                                   GreedyRewriteConfig config, bool *changed) {
+mlir::applyPatternsGreedily(Region &region,
+                            const FrozenRewritePatternSet &patterns,
+                            GreedyRewriteConfig config, bool *changed) {
   // The top-level operation must be known to be isolated from above to
   // prevent performing canonicalizations on operations defined at or above
   // the region containing 'op'.
@@ -1012,7 +1017,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   return region;
 }
 
-LogicalResult mlir::applyOpPatternsAndFold(
+LogicalResult mlir::applyOpPatternsGreedily(
     ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
     GreedyRewriteConfig config, bool *changed, bool *allErased) {
   if (ops.empty()) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/104649


More information about the Mlir-commits mailing list