[Mlir-commits] [mlir] [mlir] Enable decoupling two kinds of greedy behavior. (PR #104649)

Jacques Pienaar llvmlistbot at llvm.org
Mon Sep 2 20:52:28 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/104649

>From 441584257819ba27087c81658cef467add39e395 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 16 Aug 2024 21:23:56 +0000
Subject: [PATCH 1/3] [mlir] Enable decoupling two kinds of greedy behavior.

The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) wrt how ops are matched, 2)
folding wherever it can.

These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering, one is
required to applying patterns in specific order/different passes. But if
using the driver one ends up needlessly retrying folding or having
multiple rounds of folding attempts, where one final run would have
sufficed. It also is rather confusing to users that just want to apply
some patterns while having all the convenience and structure to have
unrelated changes to IR.

Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.
---
 .../Transforms/GreedyPatternRewriteDriver.h   | 58 +++++++++++++++----
 .../Utils/GreedyPatternRewriteDriver.cpp      | 17 ++++--
 2 files changed, 57 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index eaff85804f6b3d..061cdd4b7d4d94 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -91,6 +91,15 @@ class GreedyRewriteConfig {
 
   /// An optional listener that should be notified about IR modifications.
   RewriterBase::Listener *listener = nullptr;
+
+  // Whether this should fold while greedily rewriting.
+  //
+  // Note: greedy here generally refers to two forms, 1) greedily applying
+  // patterns based purely on benefit and applying without backtracking using
+  // default cost model, 2) greedily folding where possible while attempting to
+  // match and rewrite using the provided patterns. With this option set to
+  // false it only does the former.
+  bool fold = true;
 };
 
 //===----------------------------------------------------------------------===//
@@ -104,8 +113,8 @@ class GreedyRewriteConfig {
 /// The greedy rewrite may prematurely stop after a maximum number of
 /// iterations, which can be configured in the configuration parameter.
 ///
-/// Also performs folding and simple dead-code elimination before attempting to
-/// match any of the provided patterns.
+/// Also performs simple dead-code elimination before attempting to match any of
+/// the provided patterns.
 ///
 /// A region scope can be set in the configuration parameter. By default, the
 /// scope is set to the specified region. Only in-scope ops are added to the
@@ -117,10 +126,18 @@ class GreedyRewriteConfig {
 ///
 /// Note: This method does not apply patterns to the region's parent operation.
 LogicalResult
+applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
+                      GreedyRewriteConfig config = GreedyRewriteConfig(),
+                      bool *changed = nullptr);
+/// Same as `applyPatternsAndGreedily` above with folding.
+inline LogicalResult
 applyPatternsAndFoldGreedily(Region &region,
                              const FrozenRewritePatternSet &patterns,
                              GreedyRewriteConfig config = GreedyRewriteConfig(),
-                             bool *changed = nullptr);
+                             bool *changed = nullptr) {
+  config.fold = true;
+  return applyPatternsGreedily(region, patterns, config, changed);
+}
 
 /// Rewrite ops nested under the given operation, which must be isolated from
 /// above, by repeatedly applying the highest benefit patterns in a greedy
@@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region &region,
 /// The greedy rewrite may prematurely stop after a maximum number of
 /// iterations, which can be configured in the configuration parameter.
 ///
-/// Also performs folding and simple dead-code elimination before attempting to
-/// match any of the provided patterns.
+/// Also performs simple dead-code elimination before attempting to match any of
+/// the provided patterns.
 ///
 /// This overload runs a separate greedy rewrite for each region of the
 /// specified op. A region scope can be set in the configuration parameter. By
@@ -147,10 +164,9 @@ applyPatternsAndFoldGreedily(Region &region,
 ///
 /// Note: This method does not apply patterns to the given operation itself.
 inline LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternSet &patterns,
-                             GreedyRewriteConfig config = GreedyRewriteConfig(),
-                             bool *changed = nullptr) {
+applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
+                      GreedyRewriteConfig config = GreedyRewriteConfig(),
+                      bool *changed = nullptr) {
   bool anyRegionChanged = false;
   bool failed = false;
   for (Region &region : op->getRegions()) {
@@ -164,6 +180,15 @@ applyPatternsAndFoldGreedily(Operation *op,
     *changed = anyRegionChanged;
   return failure(failed);
 }
+/// Same as `applyPatternsGreedily` above with folding.
+inline LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+                             const FrozenRewritePatternSet &patterns,
+                             GreedyRewriteConfig config = GreedyRewriteConfig(),
+                             bool *changed = nullptr) {
+  config.fold = true;
+  return applyPatternsGreedily(op, patterns, config, changed);
+}
 
 /// Rewrite the specified ops by repeatedly applying the highest benefit
 /// patterns in a greedy worklist driven manner until a fixpoint is reached.
@@ -171,8 +196,8 @@ applyPatternsAndFoldGreedily(Operation *op,
 /// The greedy rewrite may prematurely stop after a maximum number of
 /// iterations, which can be configured in the configuration parameter.
 ///
-/// Also performs folding and simple dead-code elimination before attempting to
-/// match any of the provided patterns.
+/// Also performs simple dead-code elimination before attempting to match any of
+/// the provided patterns.
 ///
 /// Newly created ops and other pre-existing ops that use results of rewritten
 /// ops or supply operands to such ops are also processed, unless such ops are
@@ -194,10 +219,19 @@ applyPatternsAndFoldGreedily(Operation *op,
 /// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
 /// `ops` were erased.
 LogicalResult
+applyOpPatternsGreedily(ArrayRef<Operation *> ops,
+                        const FrozenRewritePatternSet &patterns,
+                        GreedyRewriteConfig config = GreedyRewriteConfig(),
+                        bool *changed = nullptr, bool *allErased = nullptr);
+/// Same as `applyOpPatternsGreedily` with folding.
+inline LogicalResult
 applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                        const FrozenRewritePatternSet &patterns,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
-                       bool *changed = nullptr, bool *allErased = nullptr);
+                       bool *changed = nullptr, bool *allErased = nullptr) {
+  config.fold = true;
+  return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
+}
 
 } // namespace mlir
 
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index e0d0acd122e26b..4e8b74620da5fe 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements mlir::applyPatternsAndFoldGreedily.
+// This file implements mlir::applyPatternsGreedily.
 //
 //===----------------------------------------------------------------------===//
 
@@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     // infinite folding loop, as every constant op would be folded to an
     // Attribute and then immediately be rematerialized as a constant op, which
     // is then put on the worklist.
-    if (!op->hasTrait<OpTrait::ConstantLike>()) {
+    if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
       SmallVector<OpFoldResult> foldResults;
       if (succeeded(op->fold(foldResults))) {
         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -840,6 +840,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     // regions to enable more aggressive CSE'ing).
     OperationFolder folder(ctx, this);
     auto insertKnownConstant = [&](Operation *op) {
+      // This hoisting is to enable more folding, so skip checking if known
+      // constant, updating dense map etc if not doing folding.
+      if (!config.fold)
+        return false;
+
       // Check for existing constants when populating the worklist. This avoids
       // accidentally reversing the constant order during processing.
       Attribute constValue;
@@ -894,9 +899,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 }
 
 LogicalResult
-mlir::applyPatternsAndFoldGreedily(Region &region,
-                                   const FrozenRewritePatternSet &patterns,
-                                   GreedyRewriteConfig config, bool *changed) {
+mlir::applyPatternsGreedily(Region &region,
+                            const FrozenRewritePatternSet &patterns,
+                            GreedyRewriteConfig config, bool *changed) {
   // The top-level operation must be known to be isolated from above to
   // prevent performing canonicalizations on operations defined at or above
   // the region containing 'op'.
@@ -1012,7 +1017,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   return region;
 }
 
-LogicalResult mlir::applyOpPatternsAndFold(
+LogicalResult mlir::applyOpPatternsGreedily(
     ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
     GreedyRewriteConfig config, bool *changed, bool *allErased) {
   if (ops.empty()) {

>From 1d50218fed60fa3030e4c312916aa06083ac68df Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Mon, 2 Sep 2024 23:34:16 +0000
Subject: [PATCH 2/3] Address review comments.

This partially incorporates #89552. I haven't exposed it on
canonicalizer pass as that could be distinct discussion.
---
 .../mlir/Transforms/GreedyPatternRewriteDriver.h    | 13 ++++++-------
 .../Transforms/Utils/GreedyPatternRewriteDriver.cpp |  9 ++-------
 2 files changed, 8 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 061cdd4b7d4d94..e320a64481f530 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -92,14 +92,13 @@ class GreedyRewriteConfig {
   /// An optional listener that should be notified about IR modifications.
   RewriterBase::Listener *listener = nullptr;
 
-  // Whether this should fold while greedily rewriting.
-  //
-  // Note: greedy here generally refers to two forms, 1) greedily applying
-  // patterns based purely on benefit and applying without backtracking using
-  // default cost model, 2) greedily folding where possible while attempting to
-  // match and rewrite using the provided patterns. With this option set to
-  // false it only does the former.
+  /// Whether this should fold while greedily rewriting. This also disables
+  /// CSE'ing constants.
   bool fold = true;
+
+  /// If set to "true", constants are CSE'd (even across multiple regions that
+  /// are in a parent-ancestor relationship).
+  bool cseConstants = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4e8b74620da5fe..99f3569b767b1c 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -840,11 +840,6 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     // regions to enable more aggressive CSE'ing).
     OperationFolder folder(ctx, this);
     auto insertKnownConstant = [&](Operation *op) {
-      // This hoisting is to enable more folding, so skip checking if known
-      // constant, updating dense map etc if not doing folding.
-      if (!config.fold)
-        return false;
-
       // Check for existing constants when populating the worklist. This avoids
       // accidentally reversing the constant order during processing.
       Attribute constValue;
@@ -857,13 +852,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
-        if (!insertKnownConstant(op))
+        if (!config.cseConstants || !insertKnownConstant(op))
           addToWorklist(op);
       });
     } else {
       // Add all nested operations to the worklist in preorder.
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-        if (!insertKnownConstant(op)) {
+        if (!config.cseConstants || !insertKnownConstant(op)) {
           addToWorklist(op);
           return WalkResult::advance();
         }

>From 9e2c773400646d4bbcc721080d6b3b86b613d1e0 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 3 Sep 2024 03:52:14 +0000
Subject: [PATCH 3/3] Add tests & fix call site

---
 .../Transforms/GreedyPatternRewriteDriver.h   |  5 +--
 .../Transforms/test-operation-folder.mlir     | 39 ++++++++++++++++++-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 10 ++++-
 3 files changed, 48 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index e320a64481f530..5dea8340b510bc 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -170,9 +170,8 @@ applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
   bool failed = false;
   for (Region &region : op->getRegions()) {
     bool regionChanged;
-    failed |=
-        applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
-            .failed();
+    failed |= applyPatternsGreedily(region, patterns, config, &regionChanged)
+                  .failed();
     anyRegionChanged |= regionChanged;
   }
   if (changed)
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 46ee07af993cc7..31f71b5d3b4d7b 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
 // RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
+// RUN: mlir-opt -test-patterns='cse-constants=false' %s | FileCheck %s --check-prefix=NOCSE
+// RUN: mlir-opt -test-patterns='fold=false' %s | FileCheck %s --check-prefix=NOFOLD
 
 func.func @foo() -> i32 {
   %c42 = arith.constant 42 : i32
@@ -25,7 +27,8 @@ func.func @test_fold_before_previously_folded_op() -> (i32, i32) {
 }
 
 func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
-  // Test that we don't reorder existing constants during folding if it isn't necessary.
+  // Test that we don't reorder existing constants during folding if it isn't
+  // necessary.
   // CHECK: %[[CST:.+]] = arith.constant 1
   // CHECK-NEXT: %[[CST:.+]] = arith.constant 2
   // CHECK-NEXT: %[[CST:.+]] = arith.constant 3
@@ -34,3 +37,37 @@ func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
   %2 = arith.constant 3 : i32
   return %0, %1, %2 : i32, i32, i32
 }
+
+func.func @test_dont_fold() -> (i32, i32, i32, i32, i32, i32) {
+  // Test either not folding or deduping constants.
+
+  // CHECK-LABEL: test_dont_fold
+  // CHECK-NOT: arith.constant 0
+  // CHECK-DAG: %[[CST:.+]] = arith.constant 0
+  // CHECK-DAG: %[[CST:.+]] = arith.constant 1
+  // CHECK-DAG: %[[CST:.+]] = arith.constant 2
+  // CHECK-DAG: %[[CST:.+]] = arith.constant 3
+  // CHECK-NEXT: return
+
+  // NOCSE-LABEL: test_dont_fold
+  // NOCSE-DAG: arith.constant 0 : i32
+  // NOCSE-DAG: arith.constant 1 : i32
+  // NOCSE-DAG: arith.constant 2 : i32
+  // NOCSE-DAG: arith.constant 1 : i32
+  // NOCSE-DAG: arith.constant 2 : i32
+  // NOCSE-DAG: arith.constant 3 : i32
+  // NOCSE-NEXT: return
+
+  // NOFOLD-LABEL: test_dont_fold
+  // NOFOLD: arith.addi
+  // NOFOLD: arith.addi
+  // NOFOLD: arith.addi
+
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2 : i32
+  %0 = arith.addi %c0, %c1 : i32
+  %1 = arith.addi %0, %c1 : i32
+  %2 = arith.addi %c2, %c1 : i32
+  return %0, %1, %2, %c0, %c1, %c2 : i32, i32, i32, i32, i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 91dfb2faa80a17..023d1783c7ad4b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -311,8 +311,9 @@ struct TestPatternDriver
     GreedyRewriteConfig config;
     config.useTopDownTraversal = this->useTopDownTraversal;
     config.maxIterations = this->maxIterations;
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
-                                       config);
+    config.fold = this->fold;
+    config.cseConstants = this->cseConstants;
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
   }
 
   Option<bool> useTopDownTraversal{
@@ -323,6 +324,11 @@ struct TestPatternDriver
       *this, "max-iterations",
       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
+  Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
+                    llvm::cl::init(GreedyRewriteConfig().fold)};
+  Option<bool> cseConstants{*this, "cse-constants",
+                            llvm::cl::desc("Whether to CSE constants"),
+                            llvm::cl::init(GreedyRewriteConfig().cseConstants)};
 };
 
 struct DumpNotifications : public RewriterBase::Listener {



More information about the Mlir-commits mailing list