[Mlir-commits] [mlir] 94a9565 - [MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily (#174913)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 8 03:48:34 PST 2026
Author: Maksim Levental
Date: 2026-01-08T03:48:30-08:00
New Revision: 94a95659c2f6b88b72d59fd0fba2000ba8b1fee1
URL: https://github.com/llvm/llvm-project/commit/94a95659c2f6b88b72d59fd0fba2000ba8b1fee1
DIFF: https://github.com/llvm/llvm-project/commit/94a95659c2f6b88b72d59fd0fba2000ba8b1fee1.diff
LOG: [MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily (#174913)
We already have `GreedyRewriteDriverConfig` on the Python side, but it
hasn’t yet been exposed as a parameter of
`apply_patterns_and_fold_greedily`. This PR does that.
Before:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None
```
After:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
```
Note this PR is adapted from
https://github.com/llvm/llvm-project/pull/174785 but using
`std::optional` instead of `nb::object`. Note, this required refactoring
`PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it
could support a copy-ctor.
Co-authored-by: PragmaTwice <twice at apache.org>
Added:
Modified:
mlir/lib/Bindings/Python/Rewrite.cpp
mlir/test/python/rewrite.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 4b1ced572931d..e143f118a1f01 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -245,80 +245,83 @@ enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
class PyGreedyRewriteDriverConfig {
public:
PyGreedyRewriteDriverConfig()
- : config(mlirGreedyRewriteDriverConfigCreate()) {}
+ : config(mlirGreedyRewriteDriverConfigCreate().ptr,
+ PyGreedyRewriteDriverConfig::customDeleter) {}
PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
- : config(other.config) {
- other.config.ptr = nullptr;
- }
- ~PyGreedyRewriteDriverConfig() {
- if (config.ptr != nullptr)
- mlirGreedyRewriteDriverConfigDestroy(config);
+ : config(std::move(other.config)) {}
+ PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
+ : config(other.config) {}
+
+ MlirGreedyRewriteDriverConfig get() {
+ return MlirGreedyRewriteDriverConfig{config.get()};
}
- MlirGreedyRewriteDriverConfig get() { return config; }
void setMaxIterations(int64_t maxIterations) {
- mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+ mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations);
}
void setMaxNumRewrites(int64_t maxNumRewrites) {
- mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+ mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites);
}
void setUseTopDownTraversal(bool useTopDownTraversal) {
- mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
+ mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(),
useTopDownTraversal);
}
void enableFolding(bool enable) {
- mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+ mlirGreedyRewriteDriverConfigEnableFolding(get(), enable);
}
void setStrictness(PyGreedyRewriteStrictness strictness) {
mlirGreedyRewriteDriverConfigSetStrictness(
- config, static_cast<MlirGreedyRewriteStrictness>(strictness));
+ get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
}
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
- config, static_cast<MlirGreedySimplifyRegionLevel>(level));
+ get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
}
void enableConstantCSE(bool enable) {
- mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+ mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable);
}
int64_t getMaxIterations() {
- return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+ return mlirGreedyRewriteDriverConfigGetMaxIterations(get());
}
int64_t getMaxNumRewrites() {
- return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+ return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get());
}
bool getUseTopDownTraversal() {
- return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+ return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get());
}
bool isFoldingEnabled() {
- return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+ return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get());
}
PyGreedyRewriteStrictness getStrictness() {
return static_cast<PyGreedyRewriteStrictness>(
- mlirGreedyRewriteDriverConfigGetStrictness(config));
+ mlirGreedyRewriteDriverConfigGetStrictness(get()));
}
PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
return static_cast<PyGreedySimplifyRegionLevel>(
- mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+ mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get()));
}
bool isConstantCSEEnabled() {
- return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+ return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get());
}
private:
- MlirGreedyRewriteDriverConfig config;
+ std::shared_ptr<void> config;
+ static void customDeleter(void *c) {
+ mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
+ }
};
/// Create the `mlir.rewrite` here.
@@ -504,26 +507,31 @@ void populateRewriteSubmodule(nb::module_ &m) {
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
- [](PyModule &module, PyFrozenRewritePatternSet &set) {
- auto status = mlirApplyPatternsAndFoldGreedily(
- module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
+ [](PyModule &module, PyFrozenRewritePatternSet &set,
+ std::optional<PyGreedyRewriteDriverConfig> config) {
+ MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
+ module.get(), set.get(),
+ config.has_value() ? config->get()
+ : mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
- "module"_a, "set"_a,
+ "module"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
- auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
+ std::optional<PyGreedyRewriteDriverConfig> config) {
+ MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
- mlirGreedyRewriteDriverConfigCreate());
+ config.has_value() ? config->get()
+ : mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
- "op"_a, "set"_a,
+ "op"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index a2fbbde38b8c0..43e9b761a0ea2 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
level = config.region_simplification_level
print(f"region_level AGGRESSIVE: {level}")
+
+
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+ at run
+def testRewriteWithGreedyRewriteDriverConfig():
+ def constant_1_to_2(op, rewriter):
+ c = op.value.value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.ConstantOp, constant_1_to_2)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 1 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ config = GreedyRewriteDriverConfig()
+ config.enable_constant_cse = False
+ apply_patterns_and_fold_greedily(module, frozen, config)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c2_i64_0 = arith.constant 2 : i64
+ # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
+ print(module)
+
+ config = GreedyRewriteDriverConfig()
+ config.enable_constant_cse = True
+ apply_patterns_and_fold_greedily(module, frozen, config)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: return %c2_i64, %c2_i64 : i64
+ print(module)
More information about the Mlir-commits
mailing list