[Mlir-commits] [mlir] [mlir][c] Enable creating and setting greedy rewrite confing. (PR #162429)

Jacques Pienaar llvmlistbot at llvm.org
Tue Oct 7 23:14:54 PDT 2025


https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/162429

Done very mechanically.

>From e7665936a84c8f6b7a286b7900cecce3572a5bf4 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Wed, 8 Oct 2025 05:54:41 +0000
Subject: [PATCH] [mlir][c] Enable creating and setting greedy rewrite confing.

---
 mlir/include/mlir-c/Rewrite.h        |  98 +++++++++++++++++-
 mlir/lib/Bindings/Python/Rewrite.cpp | 118 ++++++++++++++++++++++
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 145 ++++++++++++++++++++++++++-
 mlir/test/CAPI/rewrite.c             |  47 +++++++++
 mlir/test/python/rewrite.py          | 107 ++++++++++++++++++++
 5 files changed, 510 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/python/rewrite.py

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 77be1f480eacf..20e078a3c1e81 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -36,6 +36,26 @@ extern "C" {
 DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+
+/// Greedy rewrite strictness levels.
+typedef enum {
+  /// No restrictions wrt. which ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
+  /// Only pre-existing and newly created ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
+  /// Only pre-existing ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
+} MlirGreedyRewriteStrictness;
+
+/// Greedy simplify region levels.
+typedef enum {
+  /// Disable region control-flow simplification.
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
+  /// Run the normal simplification (e.g. dead args elimination).
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
+  /// Run extra simplifications (e.g. block merging).
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
+} MlirGreedySimplifyRegionLevel;
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 
@@ -308,7 +328,83 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
 
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
-    MlirGreedyRewriteDriverConfig);
+    MlirGreedyRewriteDriverConfig config);
+
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+/// Creates a greedy rewrite driver configuration with default settings.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig
+mlirGreedyRewriteDriverConfigCreate();
+
+/// Destroys a greedy rewrite driver configuration.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config);
+
+/// Sets the maximum number of iterations for the greedy rewrite driver.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(
+    MlirGreedyRewriteDriverConfig config, int64_t maxIterations);
+
+/// Sets the maximum number of rewrites within an iteration.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites);
+
+/// Sets whether to use top-down traversal for the initial population of the
+/// worklist.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal);
+
+/// Enables or disables folding during greedy rewriting.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(
+    MlirGreedyRewriteDriverConfig config, bool enable);
+
+/// Sets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedyRewriteStrictness strictness);
+
+/// Sets the region simplification level.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedySimplifyRegionLevel level);
+
+/// Enables or disables constant CSE.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+    MlirGreedyRewriteDriverConfig config, bool enable);
+
+/// Gets the maximum number of iterations for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the maximum number of rewrites within an iteration.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether top-down traversal is used for initial worklist population.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether folding is enabled during greedy rewriting.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness
+mlirGreedyRewriteDriverConfigGetStrictness(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the region simplification level.
+MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether constant CSE is enabled.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+    MlirGreedyRewriteDriverConfig config);
 
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 836f44fd7d4be..6908e9423e5a3 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -139,10 +139,96 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+/// Owning Wrapper around a GreedyRewriteDriverConfig.
+class PyGreedyRewriteDriverConfig {
+public:
+  PyGreedyRewriteDriverConfig() 
+      : config(mlirGreedyRewriteDriverConfigCreate()) {}
+  PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
+      : config(other.config) {
+    other.config.ptr = nullptr;
+  }
+  ~PyGreedyRewriteDriverConfig() {
+    if (config.ptr != nullptr)
+      mlirGreedyRewriteDriverConfigDestroy(config);
+  }
+  MlirGreedyRewriteDriverConfig get() { return config; }
+
+  void setMaxIterations(int64_t maxIterations) {
+    mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+  }
+  
+  void setMaxNumRewrites(int64_t maxNumRewrites) {
+    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+  }
+  
+  void setUseTopDownTraversal(bool useTopDownTraversal) {
+    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, useTopDownTraversal);
+  }
+  
+  void enableFolding(bool enable) {
+    mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+  }
+  
+  void setStrictness(MlirGreedyRewriteStrictness strictness) {
+    mlirGreedyRewriteDriverConfigSetStrictness(config, strictness);
+  }
+  
+  void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) {
+    mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(config, level);
+  }
+  
+  void enableConstantCSE(bool enable) {
+    mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+  }
+
+  int64_t getMaxIterations() {
+    return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+  }
+  
+  int64_t getMaxNumRewrites() {
+    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+  }
+  
+  bool getUseTopDownTraversal() {
+    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+  }
+  
+  bool isFoldingEnabled() {
+    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+  }
+  
+  MlirGreedyRewriteStrictness getStrictness() {
+    return mlirGreedyRewriteDriverConfigGetStrictness(config);
+  }
+  
+  MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() {
+    return mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config);
+  }
+  
+  bool isConstantCSEEnabled() {
+    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+  }
+
+private:
+  MlirGreedyRewriteDriverConfig config;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  // Enum definitions
+  nb::enum_<MlirGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
+      .value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP)
+      .value("EXISTING_AND_NEW_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS)
+      .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+      
+  nb::enum_<MlirGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
+      .value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED)
+      .value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL)
+      .value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE);
+
   nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
@@ -228,6 +314,38 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           },
           nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  
+  nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
+      .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
+      .def("set_max_iterations", &PyGreedyRewriteDriverConfig::setMaxIterations,
+           "max_iterations"_a, "Set maximum number of iterations")
+      .def("set_max_num_rewrites", &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
+           "max_num_rewrites"_a, "Set maximum number of rewrites per iteration")
+      .def("set_use_top_down_traversal", &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
+           "use_top_down"_a, "Set whether to use top-down traversal")
+      .def("enable_folding", &PyGreedyRewriteDriverConfig::enableFolding,
+           "enable"_a, "Enable or disable folding")
+      .def("set_strictness", &PyGreedyRewriteDriverConfig::setStrictness,
+           "strictness"_a, "Set rewrite strictness level")
+      .def("set_region_simplification_level", &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
+           "level"_a, "Set region simplification level")
+      .def("enable_constant_cse", &PyGreedyRewriteDriverConfig::enableConstantCSE,
+           "enable"_a, "Enable or disable constant CSE")
+      .def("get_max_iterations", &PyGreedyRewriteDriverConfig::getMaxIterations,
+           "Get maximum number of iterations")
+      .def("get_max_num_rewrites", &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
+           "Get maximum number of rewrites per iteration")
+      .def("get_use_top_down_traversal", &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
+           "Get whether top-down traversal is used")
+      .def("is_folding_enabled", &PyGreedyRewriteDriverConfig::isFoldingEnabled,
+           "Check if folding is enabled")
+      .def("get_strictness", &PyGreedyRewriteDriverConfig::getStrictness,
+           "Get rewrite strictness level")
+      .def("get_region_simplification_level", &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
+           "Get region simplification level")
+      .def("is_constant_cse_enabled", &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
+           "Check if constant CSE is enabled");
+
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
                    &PyFrozenRewritePatternSet::getCapsule)
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308cadf83..0741308e19077 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -289,18 +289,155 @@ void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
   op.ptr = nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) {
+  assert(config.ptr && "unexpected null config");
+  return static_cast<mlir::GreedyRewriteConfig *>(config.ptr);
+}
+
+inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) {
+  return {config};
+}
+
+MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() {
+  return wrap(new mlir::GreedyRewriteConfig());
+}
+
+void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config) {
+  delete unwrap(config);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxIterations(
+    MlirGreedyRewriteDriverConfig config, int64_t maxIterations) {
+  unwrap(config)->setMaxIterations(maxIterations);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) {
+  unwrap(config)->setMaxNumRewrites(maxNumRewrites);
+}
+
+void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) {
+  unwrap(config)->setUseTopDownTraversal(useTopDownTraversal);
+}
+
+void mlirGreedyRewriteDriverConfigEnableFolding(
+    MlirGreedyRewriteDriverConfig config, bool enable) {
+  unwrap(config)->enableFolding(enable);
+}
+
+void mlirGreedyRewriteDriverConfigSetStrictness(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedyRewriteStrictness strictness) {
+  mlir::GreedyRewriteStrictness cppStrictness;
+  switch (strictness) {
+  case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP:
+    cppStrictness = mlir::GreedyRewriteStrictness::AnyOp;
+    break;
+  case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS:
+    cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
+    break;
+  case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS:
+    cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps;
+    break;
+  }
+  unwrap(config)->setStrictness(cppStrictness);
+}
+
+void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedySimplifyRegionLevel level) {
+  mlir::GreedySimplifyRegionLevel cppLevel;
+  switch (level) {
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Disabled;
+    break;
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Normal;
+    break;
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive;
+    break;
+  }
+  unwrap(config)->setRegionSimplificationLevel(cppLevel);
+}
+
+void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+    MlirGreedyRewriteDriverConfig config, bool enable) {
+  unwrap(config)->enableConstantCSE(enable);
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getMaxIterations();
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getMaxNumRewrites();
+}
+
+bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getUseTopDownTraversal();
+}
+
+bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->isFoldingEnabled();
+}
+
+MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(
+    MlirGreedyRewriteDriverConfig config) {
+  mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness();
+  switch (cppStrictness) {
+  case mlir::GreedyRewriteStrictness::AnyOp:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP;
+  case mlir::GreedyRewriteStrictness::ExistingAndNewOps:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS;
+  case mlir::GreedyRewriteStrictness::ExistingOps:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS;
+  }
+}
+
+MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config) {
+  mlir::GreedySimplifyRegionLevel cppLevel =
+      unwrap(config)->getRegionSimplificationLevel();
+  switch (cppLevel) {
+  case mlir::GreedySimplifyRegionLevel::Disabled:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED;
+  case mlir::GreedySimplifyRegionLevel::Normal:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL;
+  case mlir::GreedySimplifyRegionLevel::Aggressive:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE;
+  }
+}
+
+bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->isConstantCSEEnabled();
+}
+
 MlirLogicalResult
 mlirApplyPatternsAndFoldGreedily(MlirModule op,
                                  MlirFrozenRewritePatternSet patterns,
-                                 MlirGreedyRewriteDriverConfig) {
-  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+                                 MlirGreedyRewriteDriverConfig config) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+                                          *unwrap(config)));
 }
 
 MlirLogicalResult
 mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
                                        MlirFrozenRewritePatternSet patterns,
-                                       MlirGreedyRewriteDriverConfig) {
-  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+                                       MlirGreedyRewriteDriverConfig config) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+                                          *unwrap(config)));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c
index b33d225767046..0745eb496c1d7 100644
--- a/mlir/test/CAPI/rewrite.c
+++ b/mlir/test/CAPI/rewrite.c
@@ -534,6 +534,52 @@ void testReplaceUses(MlirContext ctx) {
   mlirModuleDestroy(module);
 }
 
+void testGreedyRewriteDriverConfig(MlirContext ctx) {
+  // CHECK-LABEL: @testGreedyRewriteDriverConfig
+  fprintf(stderr, "@testGreedyRewriteDriverConfig\n");
+
+  // Test config creation and destruction
+  MlirGreedyRewriteDriverConfig config = mlirGreedyRewriteDriverConfigCreate();
+
+  // Test all configuration setters
+  mlirGreedyRewriteDriverConfigSetMaxIterations(config, 5);
+  mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, 100);
+  mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, true);
+  mlirGreedyRewriteDriverConfigEnableFolding(config, false);
+  mlirGreedyRewriteDriverConfigSetStrictness(
+      config, MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+  mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+      config, MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL);
+  mlirGreedyRewriteDriverConfigEnableConstantCSE(config, false);
+
+  // Test all configuration getters and verify values
+  // CHECK: MaxIterations: 5
+  fprintf(stderr, "MaxIterations: %ld\n",
+          mlirGreedyRewriteDriverConfigGetMaxIterations(config));
+  // CHECK: MaxNumRewrites: 100
+  fprintf(stderr, "MaxNumRewrites: %ld\n",
+          mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config));
+  // CHECK: UseTopDownTraversal: 1
+  fprintf(stderr, "UseTopDownTraversal: %d\n",
+          mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config));
+  // CHECK: FoldingEnabled: 0
+  fprintf(stderr, "FoldingEnabled: %d\n",
+          mlirGreedyRewriteDriverConfigIsFoldingEnabled(config));
+  // CHECK: Strictness: 2
+  fprintf(stderr, "Strictness: %d\n",
+          mlirGreedyRewriteDriverConfigGetStrictness(config));
+  // CHECK: RegionSimplificationLevel: 1
+  fprintf(stderr, "RegionSimplificationLevel: %d\n",
+          mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+  // CHECK: ConstantCSEEnabled: 0
+  fprintf(stderr, "ConstantCSEEnabled: %d\n",
+          mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config));
+
+  // CHECK: Config test completed successfully
+  fprintf(stderr, "Config test completed successfully\n");
+  mlirGreedyRewriteDriverConfigDestroy(config);
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   mlirContextSetAllowUnregisteredDialects(ctx, true);
@@ -547,6 +593,7 @@ int main(void) {
   testMove(ctx);
   testOpModification(ctx);
   testReplaceUses(ctx);
+  testGreedyRewriteDriverConfig(ctx);
 
   mlirContextDestroy(ctx);
   return 0;
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6f7deadd3cbba
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,107 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+from mlir.rewrite import *
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  return f
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation
+ at run
+def testGreedyRewriteDriverConfigCreation():
+  # Test basic config creation and destruction
+  config = GreedyRewriteDriverConfig()
+  # CHECK: Config created successfully
+  print("Config created successfully")
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters
+ at run
+def testGreedyRewriteDriverConfigGetters():
+  config = GreedyRewriteDriverConfig()
+
+  # Set some values
+  config.set_max_iterations(5)
+  config.set_max_num_rewrites(50)
+  config.set_use_top_down_traversal(True)
+  config.enable_folding(False)
+  config.set_strictness(GreedyRewriteStrictness.EXISTING_AND_NEW_OPS)
+  config.set_region_simplification_level(GreedySimplifyRegionLevel.AGGRESSIVE)
+  config.enable_constant_cse(True)
+
+  # Test all getter methods and print results
+  # CHECK: max_iterations: 5
+  max_iterations = config.get_max_iterations()
+  print(f"max_iterations: {max_iterations}")
+  # CHECK: max_rewrites: 50
+  max_rewrites = config.get_max_num_rewrites()
+  print(f"max_rewrites: {max_rewrites}")
+  # CHECK: use_top_down: True
+  use_top_down = config.get_use_top_down_traversal()
+  print(f"use_top_down: {use_top_down}")
+  # CHECK: folding_enabled: False
+  folding_enabled = config.is_folding_enabled()
+  print(f"folding_enabled: {folding_enabled}")
+  # CHECK: strictness: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+  strictness = config.get_strictness()
+  print(f"strictness: {strictness}")
+  # CHECK: region_level: GreedySimplifyRegionLevel.AGGRESSIVE
+  region_level = config.get_region_simplification_level()
+  print(f"region_level: {region_level}")
+  # CHECK: cse_enabled: True
+  cse_enabled = config.is_constant_cse_enabled()
+  print(f"cse_enabled: {cse_enabled}")
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
+ at run
+def testGreedyRewriteStrictnessEnum():
+  config = GreedyRewriteDriverConfig()
+
+  # Test ANY_OP
+  # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
+  config.set_strictness(GreedyRewriteStrictness.ANY_OP)
+  strictness = config.get_strictness()
+  print(f"strictness ANY_OP: {strictness}")
+
+  # Test EXISTING_AND_NEW_OPS
+  # CHECK: strictness EXISTING_AND_NEW_OPS: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+  config.set_strictness(GreedyRewriteStrictness.EXISTING_AND_NEW_OPS)
+  strictness = config.get_strictness()
+  print(f"strictness EXISTING_AND_NEW_OPS: {strictness}")
+
+  # Test EXISTING_OPS
+  # CHECK: strictness EXISTING_OPS: GreedyRewriteStrictness.EXISTING_OPS
+  config.set_strictness(GreedyRewriteStrictness.EXISTING_OPS)
+  strictness = config.get_strictness()
+  print(f"strictness EXISTING_OPS: {strictness}")
+
+
+# CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
+ at run
+def testGreedySimplifyRegionLevelEnum():
+  config = GreedyRewriteDriverConfig()
+
+  # Test DISABLED
+  # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
+  config.set_region_simplification_level(GreedySimplifyRegionLevel.DISABLED)
+  level = config.get_region_simplification_level()
+  print(f"region_level DISABLED: {level}")
+
+  # Test NORMAL
+  # CHECK: region_level NORMAL: GreedySimplifyRegionLevel.NORMAL
+  config.set_region_simplification_level(GreedySimplifyRegionLevel.NORMAL)
+  level = config.get_region_simplification_level()
+  print(f"region_level NORMAL: {level}")
+
+  # Test AGGRESSIVE
+  # CHECK: region_level AGGRESSIVE: GreedySimplifyRegionLevel.AGGRESSIVE
+  config.set_region_simplification_level(GreedySimplifyRegionLevel.AGGRESSIVE)
+  level = config.get_region_simplification_level()
+  print(f"region_level AGGRESSIVE: {level}")



More information about the Mlir-commits mailing list