[Mlir-commits] [mlir] [MLIR][Python] Rename `GreedyRewriteDriverConfig` to `GreedyRewriteConfig` (PR #175409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 10 21:09:26 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

This is mainly for two purposes: 
1. to keep it consistent with the C++ class name `mlir::GreedyRewriteConfig`,
2. to make it shorter.

Since this type was only added a few days ago (654b3e844f21d3f64521e9cb028efdfebbf99bb4), it shouldn’t cause any obvious compatibility issues.


---
Full diff: https://github.com/llvm/llvm-project/pull/175409.diff


2 Files Affected:

- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+22-24) 
- (modified) mlir/test/python/rewrite.py (+12-12) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e143f118a1f01..2b649f79c5982 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -242,14 +242,14 @@ enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
 };
 
 /// Owning Wrapper around a GreedyRewriteDriverConfig.
-class PyGreedyRewriteDriverConfig {
+class PyGreedyRewriteConfig {
 public:
-  PyGreedyRewriteDriverConfig()
+  PyGreedyRewriteConfig()
       : config(mlirGreedyRewriteDriverConfigCreate().ptr,
-               PyGreedyRewriteDriverConfig::customDeleter) {}
-  PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
+               PyGreedyRewriteConfig::customDeleter) {}
+  PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
       : config(std::move(other.config)) {}
-  PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
+  PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
       : config(other.config) {}
 
   MlirGreedyRewriteDriverConfig get() {
@@ -470,34 +470,32 @@ void populateRewriteSubmodule(nb::module_ &m) {
           nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
-  nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
+  nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
       .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
-      .def_prop_rw("max_iterations",
-                   &PyGreedyRewriteDriverConfig::getMaxIterations,
-                   &PyGreedyRewriteDriverConfig::setMaxIterations,
+      .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
+                   &PyGreedyRewriteConfig::setMaxIterations,
                    "Maximum number of iterations")
       .def_prop_rw("max_num_rewrites",
-                   &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
-                   &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
+                   &PyGreedyRewriteConfig::getMaxNumRewrites,
+                   &PyGreedyRewriteConfig::setMaxNumRewrites,
                    "Maximum number of rewrites per iteration")
       .def_prop_rw("use_top_down_traversal",
-                   &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
-                   &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
+                   &PyGreedyRewriteConfig::getUseTopDownTraversal,
+                   &PyGreedyRewriteConfig::setUseTopDownTraversal,
                    "Whether to use top-down traversal")
-      .def_prop_rw("enable_folding",
-                   &PyGreedyRewriteDriverConfig::isFoldingEnabled,
-                   &PyGreedyRewriteDriverConfig::enableFolding,
+      .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
+                   &PyGreedyRewriteConfig::enableFolding,
                    "Enable or disable folding")
-      .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
-                   &PyGreedyRewriteDriverConfig::setStrictness,
+      .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
+                   &PyGreedyRewriteConfig::setStrictness,
                    "Rewrite strictness level")
       .def_prop_rw("region_simplification_level",
-                   &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
-                   &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
+                   &PyGreedyRewriteConfig::getRegionSimplificationLevel,
+                   &PyGreedyRewriteConfig::setRegionSimplificationLevel,
                    "Region simplification level")
       .def_prop_rw("enable_constant_cse",
-                   &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
-                   &PyGreedyRewriteDriverConfig::enableConstantCSE,
+                   &PyGreedyRewriteConfig::isConstantCSEEnabled,
+                   &PyGreedyRewriteConfig::enableConstantCSE,
                    "Enable or disable constant CSE");
 
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
@@ -508,7 +506,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
   m.def(
        "apply_patterns_and_fold_greedily",
        [](PyModule &module, PyFrozenRewritePatternSet &set,
-          std::optional<PyGreedyRewriteDriverConfig> config) {
+          std::optional<PyGreedyRewriteConfig> config) {
          MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
              module.get(), set.get(),
              config.has_value() ? config->get()
@@ -522,7 +520,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "apply_patterns_and_fold_greedily",
           [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
-             std::optional<PyGreedyRewriteDriverConfig> config) {
+             std::optional<PyGreedyRewriteConfig> config) {
             MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
                 op.getOperation(), set.get(),
                 config.has_value() ? config->get()
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 43e9b761a0ea2..8ef49981a8b3c 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -89,19 +89,19 @@ def constant_1_to_2(op, rewriter):
         print(module)
 
 
-# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation
+# CHECK-LABEL: TEST: testGreedyRewriteConfigCreation
 @run
-def testGreedyRewriteDriverConfigCreation():
+def testGreedyRewriteConfigCreation():
     # Test basic config creation and destruction
-    config = GreedyRewriteDriverConfig()
+    config = GreedyRewriteConfig()
     # CHECK: Config created successfully
     print("Config created successfully")
 
 
-# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters
+# CHECK-LABEL: TEST: testGreedyRewriteConfigGetters
 @run
-def testGreedyRewriteDriverConfigGetters():
-    config = GreedyRewriteDriverConfig()
+def testGreedyRewriteConfigGetters():
+    config = GreedyRewriteConfig()
 
     # Set some values
     config.max_iterations = 5
@@ -139,7 +139,7 @@ def testGreedyRewriteDriverConfigGetters():
 # CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
 @run
 def testGreedyRewriteStrictnessEnum():
-    config = GreedyRewriteDriverConfig()
+    config = GreedyRewriteConfig()
 
     # Test ANY_OP
     # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
@@ -163,7 +163,7 @@ def testGreedyRewriteStrictnessEnum():
 # CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
 @run
 def testGreedySimplifyRegionLevelEnum():
-    config = GreedyRewriteDriverConfig()
+    config = GreedyRewriteConfig()
 
     # Test DISABLED
     # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
@@ -184,9 +184,9 @@ def testGreedySimplifyRegionLevelEnum():
     print(f"region_level AGGRESSIVE: {level}")
 
 
-# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteConfig
 @run
-def testRewriteWithGreedyRewriteDriverConfig():
+def testRewriteWithGreedyRewriteConfig():
     def constant_1_to_2(op, rewriter):
         c = op.value.value
         if c != 1:
@@ -212,7 +212,7 @@ def constant_1_to_2(op, rewriter):
             """
         )
 
-        config = GreedyRewriteDriverConfig()
+        config = GreedyRewriteConfig()
         config.enable_constant_cse = False
         apply_patterns_and_fold_greedily(module, frozen, config)
         # CHECK: %c2_i64 = arith.constant 2 : i64
@@ -220,7 +220,7 @@ def constant_1_to_2(op, rewriter):
         # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
         print(module)
 
-        config = GreedyRewriteDriverConfig()
+        config = GreedyRewriteConfig()
         config.enable_constant_cse = True
         apply_patterns_and_fold_greedily(module, frozen, config)
         # CHECK: %c2_i64 = arith.constant 2 : i64

``````````

</details>


https://github.com/llvm/llvm-project/pull/175409


More information about the Mlir-commits mailing list