[Mlir-commits] [mlir] [mlir][c] Enable creating and setting greedy rewrite confing. (PR #162429)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 1 21:54:51 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/162429
>From 4f7dfcc8539b96cdbdd080bf5f8b82f7f4e79f4d Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Mon, 13 Oct 2025 03:51:57 +0000
Subject: [PATCH 1/3] [mlir][c] Enable creating and setting greedy rewrite
confing.
---
mlir/include/mlir-c/Rewrite.h | 98 +++++++++++++++++-
mlir/lib/Bindings/Python/Rewrite.cpp | 119 ++++++++++++++++++++++
mlir/lib/CAPI/Transforms/Rewrite.cpp | 145 ++++++++++++++++++++++++++-
mlir/test/CAPI/rewrite.c | 47 +++++++++
mlir/test/python/rewrite.py | 99 +++++++++++++++++-
5 files changed, 502 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 9e2685719fe4d..dec533882aeea 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -36,6 +36,26 @@ extern "C" {
DEFINE_C_API_STRUCT(MlirRewriterBase, void);
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+
+/// Greedy rewrite strictness levels.
+typedef enum {
+ /// No restrictions wrt. which ops are processed.
+ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
+ /// Only pre-existing and newly created ops are processed.
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
+ /// Only pre-existing ops are processed.
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
+} MlirGreedyRewriteStrictness;
+
+/// Greedy simplify region levels.
+typedef enum {
+ /// Disable region control-flow simplification.
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
+ /// Run the normal simplification (e.g. dead args elimination).
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
+ /// Run extra simplifications (e.g. block merging).
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
+} MlirGreedySimplifyRegionLevel;
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
@@ -319,7 +339,83 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
- MlirGreedyRewriteDriverConfig);
+ MlirGreedyRewriteDriverConfig config);
+
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+/// Creates a greedy rewrite driver configuration with default settings.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig
+mlirGreedyRewriteDriverConfigCreate();
+
+/// Destroys a greedy rewrite driver configuration.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config);
+
+/// Sets the maximum number of iterations for the greedy rewrite driver.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(
+ MlirGreedyRewriteDriverConfig config, int64_t maxIterations);
+
+/// Sets the maximum number of rewrites within an iteration.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+ MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites);
+
+/// Sets whether to use top-down traversal for the initial population of the
+/// worklist.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+ MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal);
+
+/// Enables or disables folding during greedy rewriting.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config,
+ bool enable);
+
+/// Sets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(
+ MlirGreedyRewriteDriverConfig config,
+ MlirGreedyRewriteStrictness strictness);
+
+/// Sets the region simplification level.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+ MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level);
+
+/// Enables or disables constant CSE.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+ MlirGreedyRewriteDriverConfig config, bool enable);
+
+/// Gets the maximum number of iterations for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+ MlirGreedyRewriteDriverConfig config);
+
+/// Gets the maximum number of rewrites within an iteration.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+ MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether top-down traversal is used for initial worklist population.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+ MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether folding is enabled during greedy rewriting.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+ MlirGreedyRewriteDriverConfig config);
+
+/// Gets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness
+mlirGreedyRewriteDriverConfigGetStrictness(
+ MlirGreedyRewriteDriverConfig config);
+
+/// Gets the region simplification level.
+MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+ MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether constant CSE is enabled.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+ MlirGreedyRewriteDriverConfig config);
/// Applies the given patterns to the given op by a fast walk-based pattern
/// rewrite driver.
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0df9d0cbc7ffc..97da49bb6aac8 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -223,10 +223,98 @@ class PyRewritePatternSet {
MlirContext ctx;
};
+/// Owning Wrapper around a GreedyRewriteDriverConfig.
+class PyGreedyRewriteDriverConfig {
+public:
+ PyGreedyRewriteDriverConfig()
+ : config(mlirGreedyRewriteDriverConfigCreate()) {}
+ PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
+ : config(other.config) {
+ other.config.ptr = nullptr;
+ }
+ ~PyGreedyRewriteDriverConfig() {
+ if (config.ptr != nullptr)
+ mlirGreedyRewriteDriverConfigDestroy(config);
+ }
+ MlirGreedyRewriteDriverConfig get() { return config; }
+
+ void setMaxIterations(int64_t maxIterations) {
+ mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+ }
+
+ void setMaxNumRewrites(int64_t maxNumRewrites) {
+ mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+ }
+
+ void setUseTopDownTraversal(bool useTopDownTraversal) {
+ mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
+ useTopDownTraversal);
+ }
+
+ void enableFolding(bool enable) {
+ mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+ }
+
+ void setStrictness(MlirGreedyRewriteStrictness strictness) {
+ mlirGreedyRewriteDriverConfigSetStrictness(config, strictness);
+ }
+
+ void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) {
+ mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(config, level);
+ }
+
+ void enableConstantCSE(bool enable) {
+ mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+ }
+
+ int64_t getMaxIterations() {
+ return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+ }
+
+ int64_t getMaxNumRewrites() {
+ return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+ }
+
+ bool getUseTopDownTraversal() {
+ return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+ }
+
+ bool isFoldingEnabled() {
+ return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+ }
+
+ MlirGreedyRewriteStrictness getStrictness() {
+ return mlirGreedyRewriteDriverConfigGetStrictness(config);
+ }
+
+ MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() {
+ return mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config);
+ }
+
+ bool isConstantCSEEnabled() {
+ return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+ }
+
+private:
+ MlirGreedyRewriteDriverConfig config;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+ // Enum definitions
+ nb::enum_<MlirGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
+ .value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP)
+ .value("EXISTING_AND_NEW_OPS",
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS)
+ .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+
+ nb::enum_<MlirGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
+ .value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED)
+ .value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL)
+ .value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE);
+
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
@@ -373,6 +461,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
},
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+ nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
+ .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
+ .def_prop_rw("max_iterations",
+ &PyGreedyRewriteDriverConfig::getMaxIterations,
+ &PyGreedyRewriteDriverConfig::setMaxIterations,
+ "Maximum number of iterations")
+ .def_prop_rw("max_num_rewrites",
+ &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
+ &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
+ "Maximum number of rewrites per iteration")
+ .def_prop_rw("use_top_down_traversal",
+ &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
+ &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
+ "Whether to use top-down traversal")
+ .def_prop_rw("enable_folding",
+ &PyGreedyRewriteDriverConfig::isFoldingEnabled,
+ &PyGreedyRewriteDriverConfig::enableFolding,
+ "Enable or disable folding")
+ .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
+ &PyGreedyRewriteDriverConfig::setStrictness,
+ "Rewrite strictness level")
+ .def_prop_rw("region_simplification_level",
+ &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
+ &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
+ "Region simplification level")
+ .def_prop_rw("enable_constant_cse",
+ &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
+ &PyGreedyRewriteDriverConfig::enableConstantCSE,
+ "Enable or disable constant CSE");
+
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index fd4ae6ffc72be..798ca1de651c1 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -283,18 +283,155 @@ void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
set.ptr = nullptr;
}
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) {
+ assert(config.ptr && "unexpected null config");
+ return static_cast<mlir::GreedyRewriteConfig *>(config.ptr);
+}
+
+inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) {
+ return {config};
+}
+
+MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() {
+ return wrap(new mlir::GreedyRewriteConfig());
+}
+
+void mlirGreedyRewriteDriverConfigDestroy(
+ MlirGreedyRewriteDriverConfig config) {
+ delete unwrap(config);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxIterations(
+ MlirGreedyRewriteDriverConfig config, int64_t maxIterations) {
+ unwrap(config)->setMaxIterations(maxIterations);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+ MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) {
+ unwrap(config)->setMaxNumRewrites(maxNumRewrites);
+}
+
+void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+ MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) {
+ unwrap(config)->setUseTopDownTraversal(useTopDownTraversal);
+}
+
+void mlirGreedyRewriteDriverConfigEnableFolding(
+ MlirGreedyRewriteDriverConfig config, bool enable) {
+ unwrap(config)->enableFolding(enable);
+}
+
+void mlirGreedyRewriteDriverConfigSetStrictness(
+ MlirGreedyRewriteDriverConfig config,
+ MlirGreedyRewriteStrictness strictness) {
+ mlir::GreedyRewriteStrictness cppStrictness;
+ switch (strictness) {
+ case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP:
+ cppStrictness = mlir::GreedyRewriteStrictness::AnyOp;
+ break;
+ case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS:
+ cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
+ break;
+ case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS:
+ cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps;
+ break;
+ }
+ unwrap(config)->setStrictness(cppStrictness);
+}
+
+void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+ MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level) {
+ mlir::GreedySimplifyRegionLevel cppLevel;
+ switch (level) {
+ case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED:
+ cppLevel = mlir::GreedySimplifyRegionLevel::Disabled;
+ break;
+ case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL:
+ cppLevel = mlir::GreedySimplifyRegionLevel::Normal;
+ break;
+ case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE:
+ cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive;
+ break;
+ }
+ unwrap(config)->setRegionSimplificationLevel(cppLevel);
+}
+
+void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+ MlirGreedyRewriteDriverConfig config, bool enable) {
+ unwrap(config)->enableConstantCSE(enable);
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+ MlirGreedyRewriteDriverConfig config) {
+ return unwrap(config)->getMaxIterations();
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+ MlirGreedyRewriteDriverConfig config) {
+ return unwrap(config)->getMaxNumRewrites();
+}
+
+bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+ MlirGreedyRewriteDriverConfig config) {
+ return unwrap(config)->getUseTopDownTraversal();
+}
+
+bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+ MlirGreedyRewriteDriverConfig config) {
+ return unwrap(config)->isFoldingEnabled();
+}
+
+MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(
+ MlirGreedyRewriteDriverConfig config) {
+ mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness();
+ switch (cppStrictness) {
+ case mlir::GreedyRewriteStrictness::AnyOp:
+ return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP;
+ case mlir::GreedyRewriteStrictness::ExistingAndNewOps:
+ return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS;
+ case mlir::GreedyRewriteStrictness::ExistingOps:
+ return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS;
+ }
+}
+
+MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+ MlirGreedyRewriteDriverConfig config) {
+ mlir::GreedySimplifyRegionLevel cppLevel =
+ unwrap(config)->getRegionSimplificationLevel();
+ switch (cppLevel) {
+ case mlir::GreedySimplifyRegionLevel::Disabled:
+ return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED;
+ case mlir::GreedySimplifyRegionLevel::Normal:
+ return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL;
+ case mlir::GreedySimplifyRegionLevel::Aggressive:
+ return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE;
+ }
+}
+
+bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+ MlirGreedyRewriteDriverConfig config) {
+ return unwrap(config)->isConstantCSEEnabled();
+}
+
MlirLogicalResult
mlirApplyPatternsAndFoldGreedily(MlirModule op,
MlirFrozenRewritePatternSet patterns,
- MlirGreedyRewriteDriverConfig) {
- return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+ MlirGreedyRewriteDriverConfig config) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+ *unwrap(config)));
}
MlirLogicalResult
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
MlirFrozenRewritePatternSet patterns,
- MlirGreedyRewriteDriverConfig) {
- return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+ MlirGreedyRewriteDriverConfig config) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+ *unwrap(config)));
}
void mlirWalkAndApplyPatterns(MlirOperation op,
diff --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c
index b33d225767046..0745eb496c1d7 100644
--- a/mlir/test/CAPI/rewrite.c
+++ b/mlir/test/CAPI/rewrite.c
@@ -534,6 +534,52 @@ void testReplaceUses(MlirContext ctx) {
mlirModuleDestroy(module);
}
+void testGreedyRewriteDriverConfig(MlirContext ctx) {
+ // CHECK-LABEL: @testGreedyRewriteDriverConfig
+ fprintf(stderr, "@testGreedyRewriteDriverConfig\n");
+
+ // Test config creation and destruction
+ MlirGreedyRewriteDriverConfig config = mlirGreedyRewriteDriverConfigCreate();
+
+ // Test all configuration setters
+ mlirGreedyRewriteDriverConfigSetMaxIterations(config, 5);
+ mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, 100);
+ mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, true);
+ mlirGreedyRewriteDriverConfigEnableFolding(config, false);
+ mlirGreedyRewriteDriverConfigSetStrictness(
+ config, MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+ mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+ config, MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL);
+ mlirGreedyRewriteDriverConfigEnableConstantCSE(config, false);
+
+ // Test all configuration getters and verify values
+ // CHECK: MaxIterations: 5
+ fprintf(stderr, "MaxIterations: %ld\n",
+ mlirGreedyRewriteDriverConfigGetMaxIterations(config));
+ // CHECK: MaxNumRewrites: 100
+ fprintf(stderr, "MaxNumRewrites: %ld\n",
+ mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config));
+ // CHECK: UseTopDownTraversal: 1
+ fprintf(stderr, "UseTopDownTraversal: %d\n",
+ mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config));
+ // CHECK: FoldingEnabled: 0
+ fprintf(stderr, "FoldingEnabled: %d\n",
+ mlirGreedyRewriteDriverConfigIsFoldingEnabled(config));
+ // CHECK: Strictness: 2
+ fprintf(stderr, "Strictness: %d\n",
+ mlirGreedyRewriteDriverConfigGetStrictness(config));
+ // CHECK: RegionSimplificationLevel: 1
+ fprintf(stderr, "RegionSimplificationLevel: %d\n",
+ mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+ // CHECK: ConstantCSEEnabled: 0
+ fprintf(stderr, "ConstantCSEEnabled: %d\n",
+ mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config));
+
+ // CHECK: Config test completed successfully
+ fprintf(stderr, "Config test completed successfully\n");
+ mlirGreedyRewriteDriverConfigDestroy(config);
+}
+
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirContextSetAllowUnregisteredDialects(ctx, true);
@@ -547,6 +593,7 @@ int main(void) {
testMove(ctx);
testOpModification(ctx);
testReplaceUses(ctx);
+ testGreedyRewriteDriverConfig(ctx);
mlirContextDestroy(ctx);
return 0;
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index e40d5eb92b86f..269ace8c0bba0 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -1,15 +1,17 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s
+import gc
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
-from mlir.rewrite import *
def run(f):
print("\nTEST:", f.__name__)
f()
+ gc.collect()
+ return f
# CHECK-LABEL: TEST: testRewritePattern
@@ -84,3 +86,98 @@ def constant_1_to_2(op, rewriter):
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
# CHECK: return %0 : i64
print(module)
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation
+ at run
+def testGreedyRewriteDriverConfigCreation():
+ # Test basic config creation and destruction
+ config = GreedyRewriteDriverConfig()
+ # CHECK: Config created successfully
+ print("Config created successfully")
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters
+ at run
+def testGreedyRewriteDriverConfigGetters():
+ config = GreedyRewriteDriverConfig()
+
+ # Set some values
+ config.max_iterations = 5
+ config.max_num_rewrites = 50
+ config.use_top_down_traversal = True
+ config.enable_folding = False
+ config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+ config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
+ config.enable_constant_cse = True
+
+ # Test all getter methods and print results
+ # CHECK: max_iterations: 5
+ max_iterations = config.max_iterations
+ print(f"max_iterations: {max_iterations}")
+ # CHECK: max_rewrites: 50
+ max_rewrites = config.max_num_rewrites
+ print(f"max_rewrites: {max_rewrites}")
+ # CHECK: use_top_down: True
+ use_top_down = config.use_top_down_traversal
+ print(f"use_top_down: {use_top_down}")
+ # CHECK: folding_enabled: False
+ folding_enabled = config.enable_folding
+ print(f"folding_enabled: {folding_enabled}")
+ # CHECK: strictness: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+ strictness = config.strictness
+ print(f"strictness: {strictness}")
+ # CHECK: region_level: GreedySimplifyRegionLevel.AGGRESSIVE
+ region_level = config.region_simplification_level
+ print(f"region_level: {region_level}")
+ # CHECK: cse_enabled: True
+ cse_enabled = config.enable_constant_cse
+ print(f"cse_enabled: {cse_enabled}")
+
+
+# CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
+ at run
+def testGreedyRewriteStrictnessEnum():
+ config = GreedyRewriteDriverConfig()
+
+ # Test ANY_OP
+ # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
+ config.strictness = GreedyRewriteStrictness.ANY_OP
+ strictness = config.strictness
+ print(f"strictness ANY_OP: {strictness}")
+
+ # Test EXISTING_AND_NEW_OPS
+ # CHECK: strictness EXISTING_AND_NEW_OPS: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+ config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
+ strictness = config.strictness
+ print(f"strictness EXISTING_AND_NEW_OPS: {strictness}")
+
+ # Test EXISTING_OPS
+ # CHECK: strictness EXISTING_OPS: GreedyRewriteStrictness.EXISTING_OPS
+ config.strictness = GreedyRewriteStrictness.EXISTING_OPS
+ strictness = config.strictness
+ print(f"strictness EXISTING_OPS: {strictness}")
+
+
+# CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
+ at run
+def testGreedySimplifyRegionLevelEnum():
+ config = GreedyRewriteDriverConfig()
+
+ # Test DISABLED
+ # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
+ config.region_simplification_level = GreedySimplifyRegionLevel.DISABLED
+ level = config.region_simplification_level
+ print(f"region_level DISABLED: {level}")
+
+ # Test NORMAL
+ # CHECK: region_level NORMAL: GreedySimplifyRegionLevel.NORMAL
+ config.region_simplification_level = GreedySimplifyRegionLevel.NORMAL
+ level = config.region_simplification_level
+ print(f"region_level NORMAL: {level}")
+
+ # Test AGGRESSIVE
+ # CHECK: region_level AGGRESSIVE: GreedySimplifyRegionLevel.AGGRESSIVE
+ config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
+ level = config.region_simplification_level
+ print(f"region_level AGGRESSIVE: {level}")
>From 9b33f416114ded015eeee5361b2ca3d887c7ae30 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Fri, 2 Jan 2026 04:34:27 +0000
Subject: [PATCH 2/3] Fix bad merge & create config rather than default
constructor
---
mlir/lib/Bindings/Python/Rewrite.cpp | 13 +++++++------
mlir/test/python/rewrite.py | 1 +
2 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 97da49bb6aac8..c3a17d57d43ad 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -500,8 +500,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set) {
- auto status =
- mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
+ auto status = mlirApplyPatternsAndFoldGreedily(
+ module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
@@ -514,8 +514,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, MlirFrozenRewritePatternSet set) {
- auto status =
- mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+ auto status = mlirApplyPatternsAndFoldGreedily(
+ module.get(), set, mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
@@ -531,7 +531,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set.get(), {});
+ op.getOperation(), set.get(),
+ mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
@@ -546,7 +547,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set, {});
+ op.getOperation(), set, mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 269ace8c0bba0..c2af43201bc89 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -5,6 +5,7 @@
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
+from mlir.rewrite import *
def run(f):
>From 124ece7291f5b8d049fa479cfde8e14da16283b8 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 2 Jan 2026 13:54:43 +0800
Subject: [PATCH 3/3] Update mlir/include/mlir-c/Rewrite.h
---
mlir/include/mlir-c/Rewrite.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index dec533882aeea..26f7f08535b41 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -347,7 +347,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
/// Creates a greedy rewrite driver configuration with default settings.
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig
-mlirGreedyRewriteDriverConfigCreate();
+mlirGreedyRewriteDriverConfigCreate(void);
/// Destroys a greedy rewrite driver configuration.
MLIR_CAPI_EXPORTED void
More information about the Mlir-commits
mailing list