[flang-commits] [flang] [mlir] [mlir] add a fluent API to GreedyRewriterConfig (PR #137122)

Oleksandr Alex Zinenko via flang-commits flang-commits at lists.llvm.org
Wed Apr 23 23:38:56 PDT 2025


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/137122

This is similar to other configuration objects used across MLIR.

Rename some fields to better reflect that they are no longer booleans.

Reland 04d261101b4f229189463136a794e3e362a793af / #132253.

>From 7574cada6ed7eeb0c8b8cb5cb38c41ad4eda9597 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Fri, 4 Apr 2025 14:37:25 +0200
Subject: [PATCH] [mlir] add a fluent API to GreedyRewriterConfig

This is similar to other configuration objects used across MLIR.

Rename some fields to better reflect that they are no longer booleans.

Reland 04d261101b4f229189463136a794e3e362a793af / #132253.
---
 .../Optimizer/CodeGen/LowerRepackArrays.cpp   |   4 +-
 .../HLFIR/Transforms/InlineElementals.cpp     |   4 +-
 .../HLFIR/Transforms/InlineHLFIRAssign.cpp    |   4 +-
 .../HLFIR/Transforms/LowerHLFIRIntrinsics.cpp |   4 +-
 .../Transforms/OptimizedBufferization.cpp     |   4 +-
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    |   4 +-
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   6 +-
 .../Transforms/AssumedRankOpConversion.cpp    |   4 +-
 .../ConstantArgumentGlobalisation.cpp         |   6 +-
 .../Transforms/SimplifyFIROperations.cpp      |   3 +-
 .../lib/Optimizer/Transforms/StackArrays.cpp  |   3 +-
 .../Transforms/GreedyPatternRewriteDriver.h   |  76 ++++++++++---
 mlir/include/mlir/Transforms/Passes.td        |   2 +-
 .../TransformOps/AffineTransformOps.cpp       |  11 +-
 .../Transforms/AffineDataCopyGeneration.cpp   |   7 +-
 .../Transforms/SimplifyAffineStructures.cpp   |   7 +-
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   |   9 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |   8 +-
 .../Transforms/IntRangeOptimizations.cpp      |  16 ++-
 .../BufferDeallocationSimplification.cpp      |  14 +--
 .../TransformOps/LinalgTransformOps.cpp       |   6 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |   7 +-
 .../SCF/Transforms/TileUsingInterface.cpp     |   8 +-
 .../SPIRV/Transforms/SPIRVConversion.cpp      |   6 +-
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  14 +--
 mlir/lib/Reducer/ReductionTreePass.cpp        |   8 +-
 mlir/lib/Transforms/Canonicalizer.cpp         |  16 +--
 .../Utils/GreedyPatternRewriteDriver.cpp      | 103 +++++++++---------
 .../lib/Dialect/Affine/TestAffineDataCopy.cpp |   2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  28 ++---
 30 files changed, 225 insertions(+), 169 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
index 7deed3d44ae5b..7fb713ff1a6c7 100644
--- a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
+++ b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
@@ -357,8 +357,8 @@ class LowerRepackArraysPass
     patterns.insert<PackArrayConversion>(context);
     patterns.insert<UnpackArrayConversion>(context);
     mlir::GreedyRewriteConfig config;
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
     (void)applyPatternsGreedily(module, std::move(patterns), config);
   }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
index b68fe6ee0c747..c42b895946d19 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
@@ -119,8 +119,8 @@ class InlineElementalsPass
 
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks.
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
 
     mlir::RewritePatternSet patterns(context);
     patterns.insert<InlineElementalConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
index 249976d5509b0..6e209cce07ad4 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
@@ -135,8 +135,8 @@ class InlineHLFIRAssignPass
 
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks.
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
 
     mlir::RewritePatternSet patterns(context);
     patterns.insert<InlineHLFIRAssignConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 7c0fcba806869..31e5bc1193e22 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -557,8 +557,8 @@ class LowerHLFIRIntrinsics
     // Pattern rewriting only requires that the resulting IR is still valid
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
 
     if (mlir::failed(
             mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 79aabd2981e1a..2f6ee2592a84f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -875,8 +875,8 @@ class OptimizedBufferizationPass
 
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
 
     mlir::RewritePatternSet patterns(context);
     // TODO: right now the patterns are non-conflicting,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index e9d820adbd22b..1dea7d89e180d 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -2132,8 +2132,8 @@ class SimplifyHLFIRIntrinsics
 
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
 
     mlir::RewritePatternSet patterns(context);
     patterns.insert<TransposeAsElementalConversion>(context);
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 81ff6bf9b2c6a..7a06a27748ebd 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -35,7 +35,8 @@ void addNestedPassToAllTopLevelOperationsConditionally(
 
 void addCanonicalizerPassWithoutRegionSimplification(mlir::OpPassManager &pm) {
   mlir::GreedyRewriteConfig config;
-  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+  config.setRegionSimplificationLevel(
+      mlir::GreedySimplifyRegionLevel::Disabled);
   pm.addPass(mlir::createCanonicalizerPass(config));
 }
 
@@ -163,7 +164,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
 
   // simplify the IR
   mlir::GreedyRewriteConfig config;
-  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+  config.setRegionSimplificationLevel(
+      mlir::GreedySimplifyRegionLevel::Disabled);
   pm.addPass(mlir::createCSEPass());
   fir::addAVC(pm, pc.OptLevel);
   addNestedPassToAllTopLevelOperations<PassConstructor>(
diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
index d0bd67a236419..6af1cb988a4c1 100644
--- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
@@ -152,8 +152,8 @@ class AssumedRankOpConversion
     patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
     patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
     mlir::GreedyRewriteConfig config;
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
     (void)applyPatternsGreedily(mod, std::move(patterns), config);
   }
 };
diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 562f3058f20f3..239a7cdaa4cf2 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -168,9 +168,9 @@ class ConstantArgumentGlobalisationOpt
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     mlir::GreedyRewriteConfig config;
-    config.enableRegionSimplification =
-        mlir::GreedySimplifyRegionLevel::Disabled;
-    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
+    config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);
 
     patterns.insert<CallOpRewriter>(context, *di);
     if (mlir::failed(
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index 212de2f2286db..6d106046b70f2 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -205,7 +205,8 @@ void SimplifyFIROperationsPass::runOnOperation() {
   fir::populateSimplifyFIROperationsPatterns(patterns,
                                              preferInlineImplementation);
   mlir::GreedyRewriteConfig config;
-  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+  config.setRegionSimplificationLevel(
+      mlir::GreedySimplifyRegionLevel::Disabled);
 
   if (mlir::failed(
           mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 9a6566bef50f1..f9b9b4f4ff385 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -806,7 +806,8 @@ void StackArraysPass::runOnOperation() {
   mlir::RewritePatternSet patterns(&context);
   mlir::GreedyRewriteConfig config;
   // prevent the pattern driver form merging blocks
-  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+  config.setRegionSimplificationLevel(
+      mlir::GreedySimplifyRegionLevel::Disabled);
 
   patterns.insert<AllocMemConversion>(&context, *candidateOps);
   if (mlir::failed(mlir::applyOpPatternsGreedily(
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 110b4f64856eb..45e61b68f5db2 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -49,25 +49,43 @@ class GreedyRewriteConfig {
   /// larger patterns when given an ambiguous pattern set.
   ///
   /// Note: Only applicable when simplifying entire regions.
-  bool useTopDownTraversal = false;
+  bool getUseTopDownTraversal() const { return useTopDownTraversal; }
+  GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
+    useTopDownTraversal = use;
+    return *this;
+  }
 
   /// Perform control flow optimizations to the region tree after applying all
   /// patterns.
   ///
   /// Note: Only applicable when simplifying entire regions.
-  GreedySimplifyRegionLevel enableRegionSimplification =
-      GreedySimplifyRegionLevel::Aggressive;
+  GreedySimplifyRegionLevel getRegionSimplificationLevel() const {
+    return regionSimplificationLevel;
+  }
+  GreedyRewriteConfig &
+  setRegionSimplificationLevel(GreedySimplifyRegionLevel level) {
+    regionSimplificationLevel = level;
+    return *this;
+  }
 
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
   /// disable this iteration limit.
   ///
   /// Note: Only applicable when simplifying entire regions.
-  int64_t maxIterations = 10;
+  int64_t getMaxIterations() const { return maxIterations; }
+  GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
+    maxIterations = iterations;
+    return *this;
+  }
 
   /// This specifies the maximum number of rewrites within an iteration. Use
   /// `kNoLimit` to disable this limit.
-  int64_t maxNumRewrites = kNoLimit;
+  int64_t getMaxNumRewrites() const { return maxNumRewrites; }
+  GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
+    maxNumRewrites = limit;
+    return *this;
+  }
 
   static constexpr int64_t kNoLimit = -1;
 
@@ -75,7 +93,11 @@ class GreedyRewriteConfig {
   /// specified, the closest enclosing region around the initial list of ops
   /// (or the specified region, depending on which greedy rewrite entry point
   /// is used) is used as a scope.
-  Region *scope = nullptr;
+  Region *getScope() const { return scope; }
+  GreedyRewriteConfig &setScope(Region *scope) {
+    this->scope = scope;
+    return *this;
+  }
 
   /// Strict mode can restrict the ops that are added to the worklist during
   /// the rewrite.
@@ -87,16 +109,44 @@ class GreedyRewriteConfig {
   /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
   ///   were on the worklist at the very beginning) enqueued. All other ops are
   ///   excluded.
-  GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+  GreedyRewriteStrictness getStrictness() const { return strictness; }
+  GreedyRewriteConfig &setStrictness(GreedyRewriteStrictness mode) {
+    strictness = mode;
+    return *this;
+  }
 
   /// An optional listener that should be notified about IR modifications.
-  RewriterBase::Listener *listener = nullptr;
+  RewriterBase::Listener *getListener() const { return listener; }
+  GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
+    this->listener = listener;
+    return *this;
+  }
 
   /// Whether this should fold while greedily rewriting.
-  bool fold = true;
+  bool isFoldingEnabled() const { return fold; }
+  GreedyRewriteConfig &enableFolding(bool enable = true) {
+    fold = enable;
+    return *this;
+  }
 
   /// If set to "true", constants are CSE'd (even across multiple regions that
   /// are in a parent-ancestor relationship).
+  bool isConstantCSEEnabled() const { return cseConstants; }
+  GreedyRewriteConfig &enableConstantCSE(bool enable = true) {
+    cseConstants = enable;
+    return *this;
+  }
+
+private:
+  Region *scope = nullptr;
+  bool useTopDownTraversal = false;
+  GreedySimplifyRegionLevel regionSimplificationLevel =
+      GreedySimplifyRegionLevel::Aggressive;
+  int64_t maxIterations = 10;
+  int64_t maxNumRewrites = kNoLimit;
+  GreedyRewriteStrictness strictness = GreedyRewriteStrictness::AnyOp;
+  RewriterBase::Listener *listener = nullptr;
+  bool fold = true;
   bool cseConstants = true;
 };
 
@@ -128,14 +178,14 @@ applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
                       GreedyRewriteConfig config = GreedyRewriteConfig(),
                       bool *changed = nullptr);
 /// Same as `applyPatternsAndGreedily` above with folding.
-/// FIXME: Remove this once transition to above is complieted.
+/// FIXME: Remove this once transition to above is completed.
 LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
 inline LogicalResult
 applyPatternsAndFoldGreedily(Region &region,
                              const FrozenRewritePatternSet &patterns,
                              GreedyRewriteConfig config = GreedyRewriteConfig(),
                              bool *changed = nullptr) {
-  config.fold = true;
+  config.enableFolding();
   return applyPatternsGreedily(region, patterns, config, changed);
 }
 
@@ -187,7 +237,7 @@ applyPatternsAndFoldGreedily(Operation *op,
                              const FrozenRewritePatternSet &patterns,
                              GreedyRewriteConfig config = GreedyRewriteConfig(),
                              bool *changed = nullptr) {
-  config.fold = true;
+  config.enableFolding();
   return applyPatternsGreedily(op, patterns, config, changed);
 }
 
@@ -233,7 +283,7 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                        const FrozenRewritePatternSet &patterns,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
                        bool *changed = nullptr, bool *allErased = nullptr) {
-  config.fold = true;
+  config.enableFolding();
   return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
 }
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..1e89a78912e99 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -33,7 +33,7 @@ def Canonicalizer : Pass<"canonicalize"> {
     Option<"topDownProcessingEnabled", "top-down", "bool",
            /*default=*/"true",
            "Seed the worklist in general top-down order">,
-    Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+    Option<"regionSimplifyLevel", "region-simplify", "mlir::GreedySimplifyRegionLevel",
            /*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
            "Perform control flow optimizations to the region tree",
              [{::llvm::cl::values(
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 9f7df7823d997..43d37ee3332ef 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
   patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
                   SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  GreedyRewriteConfig config;
-  config.listener =
-      static_cast<RewriterBase::Listener *>(rewriter.getListener());
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
   // Apply the simplification pattern to a fixpoint.
-  if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
+  if (failed(applyOpPatternsGreedily(
+          targets, frozenPatterns,
+          GreedyRewriteConfig()
+              .setListener(
+                  static_cast<RewriterBase::Listener *>(rewriter.getListener()))
+              .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
     auto diag = emitDefiniteFailure()
                 << "affine.min/max simplification did not converge";
     return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 4d30213cc6ec2..62c1857e4b1da 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
   AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
   AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
-  (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
+  (void)applyOpPatternsGreedily(
+      copyOps, frozenPatterns,
+      GreedyRewriteConfig().setStrictness(
+          GreedyRewriteStrictness::ExistingAndNewOps));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 31711ade3153b..9e9096c2e3186 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
     if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
       opsToSimplify.push_back(op);
   });
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
-  (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
+  (void)applyOpPatternsGreedily(
+      opsToSimplify, frozenPatterns,
+      GreedyRewriteConfig().setStrictness(
+          GreedyRewriteStrictness::ExistingAndNewOps));
 }
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index dd539ff685653..0d4ba3940c48e 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
         // Simplify/canonicalize the affine.for.
         RewritePatternSet patterns(res.getContext());
         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
-        GreedyRewriteConfig config;
-        config.strictMode = GreedyRewriteStrictness::ExistingOps;
         bool erased;
-        (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
-                                      config, /*changed=*/nullptr, &erased);
+        (void)applyOpPatternsGreedily(
+            res.getOperation(), std::move(patterns),
+            GreedyRewriteConfig().setStrictness(
+                GreedyRewriteStrictness::ExistingAndNewOps),
+            /*changed=*/nullptr, &erased);
         if (!erased && !prologue)
           prologue = res;
         if (!erased)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2925aa918cb1c..11798b99fa879 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -426,11 +426,11 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
   RewritePatternSet patterns(ifOp.getContext());
   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingOps;
   bool erased;
-  (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
-                                /*changed=*/nullptr, &erased);
+  (void)applyOpPatternsGreedily(
+      ifOp.getOperation(), frozenPatterns,
+      GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
+      /*changed=*/nullptr, &erased);
   if (erased) {
     if (folded)
       *folded = true;
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 602d80a45993e..f2f93883eb2b7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -494,10 +494,9 @@ struct IntRangeOptimizationsPass final
     RewritePatternSet patterns(ctx);
     populateIntRangeOptimizationsPatterns(patterns, solver);
 
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+    if (failed(applyPatternsGreedily(
+            op, std::move(patterns),
+            GreedyRewriteConfig().setListener(&listener))))
       signalPassFailure();
   }
 };
@@ -520,13 +519,12 @@ struct IntRangeNarrowingPass final
     RewritePatternSet patterns(ctx);
     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
 
-    GreedyRewriteConfig config;
     // We specifically need bottom-up traversal as cmpi pattern needs range
     // data, attached to its original argument values.
-    config.useTopDownTraversal = false;
-    config.listener = &listener;
-
-    if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+    if (failed(applyPatternsGreedily(
+            op, std::move(patterns),
+            GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
+                &listener))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 35f86a62ae592..c5fab80ecaa08 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+
+    populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
     // We don't want that the block structure changes invalidating the
-    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
     // region simplification
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
-    populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
-
-    if (failed(
-            applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+    if (failed(applyPatternsGreedily(
+            getOperation(), std::move(patterns),
+            GreedyRewriteConfig().setRegionSimplificationLevel(
+                GreedySimplifyRegionLevel::Normal))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c90ebe4487ca4..b20e6050fb4f8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3587,9 +3587,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
   vector::populateVectorStepLoweringPatterns(patterns);
 
   TrackingListener listener(state, *this);
-  GreedyRewriteConfig config;
-  config.listener = &listener;
-  if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
+  if (failed(
+          applyPatternsGreedily(target, std::move(patterns),
+                                GreedyRewriteConfig().setListener(&listener))))
     return emitDefaultDefiniteFailure(target);
 
   results.push_back(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..62d016b87d627 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2327,10 +2327,9 @@ struct LinalgElementwiseOpFusionPass
     // Add constant folding patterns.
     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
 
-    // Use TopDownTraversal for compile time reasons
-    GreedyRewriteConfig grc;
-    grc.useTopDownTraversal = true;
-    (void)applyPatternsGreedily(op, std::move(patterns), grc);
+    // Use TopDownTraversal for compile time reasons.
+    (void)applyPatternsGreedily(op, std::move(patterns),
+                                GreedyRewriteConfig().setUseTopDownTraversal());
   }
 };
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 91862d2e17d71..7edf19689d2e1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
   if (!patterns)
     return success();
 
-  GreedyRewriteConfig config;
-  config.listener = this;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
-  return applyOpPatternsGreedily(ops, patterns.value(), config);
+  return applyOpPatternsGreedily(
+      ops, patterns.value(),
+      GreedyRewriteConfig().setListener(this).setStrictness(
+          GreedyRewriteStrictness::ExistingAndNewOps));
 }
 
 void SliceTrackingListener::notifyOperationInserted(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 19b9af146f4a4..811f03abb3461 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1353,9 +1353,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
   // We only want to apply signature conversion once to the existing func ops.
   // Without specifying strictMode, the greedy pattern rewriter will keep
   // looking for newly created func ops.
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingOps;
-  return applyPatternsGreedily(op, std::move(patterns), config);
+  return applyPatternsGreedily(op, std::move(patterns),
+                               GreedyRewriteConfig().setStrictness(
+                                   GreedyRewriteStrictness::ExistingOps));
 }
 
 LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 4fe89f3f7fb9e..84d339a985c38 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -394,16 +394,16 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
 
   // Configure the GreedyPatternRewriteDriver.
   GreedyRewriteConfig config;
-  config.listener =
-      static_cast<RewriterBase::Listener *>(rewriter.getListener());
+  config.setListener(
+      static_cast<RewriterBase::Listener *>(rewriter.getListener()));
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
 
-  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
-                             ? GreedyRewriteConfig::kNoLimit
-                             : getMaxIterations();
-  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
+  config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
                               ? GreedyRewriteConfig::kNoLimit
-                              : getMaxNumRewrites();
+                              : getMaxIterations());
+  config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
+                               ? GreedyRewriteConfig::kNoLimit
+                               : getMaxNumRewrites());
 
   // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
   // was requested, apply the greedy pattern rewrite only once. (The greedy
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 7292752c712ae..549e4f2bd813b 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -62,11 +62,11 @@ static void applyPatterns(Region &region,
   // before that transform.
   for (Operation *op : opsInRange) {
     // `applyOpPatternsGreedily` with folding returns whether the op is
-    // convered. Omit it because we don't have expectation this reduction will
+    // converted. Omit it because we don't have expectation this reduction will
     // be success or not.
-    GreedyRewriteConfig config;
-    config.strictMode = GreedyRewriteStrictness::ExistingOps;
-    (void)applyOpPatternsGreedily(op, patterns, config);
+    (void)applyOpPatternsGreedily(op, patterns,
+                                  GreedyRewriteConfig().setStrictness(
+                                      GreedyRewriteStrictness::ExistingOps));
   }
 
   if (eraseOpNotInRange)
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 7ccd503fb0288..4b0ac28a03713 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -32,10 +32,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
                 ArrayRef<std::string> disabledPatterns,
                 ArrayRef<std::string> enabledPatterns)
       : config(config) {
-    this->topDownProcessingEnabled = config.useTopDownTraversal;
-    this->enableRegionSimplification = config.enableRegionSimplification;
-    this->maxIterations = config.maxIterations;
-    this->maxNumRewrites = config.maxNumRewrites;
+    this->topDownProcessingEnabled = config.getUseTopDownTraversal();
+    this->regionSimplifyLevel = config.getRegionSimplificationLevel();
+    this->maxIterations = config.getMaxIterations();
+    this->maxNumRewrites = config.getMaxNumRewrites();
     this->disabledPatterns = disabledPatterns;
     this->enabledPatterns = enabledPatterns;
   }
@@ -44,10 +44,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
     // Set the config from possible pass options set in the meantime.
-    config.useTopDownTraversal = topDownProcessingEnabled;
-    config.enableRegionSimplification = enableRegionSimplification;
-    config.maxIterations = maxIterations;
-    config.maxNumRewrites = maxNumRewrites;
+    config.setUseTopDownTraversal(topDownProcessingEnabled);
+    config.setRegionSimplificationLevel(regionSimplifyLevel);
+    config.setMaxIterations(maxIterations);
+    config.setMaxNumRewrites(maxNumRewrites);
 
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7c1cfd91f85e6..41baad9c1997d 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -416,7 +416,8 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
       // clang-format off
       , expensiveChecks(
           /*driver=*/this,
-          /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
+          /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
+                                         : nullptr)
 // clang-format on
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 {
@@ -455,8 +456,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
   bool changed = false;
   int64_t numRewrites = 0;
   while (!worklist.empty() &&
-         (numRewrites < config.maxNumRewrites ||
-          config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
+         (numRewrites < config.getMaxNumRewrites() ||
+          config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
     auto *op = worklist.pop();
 
     LLVM_DEBUG({
@@ -488,7 +489,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     // infinite folding loop, as every constant op would be folded to an
     // Attribute and then immediately be rematerialized as a constant op, which
     // is then put on the worklist.
-    if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
+    if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
       SmallVector<OpFoldResult> foldResults;
       if (succeeded(op->fold(foldResults))) {
         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -574,21 +575,21 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         logger.getOStream() << ")' {\n";
         logger.indent();
       });
-      if (config.listener)
-        config.listener->notifyPatternBegin(pattern, op);
+      if (RewriterBase::Listener *listener = config.getListener())
+        listener->notifyPatternBegin(pattern, op);
       return true;
     };
     function_ref<bool(const Pattern &)> canApply = canApplyCallback;
     auto onFailureCallback = [&](const Pattern &pattern) {
       LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-      if (config.listener)
-        config.listener->notifyPatternEnd(pattern, failure());
+      if (RewriterBase::Listener *listener = config.getListener())
+        listener->notifyPatternEnd(pattern, failure());
     };
     function_ref<void(const Pattern &)> onFailure = onFailureCallback;
     auto onSuccessCallback = [&](const Pattern &pattern) {
       LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-      if (config.listener)
-        config.listener->notifyPatternEnd(pattern, success());
+      if (RewriterBase::Listener *listener = config.getListener())
+        listener->notifyPatternEnd(pattern, success());
       return success();
     };
     function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
@@ -596,7 +597,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 #ifdef NDEBUG
     // Optimization: PatternApplicator callbacks are not needed when running in
     // optimized mode and without a listener.
-    if (!config.listener) {
+    if (!config.getListener()) {
       canApply = nullptr;
       onFailure = nullptr;
       onSuccess = nullptr;
@@ -604,8 +605,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 #endif // NDEBUG
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-    if (config.scope) {
-      expensiveChecks.computeFingerPrints(config.scope->getParentOp());
+    if (config.getScope()) {
+      expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
     }
     auto clearFingerprints =
         llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
@@ -640,7 +641,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
   do {
     ancestors.push_back(op);
     region = op->getParentRegion();
-    if (config.scope == region) {
+    if (config.getScope() == region) {
       // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
       for (Operation *op : ancestors)
         addSingleOpToWorklist(op);
@@ -652,20 +653,20 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
 }
 
 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
-  if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
+  if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
       strictModeFilteredOps.contains(op))
     worklist.push(op);
 }
 
 void GreedyPatternRewriteDriver::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  if (config.listener)
-    config.listener->notifyBlockInserted(block, previous, previousIt);
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyBlockInserted(block, previous, previousIt);
 }
 
 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
-  if (config.listener)
-    config.listener->notifyBlockErased(block);
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyBlockErased(block);
 }
 
 void GreedyPatternRewriteDriver::notifyOperationInserted(
@@ -674,9 +675,9 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  if (config.listener)
-    config.listener->notifyOperationInserted(op, previous);
-  if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyOperationInserted(op, previous);
+  if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
     strictModeFilteredOps.insert(op);
   addToWorklist(op);
 }
@@ -686,8 +687,8 @@ void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  if (config.listener)
-    config.listener->notifyOperationModified(op);
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyOperationModified(op);
   addToWorklist(op);
 }
 
@@ -736,18 +737,18 @@ void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
   // the part of the IR that is taken into account for the "expensive checks".
   // A greedy pattern rewrite is not allowed to erase the parent op of the scope
   // region, as that would break the worklist handling and the expensive checks.
-  if (config.scope && config.scope->getParentOp() == op)
+  if (Region *scope = config.getScope(); scope->getParentOp() == op)
     llvm_unreachable(
         "scope region must not be erased during greedy pattern rewrite");
 #endif // NDEBUG
 
-  if (config.listener)
-    config.listener->notifyOperationErased(op);
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyOperationErased(op);
 
   addOperandsToWorklist(op);
   worklist.remove(op);
 
-  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.erase(op);
 }
 
@@ -757,8 +758,8 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  if (config.listener)
-    config.listener->notifyOperationReplaced(op, replacement);
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyOperationReplaced(op, replacement);
 }
 
 void GreedyPatternRewriteDriver::notifyMatchFailure(
@@ -768,8 +769,8 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
     reasonCallback(diag);
     logger.startLine() << "** Match Failure : " << diag.str() << "\n";
   });
-  if (config.listener)
-    config.listener->notifyMatchFailure(loc, reasonCallback);
+  if (RewriterBase::Listener *listener = config.getListener())
+    listener->notifyMatchFailure(loc, reasonCallback);
 }
 
 //===----------------------------------------------------------------------===//
@@ -800,7 +801,7 @@ RegionPatternRewriteDriver::RegionPatternRewriteDriver(
     const GreedyRewriteConfig &config, Region &region)
     : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
   // Populate strict mode ops.
-  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
     region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
   }
 }
@@ -829,8 +830,8 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
   MLIRContext *ctx = rewriter.getContext();
   do {
     // Check if the iteration limit was reached.
-    if (++iteration > config.maxIterations &&
-        config.maxIterations != GreedyRewriteConfig::kNoLimit)
+    if (++iteration > config.getMaxIterations() &&
+        config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
       break;
 
     // New iteration: start with an empty worklist.
@@ -849,16 +850,16 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
       return false;
     };
 
-    if (!config.useTopDownTraversal) {
+    if (!config.getUseTopDownTraversal()) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
-        if (!config.cseConstants || !insertKnownConstant(op))
+        if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
           addToWorklist(op);
       });
     } else {
       // Add all nested operations to the worklist in preorder.
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-        if (!config.cseConstants || !insertKnownConstant(op)) {
+        if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
           addToWorklist(op);
           return WalkResult::advance();
         }
@@ -875,11 +876,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
-          if (config.enableRegionSimplification !=
+          if (config.getRegionSimplificationLevel() !=
               GreedySimplifyRegionLevel::Disabled) {
             continueRewrites |= succeeded(simplifyRegions(
                 rewriter, region,
-                /*mergeBlocks=*/config.enableRegionSimplification ==
+                /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
                     GreedySimplifyRegionLevel::Aggressive));
           }
         },
@@ -904,11 +905,11 @@ mlir::applyPatternsGreedily(Region &region,
          "patterns can only be applied to operations IsolatedFromAbove");
 
   // Set scope if not specified.
-  if (!config.scope)
-    config.scope = ®ion;
+  if (!config.getScope())
+    config.setScope(&region);
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-  if (failed(verify(config.scope->getParentOp())))
+  if (failed(verify(config.getScope()->getParentOp())))
     llvm::report_fatal_error(
         "greedy pattern rewriter input IR failed to verify");
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -919,7 +920,7 @@ mlir::applyPatternsGreedily(Region &region,
   LogicalResult converged = std::move(driver).simplify(changed);
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
-                 << config.maxIterations << " times\n";
+                 << config.getMaxIterations() << " times\n";
   });
   return converged;
 }
@@ -960,7 +961,7 @@ MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
     llvm::SmallDenseSet<Operation *, 4> *survivingOps)
     : GreedyPatternRewriteDriver(ctx, patterns, config),
       survivingOps(survivingOps) {
-  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+  if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.insert_range(ops);
 
   if (survivingOps) {
@@ -1024,22 +1025,22 @@ LogicalResult mlir::applyOpPatternsGreedily(
   }
 
   // Determine scope of rewrite.
-  if (!config.scope) {
+  if (!config.getScope()) {
     // Compute scope if none was provided. The scope will remain `nullptr` if
     // there is a top-level op among `ops`.
-    config.scope = findCommonAncestor(ops);
+    config.setScope(findCommonAncestor(ops));
   } else {
     // If a scope was provided, make sure that all ops are in scope.
 #ifndef NDEBUG
     bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
-      return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
+      return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
     });
     assert(allOpsInScope && "ops must be within the specified scope");
 #endif // NDEBUG
   }
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-  if (config.scope && failed(verify(config.scope->getParentOp())))
+  if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
     llvm::report_fatal_error(
         "greedy pattern rewriter input IR failed to verify");
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -1050,11 +1051,11 @@ LogicalResult mlir::applyOpPatternsGreedily(
                                      config, ops,
                                      allErased ? &surviving : nullptr);
   LogicalResult converged = std::move(driver).simplify(ops, changed);
-  if (allErased)
+if (allErased)
     *allErased = surviving.empty();
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after "
-                 << config.maxNumRewrites << " rewrites";
+                 << config.getMaxNumRewrites() << " rewrites";
   });
   return converged;
 }
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index d6aaa6faf94cb..2a54e0c28f71f 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -144,7 +144,7 @@ void TestAffineDataCopy::runOnOperation() {
     }
   }
   GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+  config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
   (void)applyOpPatternsGreedily(copyOps, std::move(patterns), config);
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index db02a122872d9..d073843484d81 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -386,26 +386,26 @@ struct TestGreedyPatternDriver
     patterns.insert<IncrementIntAttribute<3>>(&getContext());
 
     GreedyRewriteConfig config;
-    config.useTopDownTraversal = this->useTopDownTraversal;
-    config.maxIterations = this->maxIterations;
-    config.fold = this->fold;
-    config.cseConstants = this->cseConstants;
+    config.setUseTopDownTraversal(useTopDownTraversal)
+        .setMaxIterations(this->maxIterations)
+        .enableFolding(this->fold)
+        .enableConstantCSE(this->cseConstants);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
   }
 
   Option<bool> useTopDownTraversal{
       *this, "top-down",
       llvm::cl::desc("Seed the worklist in general top-down order"),
-      llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
+      llvm::cl::init(GreedyRewriteConfig().getUseTopDownTraversal())};
   Option<int> maxIterations{
       *this, "max-iterations",
       llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
-      llvm::cl::init(GreedyRewriteConfig().maxIterations)};
+      llvm::cl::init(GreedyRewriteConfig().getMaxIterations())};
   Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
-                    llvm::cl::init(GreedyRewriteConfig().fold)};
-  Option<bool> cseConstants{*this, "cse-constants",
-                            llvm::cl::desc("Whether to CSE constants"),
-                            llvm::cl::init(GreedyRewriteConfig().cseConstants)};
+                    llvm::cl::init(GreedyRewriteConfig().isFoldingEnabled())};
+  Option<bool> cseConstants{
+      *this, "cse-constants", llvm::cl::desc("Whether to CSE constants"),
+      llvm::cl::init(GreedyRewriteConfig().isConstantCSEEnabled())};
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
@@ -501,13 +501,13 @@ struct TestStrictPatternDriver
 
     DumpNotifications dumpNotifications;
     GreedyRewriteConfig config;
-    config.listener = &dumpNotifications;
+    config.setListener(&dumpNotifications);
     if (strictMode == "AnyOp") {
-      config.strictMode = GreedyRewriteStrictness::AnyOp;
+      config.setStrictness(GreedyRewriteStrictness::AnyOp);
     } else if (strictMode == "ExistingAndNewOps") {
-      config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+      config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
     } else if (strictMode == "ExistingOps") {
-      config.strictMode = GreedyRewriteStrictness::ExistingOps;
+      config.setStrictness(GreedyRewriteStrictness::ExistingOps);
     } else {
       llvm_unreachable("invalid strictness option");
     }



More information about the flang-commits mailing list