[Mlir-commits] [mlir] 6bdecbc - [mlir] GreedyPatternRewriteDriver: Move strict mode to GreedyPatternRewriteDriver

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 06:58:16 PST 2023


Author: Matthias Springer
Date: 2023-01-27T15:52:01+01:00
New Revision: 6bdecbcb99bc6b8fa25a2841bf2087bdbb91b4aa

URL: https://github.com/llvm/llvm-project/commit/6bdecbcb99bc6b8fa25a2841bf2087bdbb91b4aa
DIFF: https://github.com/llvm/llvm-project/commit/6bdecbcb99bc6b8fa25a2841bf2087bdbb91b4aa.diff

LOG: [mlir] GreedyPatternRewriteDriver: Move strict mode to GreedyPatternRewriteDriver

`strictMode` is moved to GreedyRewriteConfig to simplify the API and state of rewriter classes. The region-based GreedyPatternRewriteDriver now also supports strict mode.

MultiOpPatternRewriteDriver becomes simpler: fewer method must be overridden.

Differential Revision: https://reviews.llvm.org/D142623

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/lib/Reducer/ReductionTreePass.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index e4251103d6d58..217953493c41d 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -63,6 +63,18 @@ class GreedyRewriteConfig {
   /// Only ops within the scope are added to the worklist. If no scope is
   /// specified, the closest enclosing region is used as a scope.
   Region *scope = nullptr;
+
+  /// Strict mode can restrict the ops that are added to the worklist during
+  /// the rewrite.
+  ///
+  /// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
+  /// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing ops (that
+  ///   were on the worklist at the very beginning) and newly created ops are
+  ///   enqueued. All other ops are excluded.
+  /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
+  ///   were on the worklist at the very beginning) enqueued. All other ops are
+  ///   excluded.
+  GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
 };
 
 //===----------------------------------------------------------------------===//
@@ -105,14 +117,8 @@ inline LogicalResult applyPatternsAndFoldGreedily(
 ///
 /// Newly created ops and other pre-existing ops that use results of rewritten
 /// ops or supply operands to such ops are simplified, unless such ops are
-/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless
-/// of `strictMode`).
-///
-/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
-/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly
-///   created ops are simplified. All other ops are excluded.
-/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
-///   simplified. All other ops are excluded.
+/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e.,
+/// regardless of `strictMode`).
 ///
 /// In addition to strictness, a region scope can be specified. Only ops within
 /// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
@@ -130,23 +136,17 @@ inline LogicalResult applyPatternsAndFoldGreedily(
 LogicalResult
 applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                        const FrozenRewritePatternSet &patterns,
-                       GreedyRewriteStrictness strictMode,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
                        bool *changed = nullptr, bool *allErased = nullptr);
 
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead.
-///
-/// Returns success if the iterative process converged and no more patterns can
-/// be matched.
+/// Applies the specified patterns on `op` while also trying to fold it.
+/// This function is a shortcut for the ArrayRef<Operation *> overload and
+/// behaves the same way.
 inline LogicalResult
 applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
                        bool *erased = nullptr) {
-  return applyOpPatternsAndFold(ArrayRef(op), patterns,
-                                GreedyRewriteStrictness::ExistingOps, config,
+  return applyOpPatternsAndFold(ArrayRef(op), patterns, config,
                                 /*changed=*/nullptr, erased);
 }
 

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index d516de8ee6d26..19ee04d10b19a 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -130,10 +130,10 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
   patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
                   SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
   // Apply the simplification pattern to a fixpoint.
-  if (failed(
-          applyOpPatternsAndFold(targets, frozenPatterns,
-                                 GreedyRewriteStrictness::ExistingAndNewOps))) {
+  if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
     auto diag = emitDefiniteFailure()
                 << "affine.min/max simplification did not converge";
     return diag;

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index a9d6f940200b0..716f95314eac9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -239,6 +239,7 @@ void AffineDataCopyGeneration::runOnOperation() {
   AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
   AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  (void)applyOpPatternsAndFold(copyOps, frozenPatterns,
-                               GreedyRewriteStrictness::ExistingAndNewOps);
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+  (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 6cb0a30dce39f..8039484ce8d62 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -105,6 +105,7 @@ void SimplifyAffineStructures::runOnOperation() {
     if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
       opsToSimplify.push_back(op);
   });
-  (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns,
-                               GreedyRewriteStrictness::ExistingAndNewOps);
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+  (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config);
 }

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index a146d5adfa637..54d2bd469db66 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -321,9 +321,10 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
         // Simplify/canonicalize the affine.for.
         RewritePatternSet patterns(res.getContext());
         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
+        GreedyRewriteConfig config;
+        config.strictMode = GreedyRewriteStrictness::ExistingOps;
         bool erased;
-        (void)applyOpPatternsAndFold(res, std::move(patterns),
-                                     GreedyRewriteConfig(), &erased);
+        (void)applyOpPatternsAndFold(res, std::move(patterns), config, &erased);
         if (!erased && !prologue)
           prologue = res;
         if (!erased)

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 286c38ec18ee5..24acc60cd1832 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -413,10 +413,11 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
   // in which case we return with `folded` being set.
   RewritePatternSet patterns(ifOp.getContext());
   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
-  bool erased;
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  (void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
-                               &erased);
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingOps;
+  bool erased;
+  (void)applyOpPatternsAndFold(ifOp, frozenPatterns, config, &erased);
   if (erased) {
     if (folded)
       *folded = true;

diff  --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index a30b6d19413e0..b00045a3a41b7 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -60,10 +60,13 @@ static void applyPatterns(Region &region,
   // matching in above iteration. Besides, erase op not-in-range may end up in
   // invalid module, so `applyOpPatternsAndFold` should come before that
   // transform.
-  for (Operation *op : opsInRange)
+  for (Operation *op : opsInRange) {
     // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
     // because we don't have expectation this reduction will be success or not.
-    (void)applyOpPatternsAndFold(op, patterns);
+    GreedyRewriteConfig config;
+    config.strictMode = GreedyRewriteStrictness::ExistingOps;
+    (void)applyOpPatternsAndFold(op, patterns, config);
+  }
 
   if (eraseOpNotInRange)
     for (Operation *op : opsNotInRange) {

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4a37730129691..2ef46b1fdcfc1 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -59,7 +59,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
 protected:
   /// Add the given operation to the worklist.
-  virtual void addSingleOpToWorklist(Operation *op);
+  void addSingleOpToWorklist(Operation *op);
 
   // Implement the hook for inserting operations, and make sure that newly
   // inserted ops are added to the worklist for processing.
@@ -102,6 +102,12 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// Configuration information for how to simplify.
   const GreedyRewriteConfig config;
 
+  /// The list of ops we are restricting our rewrites to. These include the
+  /// supplied set of ops as well as new ops created while rewriting those ops
+  /// depending on `strictMode`. This set is not maintained when
+  /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
+  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+
 private:
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
@@ -150,6 +156,12 @@ bool GreedyPatternRewriteDriver::simplify(Region &region) && {
     return false;
   };
 
+  // Populate strict mode ops.
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+    strictModeFilteredOps.clear();
+    region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
+  }
+
   bool changed = false;
   int64_t iteration = 0;
   do {
@@ -323,12 +335,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
 }
 
 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
-  // Check to see if the worklist already contains this op.
-  if (worklistMap.count(op))
-    return;
-
-  worklistMap[op] = worklist.size();
-  worklist.push_back(op);
+  if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
+      strictModeFilteredOps.contains(op)) {
+    // Check to see if the worklist already contains this op.
+    if (worklistMap.count(op))
+      return;
+
+    worklistMap[op] = worklist.size();
+    worklist.push_back(op);
+  }
 }
 
 Operation *GreedyPatternRewriteDriver::popFromWorklist() {
@@ -355,6 +370,8 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
+  if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
+    strictModeFilteredOps.insert(op);
   addToWorklist(op);
 }
 
@@ -391,6 +408,9 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
     removeFromWorklist(operation);
     folder.notifyRemoval(operation);
   });
+
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+    strictModeFilteredOps.erase(op);
 }
 
 void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
@@ -459,10 +479,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(
       MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-      GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
+      const GreedyRewriteConfig &config,
       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
       : GreedyPatternRewriteDriver(ctx, patterns, config),
-        strictMode(strictMode), survivingOps(survivingOps) {}
+        survivingOps(survivingOps) {}
 
   /// Performs the specified rewrites on `ops` while also trying to fold these
   /// ops. `strictMode` controls which other ops are simplified. Only ops
@@ -476,38 +496,13 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   LogicalResult simplifyLocally(ArrayRef<Operation *> op,
                                 bool *changed = nullptr) &&;
 
-protected:
-  void addSingleOpToWorklist(Operation *op) override {
-    if (strictMode == GreedyRewriteStrictness::AnyOp ||
-        strictModeFilteredOps.contains(op))
-      GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
-  }
-
 private:
-  void notifyOperationInserted(Operation *op) override {
-    if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
-      strictModeFilteredOps.insert(op);
-    GreedyPatternRewriteDriver::notifyOperationInserted(op);
-  }
-
   void notifyOperationRemoved(Operation *op) override {
     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
     if (survivingOps)
       survivingOps->erase(op);
-    if (strictMode != GreedyRewriteStrictness::AnyOp)
-      strictModeFilteredOps.erase(op);
   }
 
-  /// `strictMode` control which ops are added to the worklist during
-  /// simplification.
-  const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
-
-  /// The list of ops we are restricting our rewrites to. These include the
-  /// supplied set of ops as well as new ops created while rewriting those ops
-  /// depending on `strictMode`. This set is not maintained when `strictMode`
-  /// is GreedyRewriteStrictness::AnyOp.
-  llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
-
   /// An optional set of ops that survived the rewrite. This set is populated
   /// at the beginning of `simplifyLocally` with the inititally provided list
   /// of ops.
@@ -524,7 +519,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
     survivingOps->insert(ops.begin(), ops.end());
   }
 
-  if (strictMode != GreedyRewriteStrictness::AnyOp) {
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
     strictModeFilteredOps.clear();
     strictModeFilteredOps.insert(ops.begin(), ops.end());
   }
@@ -549,7 +544,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
     if (op == nullptr)
       continue;
 
-    assert((strictMode == GreedyRewriteStrictness::AnyOp ||
+    assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
             strictModeFilteredOps.contains(op)) &&
            "unexpected op was inserted under strict mode");
 
@@ -637,8 +632,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
 
 LogicalResult mlir::applyOpPatternsAndFold(
     ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
-    GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
-    bool *changed, bool *allErased) {
+    GreedyRewriteConfig config, bool *changed, bool *allErased) {
   if (ops.empty()) {
     if (changed)
       *changed = false;
@@ -664,8 +658,7 @@ LogicalResult mlir::applyOpPatternsAndFold(
   // Start the pattern driver.
   llvm::SmallDenseSet<Operation *, 4> surviving;
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     strictMode, config,
-                                     allErased ? &surviving : nullptr);
+                                     config, allErased ? &surviving : nullptr);
   LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
   if (allErased)
     *allErased = surviving.empty();

diff  --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 7dc478c8b9cf1..62671dde34b0e 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -132,8 +132,9 @@ void TestAffineDataCopy::runOnOperation() {
       AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
     }
   }
-  (void)applyOpPatternsAndFold(copyOps, std::move(patterns),
-                               GreedyRewriteStrictness::ExistingAndNewOps);
+  GreedyRewriteConfig config;
+  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+  (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config);
 }
 
 namespace mlir {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index da018708c702e..de4fa39218844 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -266,13 +266,13 @@ struct TestStrictPatternDriver
       }
     });
 
-    GreedyRewriteStrictness mode;
+    GreedyRewriteConfig config;
     if (strictMode == "AnyOp") {
-      mode = GreedyRewriteStrictness::AnyOp;
+      config.strictMode = GreedyRewriteStrictness::AnyOp;
     } else if (strictMode == "ExistingAndNewOps") {
-      mode = GreedyRewriteStrictness::ExistingAndNewOps;
+      config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
     } else if (strictMode == "ExistingOps") {
-      mode = GreedyRewriteStrictness::ExistingOps;
+      config.strictMode = GreedyRewriteStrictness::ExistingOps;
     } else {
       llvm_unreachable("invalid strictness option");
     }
@@ -282,8 +282,8 @@ struct TestStrictPatternDriver
     // operation will trigger the assertion while processing.
     bool changed = false;
     bool allErased = false;
-    (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
-                                 GreedyRewriteConfig(), &changed, &allErased);
+    (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
+                                 &changed, &allErased);
     Builder b(ctx);
     getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
     getOperation()->setAttr("pattern_driver_all_erased",


        


More information about the Mlir-commits mailing list