[Mlir-commits] [mlir] [mlir] add a fluent API to GreedyRewriterConfig (PR #132253)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 20 10:03:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

This is similar to other configuration objects used across MLIR.

---
Full diff: https://github.com/llvm/llvm-project/pull/132253.diff


11 Files Affected:

- (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+37) 
- (modified) mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp (+6-5) 
- (modified) mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp (+4-3) 
- (modified) mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp (+4-3) 
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+5-4) 
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+7-9) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+7-7) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+3-4) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+4-4) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 110b4f64856eb..aff2616d11276 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -50,6 +50,10 @@ class GreedyRewriteConfig {
   ///
   /// Note: Only applicable when simplifying entire regions.
   bool useTopDownTraversal = false;
+  GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
+    useTopDownTraversal = use;
+    return *this;
+  }
 
   /// Perform control flow optimizations to the region tree after applying all
   /// patterns.
@@ -57,6 +61,11 @@ class GreedyRewriteConfig {
   /// Note: Only applicable when simplifying entire regions.
   GreedySimplifyRegionLevel enableRegionSimplification =
       GreedySimplifyRegionLevel::Aggressive;
+  GreedyRewriteConfig &
+  setEnableRegionSimplification(GreedySimplifyRegionLevel level) {
+    enableRegionSimplification = level;
+    return *this;
+  }
 
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
@@ -64,10 +73,18 @@ class GreedyRewriteConfig {
   ///
   /// Note: Only applicable when simplifying entire regions.
   int64_t maxIterations = 10;
+  GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
+    maxIterations = iterations;
+    return *this;
+  }
 
   /// This specifies the maximum number of rewrites within an iteration. Use
   /// `kNoLimit` to disable this limit.
   int64_t maxNumRewrites = kNoLimit;
+  GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
+    maxNumRewrites = limit;
+    return *this;
+  }
 
   static constexpr int64_t kNoLimit = -1;
 
@@ -76,6 +93,10 @@ class GreedyRewriteConfig {
   /// (or the specified region, depending on which greedy rewrite entry point
   /// is used) is used as a scope.
   Region *scope = nullptr;
+  GreedyRewriteConfig &setScope(Region *scope) {
+    this->scope = scope;
+    return *this;
+  }
 
   /// Strict mode can restrict the ops that are added to the worklist during
   /// the rewrite.
@@ -88,16 +109,32 @@ class GreedyRewriteConfig {
   ///   were on the worklist at the very beginning) enqueued. All other ops are
   ///   excluded.
   GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+  GreedyRewriteConfig &setStrictMode(GreedyRewriteStrictness mode) {
+    strictMode = mode;
+    return *this;
+  }
 
   /// An optional listener that should be notified about IR modifications.
   RewriterBase::Listener *listener = nullptr;
+  GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
+    this->listener = listener;
+    return *this;
+  }
 
   /// Whether this should fold while greedily rewriting.
   bool fold = true;
+  GreedyRewriteConfig &setFold(bool enable = true) {
+    fold = enable;
+    return *this;
+  }
 
   /// If set to "true", constants are CSE'd (even across multiple regions that
   /// are in a parent-ancestor relationship).
   bool cseConstants = true;
+  GreedyRewriteConfig &setCSEConstants(bool enable = true) {
+    cseConstants = enable;
+    return *this;
+  }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 9f7df7823d997..46c2d11d89d42 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
   patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
                   SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  GreedyRewriteConfig config;
-  config.listener =
-      static_cast<RewriterBase::Listener *>(rewriter.getListener());
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
   // Apply the simplification pattern to a fixpoint.
-  if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
+  if (failed(applyOpPatternsGreedily(
+          targets, frozenPatterns,
+          GreedyRewriteConfig()
+              .setListener(
+                  static_cast<RewriterBase::Listener *>(rewriter.getListener()))
+              .setStrictMode(GreedyRewriteStrictness::ExistingAndNewOps)))) {
     auto diag = emitDefiniteFailure()
                 << "affine.min/max simplification did not converge";
     return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 4d30213cc6ec2..a7386af950f56 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
   AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
   AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
-  (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
+  (void)applyOpPatternsGreedily(
+      copyOps, frozenPatterns,
+      GreedyRewriteConfig().setStrictMode(
+          GreedyRewriteStrictness::ExistingAndNewOps));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 31711ade3153b..96a7300ee7de6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
     if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
       opsToSimplify.push_back(op);
   });
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
-  (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
+  (void)applyOpPatternsGreedily(
+      opsToSimplify, frozenPatterns,
+      GreedyRewriteConfig().setStrictMode(
+          GreedyRewriteStrictness::ExistingAndNewOps));
 }
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index dd539ff685653..e52196b892dc5 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
         // Simplify/canonicalize the affine.for.
         RewritePatternSet patterns(res.getContext());
         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
-        GreedyRewriteConfig config;
-        config.strictMode = GreedyRewriteStrictness::ExistingOps;
         bool erased;
-        (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
-                                      config, /*changed=*/nullptr, &erased);
+        (void)applyOpPatternsGreedily(
+            res.getOperation(), std::move(patterns),
+            GreedyRewriteConfig().setStrictMode(
+                GreedyRewriteStrictness::ExistingAndNewOps),
+            /*changed=*/nullptr, &erased);
         if (!erased && !prologue)
           prologue = res;
         if (!erased)
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f866c91ef6e39..72b7f09344384 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -490,10 +490,9 @@ struct IntRangeOptimizationsPass final
     RewritePatternSet patterns(ctx);
     populateIntRangeOptimizationsPatterns(patterns, solver);
 
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+    if (failed(applyPatternsGreedily(
+            op, std::move(patterns),
+            GreedyRewriteConfig().setListener(&listener))))
       signalPassFailure();
   }
 };
@@ -516,13 +515,12 @@ struct IntRangeNarrowingPass final
     RewritePatternSet patterns(ctx);
     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
 
-    GreedyRewriteConfig config;
     // We specifically need bottom-up traversal as cmpi pattern needs range
     // data, attached to its original argument values.
-    config.useTopDownTraversal = false;
-    config.listener = &listener;
-
-    if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+    if (failed(applyPatternsGreedily(
+            op, std::move(patterns),
+            GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
+                &listener))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 35f86a62ae592..a60fd50b6c7b9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    
+    populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
     // We don't want that the block structure changes invalidating the
-    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
     // region simplification
-    GreedyRewriteConfig config;
-    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
-    populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
-
-    if (failed(
-            applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+    if (failed(applyPatternsGreedily(
+            getOperation(), std::move(patterns),
+            GreedyRewriteConfig().setEnableRegionSimplification(
+                GreedySimplifyRegionLevel::Normal))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ef7b4757a04b4..a6d05a65f68e4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3523,9 +3523,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
   vector::populateVectorStepLoweringPatterns(patterns);
 
   TrackingListener listener(state, *this);
-  GreedyRewriteConfig config;
-  config.listener = &listener;
-  if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
+  if (failed(
+          applyPatternsGreedily(target, std::move(patterns),
+                                GreedyRewriteConfig().setListener(&listener))))
     return emitDefaultDefiniteFailure(target);
 
   results.push_back(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 927ce066038d8..5e2b9877f90eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2325,10 +2325,9 @@ struct LinalgElementwiseOpFusionPass
     // Add constant folding patterns.
     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
 
-    // Use TopDownTraversal for compile time reasons
-    GreedyRewriteConfig grc;
-    grc.useTopDownTraversal = true;
-    (void)applyPatternsGreedily(op, std::move(patterns), grc);
+    // Use TopDownTraversal for compile time reasons.
+    (void)applyPatternsGreedily(op, std::move(patterns),
+                                GreedyRewriteConfig().setUseTopDownTraversal());
   }
 };
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index af87fb7a79d04..e02048944c6f3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
   if (!patterns)
     return success();
 
-  GreedyRewriteConfig config;
-  config.listener = this;
-  config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
-  return applyOpPatternsGreedily(ops, patterns.value(), config);
+  return applyOpPatternsGreedily(
+      ops, patterns.value(),
+      GreedyRewriteConfig().setListener(this).setStrictMode(
+          GreedyRewriteStrictness::ExistingAndNewOps));
 }
 
 void SliceTrackingListener::notifyOperationInserted(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index a60410d01ac57..15a988e93c479 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1352,9 +1352,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
   // We only want to apply signature conversion once to the existing func ops.
   // Without specifying strictMode, the greedy pattern rewriter will keep
   // looking for newly created func ops.
-  GreedyRewriteConfig config;
-  config.strictMode = GreedyRewriteStrictness::ExistingOps;
-  return applyPatternsGreedily(op, std::move(patterns), config);
+  return applyPatternsGreedily(op, std::move(patterns),
+                               GreedyRewriteConfig().setStrictMode(
+                                   GreedyRewriteStrictness::ExistingOps));
 }
 
 LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/132253


More information about the Mlir-commits mailing list