[Mlir-commits] [mlir] 87e345b - [mlir] GreedyPatternRewriteDriver: Add new strict mode option

Matthias Springer llvmlistbot at llvm.org
Fri Jan 20 01:08:22 PST 2023


Author: Matthias Springer
Date: 2023-01-20T10:08:11+01:00
New Revision: 87e345b1bdb76867cc6e9ae59b6dd2633a480d38

URL: https://github.com/llvm/llvm-project/commit/87e345b1bdb76867cc6e9ae59b6dd2633a480d38
DIFF: https://github.com/llvm/llvm-project/commit/87e345b1bdb76867cc6e9ae59b6dd2633a480d38.diff

LOG: [mlir] GreedyPatternRewriteDriver: Add new strict mode option

There are now three options:
* `AnyOp` (previously `false`)
* `ExistingAndNewOps` (previously `true`)
* `ExistingOps`: this one is new.

The last option corresponds to what the `applyOpPatternsAndFold(Operation*, ...)` overload is doing. It is now also supported on the `applyOpPatternsAndFold(ArrayRef<Operation *>, ...)` overload.

Differential Revision: https://reviews.llvm.org/D141904

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/Transforms/test-strict-pattern-driver.mlir
    mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index aaafebf252fab..72b24754fac27 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,17 @@
 
 namespace mlir {
 
+/// This enum controls which ops are put on the worklist during a greedy
+/// pattern rewrite.
+enum class GreedyRewriteStrictness {
+  /// No restrictions wrt. which ops are processed.
+  AnyOp,
+  /// Only pre-existing and newly created ops are processed.
+  ExistingAndNewOps,
+  /// Only pre-existing ops are processed.
+  ExistingOps
+};
+
 /// This class allows control over how the GreedyPatternRewriteDriver works.
 class GreedyRewriteConfig {
 public:
@@ -88,21 +99,29 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
                                      bool *erased = nullptr);
 
 /// Applies the specified rewrite patterns on `ops` while also trying to fold
-/// these ops as well as any other ops that were in turn created due to such
-/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops`
-/// remain completely unmodified if `strict` is set to true. If `strict` is
-/// false, other operations that use results of rewritten ops or supply operands
-/// to such ops are in turn simplified; any other ops still remain unmodified
-/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a
-/// result of folding, becoming dead, or via pattern rewrites. If more far
-/// reaching simplification is desired, applyPatternsAndFoldGreedily should be
-/// used.
+/// these ops.
+///
+/// Newly created ops and other pre-existing ops that use results of rewritten
+/// ops or supply operands to such ops are simplified, unless such ops are
+/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless
+/// of `strictMode`).
+///
+/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
+/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly
+///   created ops are simplified. All other ops are excluded.
+/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
+///   simplified. All other ops are excluded.
+///
+/// Note that ops in `ops` could be erased as result of folding, becoming dead,
+/// or via pattern rewrites. If more far reaching simplification is desired,
+/// applyPatternsAndFoldGreedily should be used.
 ///
 /// Returns success if the iterative process converged and no more patterns can
 /// be matched. `changed` is set to true if the IR was modified at all.
 LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                                      const FrozenRewritePatternSet &patterns,
-                                     bool strict, bool *changed = nullptr);
+                                     GreedyRewriteStrictness strictMode,
+                                     bool *changed = nullptr);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 282a35b45c537..d516de8ee6d26 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -132,7 +132,8 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   // Apply the simplification pattern to a fixpoint.
   if (failed(
-          applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) {
+          applyOpPatternsAndFold(targets, frozenPatterns,
+                                 GreedyRewriteStrictness::ExistingAndNewOps))) {
     auto diag = emitDefiniteFailure()
                 << "affine.min/max simplification did not converge";
     return diag;

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 0d84d384012c2..a9d6f940200b0 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -239,5 +239,6 @@ void AffineDataCopyGeneration::runOnOperation() {
   AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
   AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true);
+  (void)applyOpPatternsAndFold(copyOps, frozenPatterns,
+                               GreedyRewriteStrictness::ExistingAndNewOps);
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index bb5b390cb2ea4..6cb0a30dce39f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -105,5 +105,6 @@ void SimplifyAffineStructures::runOnOperation() {
     if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
       opsToSimplify.push_back(op);
   });
-  (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true);
+  (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns,
+                               GreedyRewriteStrictness::ExistingAndNewOps);
 }

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index b7ea592bfcc7d..56a94663e43e9 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -575,66 +575,54 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
                                        const FrozenRewritePatternSet &patterns,
-                                       bool strict)
+                                       GreedyRewriteStrictness strictMode)
       : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
-        strictMode(strict) {}
+        strictMode(strictMode) {}
 
+  /// Performs the specified rewrites on `ops` while also trying to fold these
+  /// ops. `strictMode` controls which other ops are simplified.
+  ///
+  /// Note that ops in `ops` could be erased as a result of folding, becoming
+  /// dead, or via pattern rewrites. The return value indicates convergence.
   LogicalResult simplifyLocally(ArrayRef<Operation *> op,
                                 bool *changed = nullptr);
 
   void addToWorklist(Operation *op) override {
-    if (!strictMode || strictModeFilteredOps.contains(op))
+    if (strictMode == GreedyRewriteStrictness::AnyOp ||
+        strictModeFilteredOps.contains(op))
       GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
   }
 
 private:
   void notifyOperationInserted(Operation *op) override {
-    if (strictMode)
+    if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
       strictModeFilteredOps.insert(op);
     GreedyPatternRewriteDriver::notifyOperationInserted(op);
   }
 
   void notifyOperationRemoved(Operation *op) override {
     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
-    if (strictMode)
+    if (strictMode != GreedyRewriteStrictness::AnyOp)
       strictModeFilteredOps.erase(op);
   }
 
-  /// If `strictMode` is true, any pre-existing ops outside of
-  /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
-  /// If `strictMode` is false, operations that use results of (or supply
-  /// operands to) any rewritten ops stemming from the simplification of the
-  /// provided ops are in turn simplified; any other ops still remain untouched
-  /// (i.e., regardless of `strictMode`).
-  bool strictMode = false;
-
-  /// The list of ops we are restricting our rewrites to if `strictMode` is on.
-  /// These include the supplied set of ops as well as new ops created while
-  /// rewriting those ops. This set is not maintained when strictMode is off.
+  /// `strictMode` control which ops are added to the worklist during
+  /// simplification.
+  GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+
+  /// The list of ops we are restricting our rewrites to. These include the
+  /// supplied set of ops as well as new ops created while rewriting those ops
+  /// depending on `strictMode`. This set is not maintained when `strictMode`
+  /// is GreedyRewriteStrictness::AnyOp.
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 };
 
 } // namespace
 
-/// Performs the specified rewrites on `ops` while also trying to fold these ops
-/// as well as any other ops that were in turn created due to these rewrite
-/// patterns. Any pre-existing ops outside of `ops` remain completely
-/// unmodified if `strictMode` is true. If `strictMode` is false, other
-/// operations that use results of rewritten ops or supply operands to such ops
-/// are in turn simplified; any other ops still remain unmodified (i.e.,
-/// regardless of `strictMode`). Note that ops in `ops` could be erased as a
-/// result of folding, becoming dead, or via pattern rewrites. Returns true if
-/// at all any changes happened.
-// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
-// or GreedyPatternRewriteDriver::simplify, this method just iterates until
-// the worklist is empty. As our objective is to keep simplification "local",
-// there is no strong rationale to re-add all operations into the worklist and
-// rerun until an iteration changes nothing. If more widereaching simplification
-// is desired, GreedyPatternRewriteDriver should be used.
 LogicalResult
 MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
                                              bool *changed) {
-  if (strictMode) {
+  if (strictMode != GreedyRewriteStrictness::AnyOp) {
     strictModeFilteredOps.clear();
     strictModeFilteredOps.insert(ops.begin(), ops.end());
   }
@@ -659,7 +647,8 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
     if (op == nullptr)
       continue;
 
-    assert((!strictMode || strictModeFilteredOps.contains(op)) &&
+    assert((strictMode == GreedyRewriteStrictness::AnyOp ||
+            strictModeFilteredOps.contains(op)) &&
            "unexpected op was inserted under strict mode");
 
     // If the operation is trivially dead - remove it.
@@ -718,9 +707,6 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
   return success(worklist.empty());
 }
 
-/// Rewrites only `op` using the supplied canonicalization patterns and
-/// folding. `erased` is set to true if the op is erased as a result of being
-/// folded, replaced, or dead.
 LogicalResult mlir::applyOpPatternsAndFold(
     Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
   // Start the pattern driver.
@@ -738,10 +724,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
   return converged;
 }
 
-LogicalResult
-mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
-                             const FrozenRewritePatternSet &patterns,
-                             bool strict, bool *changed) {
+LogicalResult mlir::applyOpPatternsAndFold(
+    ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
+    GreedyRewriteStrictness strictMode, bool *changed) {
   if (ops.empty()) {
     if (changed)
       *changed = false;
@@ -750,6 +735,6 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
 
   // Start the pattern driver.
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     strict);
+                                     strictMode);
   return driver.simplifyLocally(ops, changed);
 }

diff  --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 8c6eaf345d92d..ad6f6a5f01e20 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -1,9 +1,15 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN:     -test-strict-pattern-driver="strictness=ExistingAndNewOps" \
+// RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN
 
-// CHECK-LABEL: func @test_erase
-//       CHECK:   test.arg0
-//       CHECK:   test.arg1
-//   CHECK-NOT:   test.erase_op
+// RUN: mlir-opt \
+// RUN:     -test-strict-pattern-driver="strictness=ExistingOps" \
+// RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX
+
+// CHECK-EN-LABEL: func @test_erase
+//       CHECK-EN:   test.arg0
+//       CHECK-EN:   test.arg1
+//   CHECK-EN-NOT:   test.erase_op
 func.func @test_erase() {
   %0 = "test.arg0"() : () -> (i32)
   %1 = "test.arg1"() : () -> (i32)
@@ -11,18 +17,22 @@ func.func @test_erase() {
   return
 }
 
-// CHECK-LABEL: func @test_insert_same_op
-//       CHECK:   "test.insert_same_op"() {skip = true}
-//       CHECK:   "test.insert_same_op"() {skip = true}
+// -----
+
+// CHECK-EN-LABEL: func @test_insert_same_op
+//       CHECK-EN:   "test.insert_same_op"() {skip = true}
+//       CHECK-EN:   "test.insert_same_op"() {skip = true}
 func.func @test_insert_same_op() {
   %0 = "test.insert_same_op"() : () -> (i32)
   return
 }
 
-// CHECK-LABEL: func @test_replace_with_new_op
-//       CHECK:   %[[n:.*]] = "test.new_op"
-//       CHECK:   "test.dummy_user"(%[[n]])
-//       CHECK:   "test.dummy_user"(%[[n]])
+// -----
+
+// CHECK-EN-LABEL: func @test_replace_with_new_op
+//       CHECK-EN:   %[[n:.*]] = "test.new_op"
+//       CHECK-EN:   "test.dummy_user"(%[[n]])
+//       CHECK-EN:   "test.dummy_user"(%[[n]])
 func.func @test_replace_with_new_op() {
   %0 = "test.replace_with_new_op"() : () -> (i32)
   %1 = "test.dummy_user"(%0) : (i32) -> (i32)
@@ -30,9 +40,15 @@ func.func @test_replace_with_new_op() {
   return
 }
 
-// CHECK-LABEL: func @test_replace_with_erase_op
-//   CHECK-NOT:   test.replace_with_new_op
-//   CHECK-NOT:   test.erase_op
+// -----
+
+// CHECK-EN-LABEL: func @test_replace_with_erase_op
+//   CHECK-EN-NOT:   test.replace_with_new_op
+//   CHECK-EN-NOT:   test.erase_op
+
+// CHECK-EX-LABEL: func @test_replace_with_erase_op
+//   CHECK-EX-NOT:   test.replace_with_new_op
+//       CHECK-EX:   test.erase_op
 func.func @test_replace_with_erase_op() {
   "test.replace_with_new_op"() {create_erase_op} : () -> ()
   return

diff  --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 117f83e01f9ca..7dc478c8b9cf1 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -132,7 +132,8 @@ void TestAffineDataCopy::runOnOperation() {
       AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
     }
   }
-  (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true);
+  (void)applyOpPatternsAndFold(copyOps, std::move(patterns),
+                               GreedyRewriteStrictness::ExistingAndNewOps);
 }
 
 namespace mlir {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d3ef160176cc3..286d0de72cc9a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -244,11 +244,13 @@ struct TestStrictPatternDriver
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
 
   TestStrictPatternDriver() = default;
-  TestStrictPatternDriver(const TestStrictPatternDriver &other) = default;
+  TestStrictPatternDriver(const TestStrictPatternDriver &other) {
+    strictMode = other.strictMode;
+  }
 
   StringRef getArgument() const final { return "test-strict-pattern-driver"; }
   StringRef getDescription() const final {
-    return "Run strict mode of pattern driver";
+    return "Test strict mode of pattern driver";
   }
 
   void runOnOperation() override {
@@ -263,13 +265,28 @@ struct TestStrictPatternDriver
       }
     });
 
+    GreedyRewriteStrictness mode;
+    if (strictMode == "AnyOp") {
+      mode = GreedyRewriteStrictness::AnyOp;
+    } else if (strictMode == "ExistingAndNewOps") {
+      mode = GreedyRewriteStrictness::ExistingAndNewOps;
+    } else if (strictMode == "ExistingOps") {
+      mode = GreedyRewriteStrictness::ExistingOps;
+    } else {
+      llvm_unreachable("invalid strictness option");
+    }
+
     // Check if these transformations introduce visiting of operations that
     // are not in the `ops` set (The new created ops are valid). An invalid
     // operation will trigger the assertion while processing.
-    (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns),
-                                 /*strict=*/true);
+    (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode);
   }
 
+  Option<std::string> strictMode{
+      *this, "strictness",
+      llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
+      llvm::cl::init("AnyOp")};
+
 private:
   // New inserted operation is valid for further transformation.
   class InsertSameOp : public RewritePattern {


        


More information about the Mlir-commits mailing list