[Mlir-commits] [mlir] [MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily (PR #174913)

Maksim Levental llvmlistbot at llvm.org
Wed Jan 7 22:06:35 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/174913

>From 8be07893c123f343283efff6d84901974f46f31d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 7 Jan 2026 22:05:03 -0800
Subject: [PATCH] [MLIR][Python] Add GreedyRewriteDriverConfig parameter to
 apply_patterns_and_fold_greedily

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 84 +++++++++++++++++-----------
 mlir/test/python/rewrite.py          | 44 +++++++++++++++
 2 files changed, 96 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9830c277ac147..fbb4c7cb40dbf 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -250,80 +250,95 @@ enum PyGreedySimplifyRegionLevel : std::underlying_type_t<
 class PyGreedyRewriteDriverConfig {
 public:
   PyGreedyRewriteDriverConfig()
-      : config(mlirGreedyRewriteDriverConfigCreate()) {}
+      : config(mlirGreedyRewriteDriverConfigCreate().ptr, customDeleter) {}
   PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
-      : config(other.config) {
-    other.config.ptr = nullptr;
-  }
-  ~PyGreedyRewriteDriverConfig() {
-    if (config.ptr != nullptr)
-      mlirGreedyRewriteDriverConfigDestroy(config);
+      : config(std::move(other.config)) {}
+  PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
+      : config(other.config) {}
+
+  MlirGreedyRewriteDriverConfig get() {
+    return MlirGreedyRewriteDriverConfig{config.get()};
   }
-  MlirGreedyRewriteDriverConfig get() { return config; }
 
   void setMaxIterations(int64_t maxIterations) {
-    mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+    mlirGreedyRewriteDriverConfigSetMaxIterations(
+        MlirGreedyRewriteDriverConfig{config.get()}, maxIterations);
   }
 
   void setMaxNumRewrites(int64_t maxNumRewrites) {
-    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+        MlirGreedyRewriteDriverConfig{config.get()}, maxNumRewrites);
   }
 
   void setUseTopDownTraversal(bool useTopDownTraversal) {
-    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
-                                                        useTopDownTraversal);
+    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+        MlirGreedyRewriteDriverConfig{config.get()}, useTopDownTraversal);
   }
 
   void enableFolding(bool enable) {
-    mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+    mlirGreedyRewriteDriverConfigEnableFolding(
+        MlirGreedyRewriteDriverConfig{config.get()}, enable);
   }
 
   void setStrictness(PyGreedyRewriteStrictness strictness) {
     mlirGreedyRewriteDriverConfigSetStrictness(
-        config, static_cast<MlirGreedyRewriteStrictness>(strictness));
+        MlirGreedyRewriteDriverConfig{config.get()},
+        static_cast<MlirGreedyRewriteStrictness>(strictness));
   }
 
   void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
     mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
-        config, static_cast<MlirGreedySimplifyRegionLevel>(level));
+        MlirGreedyRewriteDriverConfig{config.get()},
+        static_cast<MlirGreedySimplifyRegionLevel>(level));
   }
 
   void enableConstantCSE(bool enable) {
-    mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+    mlirGreedyRewriteDriverConfigEnableConstantCSE(
+        MlirGreedyRewriteDriverConfig{config.get()}, enable);
   }
 
   int64_t getMaxIterations() {
-    return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+    return mlirGreedyRewriteDriverConfigGetMaxIterations(
+        MlirGreedyRewriteDriverConfig{config.get()});
   }
 
   int64_t getMaxNumRewrites() {
-    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+        MlirGreedyRewriteDriverConfig{config.get()});
   }
 
   bool getUseTopDownTraversal() {
-    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+        MlirGreedyRewriteDriverConfig{config.get()});
   }
 
   bool isFoldingEnabled() {
-    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+        MlirGreedyRewriteDriverConfig{config.get()});
   }
 
   PyGreedyRewriteStrictness getStrictness() {
     return static_cast<PyGreedyRewriteStrictness>(
-        mlirGreedyRewriteDriverConfigGetStrictness(config));
+        mlirGreedyRewriteDriverConfigGetStrictness(
+            MlirGreedyRewriteDriverConfig{config.get()}));
   }
 
   PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
     return static_cast<PyGreedySimplifyRegionLevel>(
-        mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+        mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+            MlirGreedyRewriteDriverConfig{config.get()}));
   }
 
   bool isConstantCSEEnabled() {
-    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+        MlirGreedyRewriteDriverConfig{config.get()});
   }
 
 private:
-  MlirGreedyRewriteDriverConfig config;
+  std::shared_ptr<void> config;
+  static void customDeleter(void *c) {
+    mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
+  }
 };
 
 /// Create the `mlir.rewrite` here.
@@ -509,26 +524,31 @@ void populateRewriteSubmodule(nb::module_ &m) {
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
        "apply_patterns_and_fold_greedily",
-       [](PyModule &module, PyFrozenRewritePatternSet &set) {
-         auto status = mlirApplyPatternsAndFoldGreedily(
-             module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
+       [](PyModule &module, PyFrozenRewritePatternSet &set,
+          std::optional<PyGreedyRewriteDriverConfig> config) {
+         MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
+             module.get(), set.get(),
+             config.has_value() ? config->get()
+                                : mlirGreedyRewriteDriverConfigCreate());
          if (mlirLogicalResultIsFailure(status))
            throw std::runtime_error("pattern application failed to converge");
        },
-       "module"_a, "set"_a,
+       "module"_a, "set"_a, "config"_a = nb::none(),
        "Applys the given patterns to the given module greedily while folding "
        "results.")
       .def(
           "apply_patterns_and_fold_greedily",
-          [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
-            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+          [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
+             std::optional<PyGreedyRewriteDriverConfig> config) {
+            MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
                 op.getOperation(), set.get(),
-                mlirGreedyRewriteDriverConfigCreate());
+                config.has_value() ? config->get()
+                                   : mlirGreedyRewriteDriverConfigCreate());
             if (mlirLogicalResultIsFailure(status))
               throw std::runtime_error(
                   "pattern application failed to converge");
           },
-          "op"_a, "set"_a,
+          "op"_a, "set"_a, "config"_a = nb::none(),
           "Applys the given patterns to the given op greedily while folding "
           "results.")
       .def(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index a2fbbde38b8c0..43e9b761a0ea2 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
     config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
     level = config.region_simplification_level
     print(f"region_level AGGRESSIVE: {level}")
+
+
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+ at run
+def testRewriteWithGreedyRewriteDriverConfig():
+    def constant_1_to_2(op, rewriter):
+        c = op.value.value
+        if c != 1:
+            return True  # failed to match
+        with rewriter.ip:
+            new_op = arith.constant(op.type, 2, loc=op.location)
+        rewriter.replace_op(op, [new_op])
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.ConstantOp, constant_1_to_2)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @const() -> (i64, i64) {
+                %0 = arith.constant 1 : i64
+                %1 = arith.constant 1 : i64
+                return %0, %1 : i64, i64
+              }
+            }
+            """
+        )
+
+        config = GreedyRewriteDriverConfig()
+        config.enable_constant_cse = False
+        apply_patterns_and_fold_greedily(module, frozen, config)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: %c2_i64_0 = arith.constant 2 : i64
+        # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
+        print(module)
+
+        config = GreedyRewriteDriverConfig()
+        config.enable_constant_cse = True
+        apply_patterns_and_fold_greedily(module, frozen, config)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: return %c2_i64, %c2_i64 : i64
+        print(module)



More information about the Mlir-commits mailing list