[Mlir-commits] [mlir] [mlir] add a fluent API to GreedyRewriterConfig (PR #132253)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 20 10:03:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
This is similar to other configuration objects used across MLIR.
---
Full diff: https://github.com/llvm/llvm-project/pull/132253.diff
11 Files Affected:
- (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+37)
- (modified) mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp (+6-5)
- (modified) mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp (+4-3)
- (modified) mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp (+4-3)
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+5-4)
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+7-9)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+7-7)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-3)
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+3-4)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+4-4)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 110b4f64856eb..aff2616d11276 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -50,6 +50,10 @@ class GreedyRewriteConfig {
///
/// Note: Only applicable when simplifying entire regions.
bool useTopDownTraversal = false;
+ GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
+ useTopDownTraversal = use;
+ return *this;
+ }
/// Perform control flow optimizations to the region tree after applying all
/// patterns.
@@ -57,6 +61,11 @@ class GreedyRewriteConfig {
/// Note: Only applicable when simplifying entire regions.
GreedySimplifyRegionLevel enableRegionSimplification =
GreedySimplifyRegionLevel::Aggressive;
+ GreedyRewriteConfig &
+ setEnableRegionSimplification(GreedySimplifyRegionLevel level) {
+ enableRegionSimplification = level;
+ return *this;
+ }
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
@@ -64,10 +73,18 @@ class GreedyRewriteConfig {
///
/// Note: Only applicable when simplifying entire regions.
int64_t maxIterations = 10;
+ GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
+ maxIterations = iterations;
+ return *this;
+ }
/// This specifies the maximum number of rewrites within an iteration. Use
/// `kNoLimit` to disable this limit.
int64_t maxNumRewrites = kNoLimit;
+ GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
+ maxNumRewrites = limit;
+ return *this;
+ }
static constexpr int64_t kNoLimit = -1;
@@ -76,6 +93,10 @@ class GreedyRewriteConfig {
/// (or the specified region, depending on which greedy rewrite entry point
/// is used) is used as a scope.
Region *scope = nullptr;
+ GreedyRewriteConfig &setScope(Region *scope) {
+ this->scope = scope;
+ return *this;
+ }
/// Strict mode can restrict the ops that are added to the worklist during
/// the rewrite.
@@ -88,16 +109,32 @@ class GreedyRewriteConfig {
/// were on the worklist at the very beginning) enqueued. All other ops are
/// excluded.
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+ GreedyRewriteConfig &setStrictMode(GreedyRewriteStrictness mode) {
+ strictMode = mode;
+ return *this;
+ }
/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;
+ GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
+ this->listener = listener;
+ return *this;
+ }
/// Whether this should fold while greedily rewriting.
bool fold = true;
+ GreedyRewriteConfig &setFold(bool enable = true) {
+ fold = enable;
+ return *this;
+ }
/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool cseConstants = true;
+ GreedyRewriteConfig &setCSEConstants(bool enable = true) {
+ cseConstants = enable;
+ return *this;
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 9f7df7823d997..46c2d11d89d42 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- GreedyRewriteConfig config;
- config.listener =
- static_cast<RewriterBase::Listener *>(rewriter.getListener());
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
- if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
+ if (failed(applyOpPatternsGreedily(
+ targets, frozenPatterns,
+ GreedyRewriteConfig()
+ .setListener(
+ static_cast<RewriterBase::Listener *>(rewriter.getListener()))
+ .setStrictMode(GreedyRewriteStrictness::ExistingAndNewOps)))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 4d30213cc6ec2..a7386af950f56 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
+ (void)applyOpPatternsGreedily(
+ copyOps, frozenPatterns,
+ GreedyRewriteConfig().setStrictMode(
+ GreedyRewriteStrictness::ExistingAndNewOps));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 31711ade3153b..96a7300ee7de6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
+ (void)applyOpPatternsGreedily(
+ opsToSimplify, frozenPatterns,
+ GreedyRewriteConfig().setStrictMode(
+ GreedyRewriteStrictness::ExistingAndNewOps));
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index dd539ff685653..e52196b892dc5 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
- (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
- config, /*changed=*/nullptr, &erased);
+ (void)applyOpPatternsGreedily(
+ res.getOperation(), std::move(patterns),
+ GreedyRewriteConfig().setStrictMode(
+ GreedyRewriteStrictness::ExistingAndNewOps),
+ /*changed=*/nullptr, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f866c91ef6e39..72b7f09344384 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -490,10 +490,9 @@ struct IntRangeOptimizationsPass final
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
- GreedyRewriteConfig config;
- config.listener = &listener;
-
- if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+ if (failed(applyPatternsGreedily(
+ op, std::move(patterns),
+ GreedyRewriteConfig().setListener(&listener))))
signalPassFailure();
}
};
@@ -516,13 +515,12 @@ struct IntRangeNarrowingPass final
RewritePatternSet patterns(ctx);
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
- GreedyRewriteConfig config;
// We specifically need bottom-up traversal as cmpi pattern needs range
// data, attached to its original argument values.
- config.useTopDownTraversal = false;
- config.listener = &listener;
-
- if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+ if (failed(applyPatternsGreedily(
+ op, std::move(patterns),
+ GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
+ &listener))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 35f86a62ae592..a60fd50b6c7b9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+
+ populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
// We don't want that the block structure changes invalidating the
- // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
// region simplification
- GreedyRewriteConfig config;
- config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
- populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
-
- if (failed(
- applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+ if (failed(applyPatternsGreedily(
+ getOperation(), std::move(patterns),
+ GreedyRewriteConfig().setEnableRegionSimplification(
+ GreedySimplifyRegionLevel::Normal))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ef7b4757a04b4..a6d05a65f68e4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3523,9 +3523,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
vector::populateVectorStepLoweringPatterns(patterns);
TrackingListener listener(state, *this);
- GreedyRewriteConfig config;
- config.listener = &listener;
- if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
+ if (failed(
+ applyPatternsGreedily(target, std::move(patterns),
+ GreedyRewriteConfig().setListener(&listener))))
return emitDefaultDefiniteFailure(target);
results.push_back(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 927ce066038d8..5e2b9877f90eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2325,10 +2325,9 @@ struct LinalgElementwiseOpFusionPass
// Add constant folding patterns.
populateConstantFoldLinalgOperations(patterns, defaultControlFn);
- // Use TopDownTraversal for compile time reasons
- GreedyRewriteConfig grc;
- grc.useTopDownTraversal = true;
- (void)applyPatternsGreedily(op, std::move(patterns), grc);
+ // Use TopDownTraversal for compile time reasons.
+ (void)applyPatternsGreedily(op, std::move(patterns),
+ GreedyRewriteConfig().setUseTopDownTraversal());
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index af87fb7a79d04..e02048944c6f3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
if (!patterns)
return success();
- GreedyRewriteConfig config;
- config.listener = this;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- return applyOpPatternsGreedily(ops, patterns.value(), config);
+ return applyOpPatternsGreedily(
+ ops, patterns.value(),
+ GreedyRewriteConfig().setListener(this).setStrictMode(
+ GreedyRewriteStrictness::ExistingAndNewOps));
}
void SliceTrackingListener::notifyOperationInserted(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index a60410d01ac57..15a988e93c479 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1352,9 +1352,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
// We only want to apply signature conversion once to the existing func ops.
// Without specifying strictMode, the greedy pattern rewriter will keep
// looking for newly created func ops.
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- return applyPatternsGreedily(op, std::move(patterns), config);
+ return applyPatternsGreedily(op, std::move(patterns),
+ GreedyRewriteConfig().setStrictMode(
+ GreedyRewriteStrictness::ExistingOps));
}
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/132253
More information about the Mlir-commits
mailing list