[Mlir-commits] [mlir] [mlir][c] Enable creating and setting greedy rewrite confing. (PR #162429)

Jacques Pienaar llvmlistbot at llvm.org
Thu Jan 1 20:34:42 PST 2026


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/162429

>From 4f7dfcc8539b96cdbdd080bf5f8b82f7f4e79f4d Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Mon, 13 Oct 2025 03:51:57 +0000
Subject: [PATCH 1/2] [mlir][c] Enable creating and setting greedy rewrite
 confing.

---
 mlir/include/mlir-c/Rewrite.h        |  98 +++++++++++++++++-
 mlir/lib/Bindings/Python/Rewrite.cpp | 119 ++++++++++++++++++++++
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 145 ++++++++++++++++++++++++++-
 mlir/test/CAPI/rewrite.c             |  47 +++++++++
 mlir/test/python/rewrite.py          |  99 +++++++++++++++++-
 5 files changed, 502 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 9e2685719fe4d..dec533882aeea 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -36,6 +36,26 @@ extern "C" {
 DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+
+/// Greedy rewrite strictness levels.
+typedef enum {
+  /// No restrictions wrt. which ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
+  /// Only pre-existing and newly created ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
+  /// Only pre-existing ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
+} MlirGreedyRewriteStrictness;
+
+/// Greedy simplify region levels.
+typedef enum {
+  /// Disable region control-flow simplification.
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
+  /// Run the normal simplification (e.g. dead args elimination).
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
+  /// Run extra simplifications (e.g. block merging).
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
+} MlirGreedySimplifyRegionLevel;
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
@@ -319,7 +339,83 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
 
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
-    MlirGreedyRewriteDriverConfig);
+    MlirGreedyRewriteDriverConfig config);
+
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+/// Creates a greedy rewrite driver configuration with default settings.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig
+mlirGreedyRewriteDriverConfigCreate();
+
+/// Destroys a greedy rewrite driver configuration.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config);
+
+/// Sets the maximum number of iterations for the greedy rewrite driver.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(
+    MlirGreedyRewriteDriverConfig config, int64_t maxIterations);
+
+/// Sets the maximum number of rewrites within an iteration.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites);
+
+/// Sets whether to use top-down traversal for the initial population of the
+/// worklist.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal);
+
+/// Enables or disables folding during greedy rewriting.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config,
+                                           bool enable);
+
+/// Sets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedyRewriteStrictness strictness);
+
+/// Sets the region simplification level.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level);
+
+/// Enables or disables constant CSE.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+    MlirGreedyRewriteDriverConfig config, bool enable);
+
+/// Gets the maximum number of iterations for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the maximum number of rewrites within an iteration.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether top-down traversal is used for initial worklist population.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether folding is enabled during greedy rewriting.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness
+mlirGreedyRewriteDriverConfigGetStrictness(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the region simplification level.
+MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether constant CSE is enabled.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+    MlirGreedyRewriteDriverConfig config);
 
 /// Applies the given patterns to the given op by a fast walk-based pattern
 /// rewrite driver.
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0df9d0cbc7ffc..97da49bb6aac8 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -223,10 +223,98 @@ class PyRewritePatternSet {
   MlirContext ctx;
 };
 
+/// Owning Wrapper around a GreedyRewriteDriverConfig.
+class PyGreedyRewriteDriverConfig {
+public:
+  PyGreedyRewriteDriverConfig()
+      : config(mlirGreedyRewriteDriverConfigCreate()) {}
+  PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
+      : config(other.config) {
+    other.config.ptr = nullptr;
+  }
+  ~PyGreedyRewriteDriverConfig() {
+    if (config.ptr != nullptr)
+      mlirGreedyRewriteDriverConfigDestroy(config);
+  }
+  MlirGreedyRewriteDriverConfig get() { return config; }
+
+  void setMaxIterations(int64_t maxIterations) {
+    mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+  }
+
+  void setMaxNumRewrites(int64_t maxNumRewrites) {
+    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+  }
+
+  void setUseTopDownTraversal(bool useTopDownTraversal) {
+    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
+                                                        useTopDownTraversal);
+  }
+
+  void enableFolding(bool enable) {
+    mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+  }
+
+  void setStrictness(MlirGreedyRewriteStrictness strictness) {
+    mlirGreedyRewriteDriverConfigSetStrictness(config, strictness);
+  }
+
+  void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) {
+    mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(config, level);
+  }
+
+  void enableConstantCSE(bool enable) {
+    mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+  }
+
+  int64_t getMaxIterations() {
+    return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+  }
+
+  int64_t getMaxNumRewrites() {
+    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+  }
+
+  bool getUseTopDownTraversal() {
+    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+  }
+
+  bool isFoldingEnabled() {
+    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+  }
+
+  MlirGreedyRewriteStrictness getStrictness() {
+    return mlirGreedyRewriteDriverConfigGetStrictness(config);
+  }
+
+  MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() {
+    return mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config);
+  }
+
+  bool isConstantCSEEnabled() {
+    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+  }
+
+private:
+  MlirGreedyRewriteDriverConfig config;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  // Enum definitions
+  nb::enum_<MlirGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
+      .value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP)
+      .value("EXISTING_AND_NEW_OPS",
+             MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS)
+      .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+
+  nb::enum_<MlirGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
+      .value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED)
+      .value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL)
+      .value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE);
+
   //----------------------------------------------------------------------------
   // Mapping of the PatternRewriter
   //----------------------------------------------------------------------------
@@ -373,6 +461,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           },
           nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+  nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
+      .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
+      .def_prop_rw("max_iterations",
+                   &PyGreedyRewriteDriverConfig::getMaxIterations,
+                   &PyGreedyRewriteDriverConfig::setMaxIterations,
+                   "Maximum number of iterations")
+      .def_prop_rw("max_num_rewrites",
+                   &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
+                   &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
+                   "Maximum number of rewrites per iteration")
+      .def_prop_rw("use_top_down_traversal",
+                   &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
+                   &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
+                   "Whether to use top-down traversal")
+      .def_prop_rw("enable_folding",
+                   &PyGreedyRewriteDriverConfig::isFoldingEnabled,
+                   &PyGreedyRewriteDriverConfig::enableFolding,
+                   "Enable or disable folding")
+      .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
+                   &PyGreedyRewriteDriverConfig::setStrictness,
+                   "Rewrite strictness level")
+      .def_prop_rw("region_simplification_level",
+                   &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
+                   &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
+                   "Region simplification level")
+      .def_prop_rw("enable_constant_cse",
+                   &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
+                   &PyGreedyRewriteDriverConfig::enableConstantCSE,
+                   "Enable or disable constant CSE");
+
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
                    &PyFrozenRewritePatternSet::getCapsule)
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index fd4ae6ffc72be..798ca1de651c1 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -283,18 +283,155 @@ void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
   set.ptr = nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) {
+  assert(config.ptr && "unexpected null config");
+  return static_cast<mlir::GreedyRewriteConfig *>(config.ptr);
+}
+
+inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) {
+  return {config};
+}
+
+MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() {
+  return wrap(new mlir::GreedyRewriteConfig());
+}
+
+void mlirGreedyRewriteDriverConfigDestroy(
+    MlirGreedyRewriteDriverConfig config) {
+  delete unwrap(config);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxIterations(
+    MlirGreedyRewriteDriverConfig config, int64_t maxIterations) {
+  unwrap(config)->setMaxIterations(maxIterations);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) {
+  unwrap(config)->setMaxNumRewrites(maxNumRewrites);
+}
+
+void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) {
+  unwrap(config)->setUseTopDownTraversal(useTopDownTraversal);
+}
+
+void mlirGreedyRewriteDriverConfigEnableFolding(
+    MlirGreedyRewriteDriverConfig config, bool enable) {
+  unwrap(config)->enableFolding(enable);
+}
+
+void mlirGreedyRewriteDriverConfigSetStrictness(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedyRewriteStrictness strictness) {
+  mlir::GreedyRewriteStrictness cppStrictness;
+  switch (strictness) {
+  case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP:
+    cppStrictness = mlir::GreedyRewriteStrictness::AnyOp;
+    break;
+  case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS:
+    cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
+    break;
+  case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS:
+    cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps;
+    break;
+  }
+  unwrap(config)->setStrictness(cppStrictness);
+}
+
+void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level) {
+  mlir::GreedySimplifyRegionLevel cppLevel;
+  switch (level) {
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Disabled;
+    break;
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Normal;
+    break;
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive;
+    break;
+  }
+  unwrap(config)->setRegionSimplificationLevel(cppLevel);
+}
+
+void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+    MlirGreedyRewriteDriverConfig config, bool enable) {
+  unwrap(config)->enableConstantCSE(enable);
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getMaxIterations();
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getMaxNumRewrites();
+}
+
+bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getUseTopDownTraversal();
+}
+
+bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->isFoldingEnabled();
+}
+
+MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(
+    MlirGreedyRewriteDriverConfig config) {
+  mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness();
+  switch (cppStrictness) {
+  case mlir::GreedyRewriteStrictness::AnyOp:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP;
+  case mlir::GreedyRewriteStrictness::ExistingAndNewOps:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS;
+  case mlir::GreedyRewriteStrictness::ExistingOps:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS;
+  }
+}
+
+MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config) {
+  mlir::GreedySimplifyRegionLevel cppLevel =
+      unwrap(config)->getRegionSimplificationLevel();
+  switch (cppLevel) {
+  case mlir::GreedySimplifyRegionLevel::Disabled:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED;
+  case mlir::GreedySimplifyRegionLevel::Normal:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL;
+  case mlir::GreedySimplifyRegionLevel::Aggressive:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE;
+  }
+}
+
+bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->isConstantCSEEnabled();
+}
+
 MlirLogicalResult
 mlirApplyPatternsAndFoldGreedily(MlirModule op,
                                  MlirFrozenRewritePatternSet patterns,
-                                 MlirGreedyRewriteDriverConfig) {
-  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+                                 MlirGreedyRewriteDriverConfig config) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+                                          *unwrap(config)));
 }
 
 MlirLogicalResult
 mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
                                        MlirFrozenRewritePatternSet patterns,
-                                       MlirGreedyRewriteDriverConfig) {
-  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+                                       MlirGreedyRewriteDriverConfig config) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+                                          *unwrap(config)));
 }
 
 void mlirWalkAndApplyPatterns(MlirOperation op,
diff --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c
index b33d225767046..0745eb496c1d7 100644
--- a/mlir/test/CAPI/rewrite.c
+++ b/mlir/test/CAPI/rewrite.c
@@ -534,6 +534,52 @@ void testReplaceUses(MlirContext ctx) {
   mlirModuleDestroy(module);
 }
 
+void testGreedyRewriteDriverConfig(MlirContext ctx) {
+  // CHECK-LABEL: @testGreedyRewriteDriverConfig
+  fprintf(stderr, "@testGreedyRewriteDriverConfig\n");
+
+  // Test config creation and destruction
+  MlirGreedyRewriteDriverConfig config = mlirGreedyRewriteDriverConfigCreate();
+
+  // Test all configuration setters
+  mlirGreedyRewriteDriverConfigSetMaxIterations(config, 5);
+  mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, 100);
+  mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, true);
+  mlirGreedyRewriteDriverConfigEnableFolding(config, false);
+  mlirGreedyRewriteDriverConfigSetStrictness(
+      config, MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+  mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+      config, MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL);
+  mlirGreedyRewriteDriverConfigEnableConstantCSE(config, false);
+
+  // Test all configuration getters and verify values
+  // CHECK: MaxIterations: 5
+  fprintf(stderr, "MaxIterations: %ld\n",
+          mlirGreedyRewriteDriverConfigGetMaxIterations(config));
+  // CHECK: MaxNumRewrites: 100
+  fprintf(stderr, "MaxNumRewrites: %ld\n",
+          mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config));
+  // CHECK: UseTopDownTraversal: 1
+  fprintf(stderr, "UseTopDownTraversal: %d\n",
+          mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config));
+  // CHECK: FoldingEnabled: 0
+  fprintf(stderr, "FoldingEnabled: %d\n",
+          mlirGreedyRewriteDriverConfigIsFoldingEnabled(config));
+  // CHECK: Strictness: 2
+  fprintf(stderr, "Strictness: %d\n",
+          mlirGreedyRewriteDriverConfigGetStrictness(config));
+  // CHECK: RegionSimplificationLevel: 1
+  fprintf(stderr, "RegionSimplificationLevel: %d\n",
+          mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+  // CHECK: ConstantCSEEnabled: 0
+  fprintf(stderr, "ConstantCSEEnabled: %d\n",
+          mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config));
+
+  // CHECK: Config test completed successfully
+  fprintf(stderr, "Config test completed successfully\n");
+  mlirGreedyRewriteDriverConfigDestroy(config);
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   mlirContextSetAllowUnregisteredDialects(ctx, true);
@@ -547,6 +593,7 @@ int main(void) {
   testMove(ctx);
   testOpModification(ctx);
   testReplaceUses(ctx);
+  testGreedyRewriteDriverConfig(ctx);
 
   mlirContextDestroy(ctx);
   return 0;
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index e40d5eb92b86f..269ace8c0bba0 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -1,15 +1,17 @@
 # RUN: %PYTHON %s 2>&1 | FileCheck %s
 
+import gc
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith
-from mlir.rewrite import *
 
 
 def run(f):
     print("\nTEST:", f.__name__)
     f()
+    gc.collect()
+    return f
 
 
 # CHECK-LABEL: TEST: testRewritePattern
@@ -84,3 +86,98 @@ def constant_1_to_2(op, rewriter):
         # CHECK: %0 = arith.muli %arg0, %arg1 : i64
         # CHECK: return %0 : i64
         print(module)
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation
+ at run
+def testGreedyRewriteDriverConfigCreation():
+    # Test basic config creation and destruction
+    config = GreedyRewriteDriverConfig()
+    # CHECK: Config created successfully
+    print("Config created successfully")
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters
+ at run
+def testGreedyRewriteDriverConfigGetters():
+    config = GreedyRewriteDriverConfig()
+
+    # Set some values
+    config.max_iterations = 5
+    config.max_num_rewrites = 50
+    config.use_top_down_traversal = True
+    config.enable_folding = False
+    config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+    config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
+    config.enable_constant_cse = True
+
+    # Test all getter methods and print results
+    # CHECK: max_iterations: 5
+    max_iterations = config.max_iterations
+    print(f"max_iterations: {max_iterations}")
+    # CHECK: max_rewrites: 50
+    max_rewrites = config.max_num_rewrites
+    print(f"max_rewrites: {max_rewrites}")
+    # CHECK: use_top_down: True
+    use_top_down = config.use_top_down_traversal
+    print(f"use_top_down: {use_top_down}")
+    # CHECK: folding_enabled: False
+    folding_enabled = config.enable_folding
+    print(f"folding_enabled: {folding_enabled}")
+    # CHECK: strictness: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+    strictness = config.strictness
+    print(f"strictness: {strictness}")
+    # CHECK: region_level: GreedySimplifyRegionLevel.AGGRESSIVE
+    region_level = config.region_simplification_level
+    print(f"region_level: {region_level}")
+    # CHECK: cse_enabled: True
+    cse_enabled = config.enable_constant_cse
+    print(f"cse_enabled: {cse_enabled}")
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
+ at run
+def testGreedyRewriteStrictnessEnum():
+    config = GreedyRewriteDriverConfig()
+
+    # Test ANY_OP
+    # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
+    config.strictness = GreedyRewriteStrictness.ANY_OP
+    strictness = config.strictness
+    print(f"strictness ANY_OP: {strictness}")
+
+    # Test EXISTING_AND_NEW_OPS
+    # CHECK: strictness EXISTING_AND_NEW_OPS: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+    config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+    strictness = config.strictness
+    print(f"strictness EXISTING_AND_NEW_OPS: {strictness}")
+
+    # Test EXISTING_OPS
+    # CHECK: strictness EXISTING_OPS: GreedyRewriteStrictness.EXISTING_OPS
+    config.strictness = GreedyRewriteStrictness.EXISTING_OPS
+    strictness = config.strictness
+    print(f"strictness EXISTING_OPS: {strictness}")
+
+
+# CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
+ at run
+def testGreedySimplifyRegionLevelEnum():
+    config = GreedyRewriteDriverConfig()
+
+    # Test DISABLED
+    # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
+    config.region_simplification_level = GreedySimplifyRegionLevel.DISABLED
+    level = config.region_simplification_level
+    print(f"region_level DISABLED: {level}")
+
+    # Test NORMAL
+    # CHECK: region_level NORMAL: GreedySimplifyRegionLevel.NORMAL
+    config.region_simplification_level = GreedySimplifyRegionLevel.NORMAL
+    level = config.region_simplification_level
+    print(f"region_level NORMAL: {level}")
+
+    # Test AGGRESSIVE
+    # CHECK: region_level AGGRESSIVE: GreedySimplifyRegionLevel.AGGRESSIVE
+    config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
+    level = config.region_simplification_level
+    print(f"region_level AGGRESSIVE: {level}")

>From 9b33f416114ded015eeee5361b2ca3d887c7ae30 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Fri, 2 Jan 2026 04:34:27 +0000
Subject: [PATCH 2/2] Fix bad merge & create config rather than default
 constructor

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 13 +++++++------
 mlir/test/python/rewrite.py          |  1 +
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 97da49bb6aac8..c3a17d57d43ad 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -500,8 +500,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
   m.def(
        "apply_patterns_and_fold_greedily",
        [](PyModule &module, PyFrozenRewritePatternSet &set) {
-         auto status =
-             mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
+         auto status = mlirApplyPatternsAndFoldGreedily(
+             module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
          if (mlirLogicalResultIsFailure(status))
            throw std::runtime_error("pattern application failed to converge");
        },
@@ -514,8 +514,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "apply_patterns_and_fold_greedily",
           [](PyModule &module, MlirFrozenRewritePatternSet set) {
-            auto status =
-                mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+            auto status = mlirApplyPatternsAndFoldGreedily(
+                module.get(), set, mlirGreedyRewriteDriverConfigCreate());
             if (mlirLogicalResultIsFailure(status))
               throw std::runtime_error(
                   "pattern application failed to converge");
@@ -531,7 +531,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           "apply_patterns_and_fold_greedily",
           [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
             auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
-                op.getOperation(), set.get(), {});
+                op.getOperation(), set.get(),
+                mlirGreedyRewriteDriverConfigCreate());
             if (mlirLogicalResultIsFailure(status))
               throw std::runtime_error(
                   "pattern application failed to converge");
@@ -546,7 +547,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           "apply_patterns_and_fold_greedily",
           [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
             auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
-                op.getOperation(), set, {});
+                op.getOperation(), set, mlirGreedyRewriteDriverConfigCreate());
             if (mlirLogicalResultIsFailure(status))
               throw std::runtime_error(
                   "pattern application failed to converge");
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 269ace8c0bba0..c2af43201bc89 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -5,6 +5,7 @@
 from mlir.passmanager import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith
+from mlir.rewrite import *
 
 
 def run(f):



More information about the Mlir-commits mailing list