[flang-commits] [flang] [mlir] [mlir] add a fluent API to GreedyRewriterConfig (PR #137122)
Oleksandr Alex Zinenko via flang-commits
flang-commits at lists.llvm.org
Wed Apr 23 23:51:43 PDT 2025
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/137122
>From 20c6df85388381f3df42d24e5a1741745fd94920 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Fri, 4 Apr 2025 14:37:25 +0200
Subject: [PATCH] [mlir] add a fluent API to GreedyRewriterConfig
This is similar to other configuration objects used across MLIR.
Rename some fields to better reflect that they are no longer booleans.
Reland 04d261101b4f229189463136a794e3e362a793af / #132253.
---
.../Optimizer/CodeGen/LowerRepackArrays.cpp | 4 +-
.../HLFIR/Transforms/InlineElementals.cpp | 4 +-
.../HLFIR/Transforms/InlineHLFIRAssign.cpp | 4 +-
.../HLFIR/Transforms/LowerHLFIRIntrinsics.cpp | 4 +-
.../Transforms/OptimizedBufferization.cpp | 4 +-
.../Transforms/SimplifyHLFIRIntrinsics.cpp | 4 +-
flang/lib/Optimizer/Passes/Pipelines.cpp | 6 +-
.../Transforms/AssumedRankOpConversion.cpp | 4 +-
.../ConstantArgumentGlobalisation.cpp | 6 +-
.../Transforms/SimplifyFIROperations.cpp | 3 +-
.../lib/Optimizer/Transforms/StackArrays.cpp | 3 +-
.../Transforms/GreedyPatternRewriteDriver.h | 76 ++++++++++---
mlir/include/mlir/Transforms/Passes.td | 2 +-
.../TransformOps/AffineTransformOps.cpp | 11 +-
.../Transforms/AffineDataCopyGeneration.cpp | 7 +-
.../Transforms/SimplifyAffineStructures.cpp | 7 +-
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 9 +-
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 8 +-
.../Transforms/IntRangeOptimizations.cpp | 16 ++-
.../BufferDeallocationSimplification.cpp | 14 +--
.../TransformOps/LinalgTransformOps.cpp | 6 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 7 +-
.../SCF/Transforms/TileUsingInterface.cpp | 8 +-
.../SPIRV/Transforms/SPIRVConversion.cpp | 6 +-
.../lib/Dialect/Transform/IR/TransformOps.cpp | 14 +--
mlir/lib/Reducer/ReductionTreePass.cpp | 8 +-
mlir/lib/Transforms/Canonicalizer.cpp | 16 +--
.../Utils/GreedyPatternRewriteDriver.cpp | 101 +++++++++---------
.../lib/Dialect/Affine/TestAffineDataCopy.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 28 ++---
30 files changed, 224 insertions(+), 168 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
index 7deed3d44ae5b..7fb713ff1a6c7 100644
--- a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
+++ b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
@@ -357,8 +357,8 @@ class LowerRepackArraysPass
patterns.insert<PackArrayConversion>(context);
patterns.insert<UnpackArrayConversion>(context);
mlir::GreedyRewriteConfig config;
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
(void)applyPatternsGreedily(module, std::move(patterns), config);
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
index b68fe6ee0c747..c42b895946d19 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
@@ -119,8 +119,8 @@ class InlineElementalsPass
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
index 249976d5509b0..6e209cce07ad4 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
@@ -135,8 +135,8 @@ class InlineHLFIRAssignPass
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineHLFIRAssignConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 7c0fcba806869..31e5bc1193e22 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -557,8 +557,8 @@ class LowerHLFIRIntrinsics
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 79aabd2981e1a..2f6ee2592a84f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -875,8 +875,8 @@ class OptimizedBufferizationPass
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
// TODO: right now the patterns are non-conflicting,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index e9d820adbd22b..1dea7d89e180d 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -2132,8 +2132,8 @@ class SimplifyHLFIRIntrinsics
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 81ff6bf9b2c6a..7a06a27748ebd 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -35,7 +35,8 @@ void addNestedPassToAllTopLevelOperationsConditionally(
void addCanonicalizerPassWithoutRegionSimplification(mlir::OpPassManager &pm) {
mlir::GreedyRewriteConfig config;
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
pm.addPass(mlir::createCanonicalizerPass(config));
}
@@ -163,7 +164,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
// simplify the IR
mlir::GreedyRewriteConfig config;
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
pm.addPass(mlir::createCSEPass());
fir::addAVC(pm, pc.OptLevel);
addNestedPassToAllTopLevelOperations<PassConstructor>(
diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
index d0bd67a236419..6af1cb988a4c1 100644
--- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
@@ -152,8 +152,8 @@ class AssumedRankOpConversion
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
mlir::GreedyRewriteConfig config;
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};
diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 562f3058f20f3..239a7cdaa4cf2 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -168,9 +168,9 @@ class ConstantArgumentGlobalisationOpt
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
- config.enableRegionSimplification =
- mlir::GreedySimplifyRegionLevel::Disabled;
- config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
+ config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index 212de2f2286db..6d106046b70f2 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -205,7 +205,8 @@ void SimplifyFIROperationsPass::runOnOperation() {
fir::populateSimplifyFIROperationsPatterns(patterns,
preferInlineImplementation);
mlir::GreedyRewriteConfig config;
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 9a6566bef50f1..f9b9b4f4ff385 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -806,7 +806,8 @@ void StackArraysPass::runOnOperation() {
mlir::RewritePatternSet patterns(&context);
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
- config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
+ config.setRegionSimplificationLevel(
+ mlir::GreedySimplifyRegionLevel::Disabled);
patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsGreedily(
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 110b4f64856eb..45e61b68f5db2 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -49,25 +49,43 @@ class GreedyRewriteConfig {
/// larger patterns when given an ambiguous pattern set.
///
/// Note: Only applicable when simplifying entire regions.
- bool useTopDownTraversal = false;
+ bool getUseTopDownTraversal() const { return useTopDownTraversal; }
+ GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
+ useTopDownTraversal = use;
+ return *this;
+ }
/// Perform control flow optimizations to the region tree after applying all
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
- GreedySimplifyRegionLevel enableRegionSimplification =
- GreedySimplifyRegionLevel::Aggressive;
+ GreedySimplifyRegionLevel getRegionSimplificationLevel() const {
+ return regionSimplificationLevel;
+ }
+ GreedyRewriteConfig &
+ setRegionSimplificationLevel(GreedySimplifyRegionLevel level) {
+ regionSimplificationLevel = level;
+ return *this;
+ }
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
///
/// Note: Only applicable when simplifying entire regions.
- int64_t maxIterations = 10;
+ int64_t getMaxIterations() const { return maxIterations; }
+ GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
+ maxIterations = iterations;
+ return *this;
+ }
/// This specifies the maximum number of rewrites within an iteration. Use
/// `kNoLimit` to disable this limit.
- int64_t maxNumRewrites = kNoLimit;
+ int64_t getMaxNumRewrites() const { return maxNumRewrites; }
+ GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
+ maxNumRewrites = limit;
+ return *this;
+ }
static constexpr int64_t kNoLimit = -1;
@@ -75,7 +93,11 @@ class GreedyRewriteConfig {
/// specified, the closest enclosing region around the initial list of ops
/// (or the specified region, depending on which greedy rewrite entry point
/// is used) is used as a scope.
- Region *scope = nullptr;
+ Region *getScope() const { return scope; }
+ GreedyRewriteConfig &setScope(Region *scope) {
+ this->scope = scope;
+ return *this;
+ }
/// Strict mode can restrict the ops that are added to the worklist during
/// the rewrite.
@@ -87,16 +109,44 @@ class GreedyRewriteConfig {
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
/// were on the worklist at the very beginning) enqueued. All other ops are
/// excluded.
- GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+ GreedyRewriteStrictness getStrictness() const { return strictness; }
+ GreedyRewriteConfig &setStrictness(GreedyRewriteStrictness mode) {
+ strictness = mode;
+ return *this;
+ }
/// An optional listener that should be notified about IR modifications.
- RewriterBase::Listener *listener = nullptr;
+ RewriterBase::Listener *getListener() const { return listener; }
+ GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
+ this->listener = listener;
+ return *this;
+ }
/// Whether this should fold while greedily rewriting.
- bool fold = true;
+ bool isFoldingEnabled() const { return fold; }
+ GreedyRewriteConfig &enableFolding(bool enable = true) {
+ fold = enable;
+ return *this;
+ }
/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
+ bool isConstantCSEEnabled() const { return cseConstants; }
+ GreedyRewriteConfig &enableConstantCSE(bool enable = true) {
+ cseConstants = enable;
+ return *this;
+ }
+
+private:
+ Region *scope = nullptr;
+ bool useTopDownTraversal = false;
+ GreedySimplifyRegionLevel regionSimplificationLevel =
+ GreedySimplifyRegionLevel::Aggressive;
+ int64_t maxIterations = 10;
+ int64_t maxNumRewrites = kNoLimit;
+ GreedyRewriteStrictness strictness = GreedyRewriteStrictness::AnyOp;
+ RewriterBase::Listener *listener = nullptr;
+ bool fold = true;
bool cseConstants = true;
};
@@ -128,14 +178,14 @@ applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
-/// FIXME: Remove this once transition to above is complieted.
+/// FIXME: Remove this once transition to above is completed.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
- config.fold = true;
+ config.enableFolding();
return applyPatternsGreedily(region, patterns, config, changed);
}
@@ -187,7 +237,7 @@ applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
- config.fold = true;
+ config.enableFolding();
return applyPatternsGreedily(op, patterns, config, changed);
}
@@ -233,7 +283,7 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr) {
- config.fold = true;
+ config.enableFolding();
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..1e89a78912e99 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -33,7 +33,7 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
- Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+ Option<"regionSimplifyLevel", "region-simplify", "mlir::GreedySimplifyRegionLevel",
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
"Perform control flow optimizations to the region tree",
[{::llvm::cl::values(
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 9f7df7823d997..43d37ee3332ef 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- GreedyRewriteConfig config;
- config.listener =
- static_cast<RewriterBase::Listener *>(rewriter.getListener());
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
- if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
+ if (failed(applyOpPatternsGreedily(
+ targets, frozenPatterns,
+ GreedyRewriteConfig()
+ .setListener(
+ static_cast<RewriterBase::Listener *>(rewriter.getListener()))
+ .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 4d30213cc6ec2..62c1857e4b1da 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
+ (void)applyOpPatternsGreedily(
+ copyOps, frozenPatterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingAndNewOps));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 31711ade3153b..9e9096c2e3186 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
+ (void)applyOpPatternsGreedily(
+ opsToSimplify, frozenPatterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingAndNewOps));
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index dd539ff685653..0d4ba3940c48e 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
- (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
- config, /*changed=*/nullptr, &erased);
+ (void)applyOpPatternsGreedily(
+ res.getOperation(), std::move(patterns),
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingAndNewOps),
+ /*changed=*/nullptr, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2925aa918cb1c..11798b99fa879 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -426,11 +426,11 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
- (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
- /*changed=*/nullptr, &erased);
+ (void)applyOpPatternsGreedily(
+ ifOp.getOperation(), frozenPatterns,
+ GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
+ /*changed=*/nullptr, &erased);
if (erased) {
if (folded)
*folded = true;
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 602d80a45993e..f2f93883eb2b7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -494,10 +494,9 @@ struct IntRangeOptimizationsPass final
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
- GreedyRewriteConfig config;
- config.listener = &listener;
-
- if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+ if (failed(applyPatternsGreedily(
+ op, std::move(patterns),
+ GreedyRewriteConfig().setListener(&listener))))
signalPassFailure();
}
};
@@ -520,13 +519,12 @@ struct IntRangeNarrowingPass final
RewritePatternSet patterns(ctx);
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
- GreedyRewriteConfig config;
// We specifically need bottom-up traversal as cmpi pattern needs range
// data, attached to its original argument values.
- config.useTopDownTraversal = false;
- config.listener = &listener;
-
- if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
+ if (failed(applyPatternsGreedily(
+ op, std::move(patterns),
+ GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
+ &listener))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 35f86a62ae592..c5fab80ecaa08 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+
+ populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
// We don't want that the block structure changes invalidating the
- // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
// region simplification
- GreedyRewriteConfig config;
- config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
- populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
-
- if (failed(
- applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+ if (failed(applyPatternsGreedily(
+ getOperation(), std::move(patterns),
+ GreedyRewriteConfig().setRegionSimplificationLevel(
+ GreedySimplifyRegionLevel::Normal))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c90ebe4487ca4..b20e6050fb4f8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3587,9 +3587,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
vector::populateVectorStepLoweringPatterns(patterns);
TrackingListener listener(state, *this);
- GreedyRewriteConfig config;
- config.listener = &listener;
- if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
+ if (failed(
+ applyPatternsGreedily(target, std::move(patterns),
+ GreedyRewriteConfig().setListener(&listener))))
return emitDefaultDefiniteFailure(target);
results.push_back(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..62d016b87d627 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2327,10 +2327,9 @@ struct LinalgElementwiseOpFusionPass
// Add constant folding patterns.
populateConstantFoldLinalgOperations(patterns, defaultControlFn);
- // Use TopDownTraversal for compile time reasons
- GreedyRewriteConfig grc;
- grc.useTopDownTraversal = true;
- (void)applyPatternsGreedily(op, std::move(patterns), grc);
+ // Use TopDownTraversal for compile time reasons.
+ (void)applyPatternsGreedily(op, std::move(patterns),
+ GreedyRewriteConfig().setUseTopDownTraversal());
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 91862d2e17d71..7edf19689d2e1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
if (!patterns)
return success();
- GreedyRewriteConfig config;
- config.listener = this;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- return applyOpPatternsGreedily(ops, patterns.value(), config);
+ return applyOpPatternsGreedily(
+ ops, patterns.value(),
+ GreedyRewriteConfig().setListener(this).setStrictness(
+ GreedyRewriteStrictness::ExistingAndNewOps));
}
void SliceTrackingListener::notifyOperationInserted(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 19b9af146f4a4..811f03abb3461 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1353,9 +1353,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
// We only want to apply signature conversion once to the existing func ops.
// Without specifying strictMode, the greedy pattern rewriter will keep
// looking for newly created func ops.
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- return applyPatternsGreedily(op, std::move(patterns), config);
+ return applyPatternsGreedily(op, std::move(patterns),
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingOps));
}
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 4fe89f3f7fb9e..84d339a985c38 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -394,16 +394,16 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
// Configure the GreedyPatternRewriteDriver.
GreedyRewriteConfig config;
- config.listener =
- static_cast<RewriterBase::Listener *>(rewriter.getListener());
+ config.setListener(
+ static_cast<RewriterBase::Listener *>(rewriter.getListener()));
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
- ? GreedyRewriteConfig::kNoLimit
- : getMaxIterations();
- config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
+ config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
- : getMaxNumRewrites();
+ : getMaxIterations());
+ config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
+ ? GreedyRewriteConfig::kNoLimit
+ : getMaxNumRewrites());
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 7292752c712ae..549e4f2bd813b 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -62,11 +62,11 @@ static void applyPatterns(Region ®ion,
// before that transform.
for (Operation *op : opsInRange) {
// `applyOpPatternsGreedily` with folding returns whether the op is
- // convered. Omit it because we don't have expectation this reduction will
+ // converted. Omit it because we don't have expectation this reduction will
// be success or not.
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- (void)applyOpPatternsGreedily(op, patterns, config);
+ (void)applyOpPatternsGreedily(op, patterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingOps));
}
if (eraseOpNotInRange)
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 7ccd503fb0288..4b0ac28a03713 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -32,10 +32,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
ArrayRef<std::string> disabledPatterns,
ArrayRef<std::string> enabledPatterns)
: config(config) {
- this->topDownProcessingEnabled = config.useTopDownTraversal;
- this->enableRegionSimplification = config.enableRegionSimplification;
- this->maxIterations = config.maxIterations;
- this->maxNumRewrites = config.maxNumRewrites;
+ this->topDownProcessingEnabled = config.getUseTopDownTraversal();
+ this->regionSimplifyLevel = config.getRegionSimplificationLevel();
+ this->maxIterations = config.getMaxIterations();
+ this->maxNumRewrites = config.getMaxNumRewrites();
this->disabledPatterns = disabledPatterns;
this->enabledPatterns = enabledPatterns;
}
@@ -44,10 +44,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
/// execution.
LogicalResult initialize(MLIRContext *context) override {
// Set the config from possible pass options set in the meantime.
- config.useTopDownTraversal = topDownProcessingEnabled;
- config.enableRegionSimplification = enableRegionSimplification;
- config.maxIterations = maxIterations;
- config.maxNumRewrites = maxNumRewrites;
+ config.setUseTopDownTraversal(topDownProcessingEnabled);
+ config.setRegionSimplificationLevel(regionSimplifyLevel);
+ config.setMaxIterations(maxIterations);
+ config.setMaxNumRewrites(maxNumRewrites);
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7c1cfd91f85e6..5a719200e0026 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -416,7 +416,8 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
// clang-format off
, expensiveChecks(
/*driver=*/this,
- /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
+ /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
+ : nullptr)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
@@ -455,8 +456,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
bool changed = false;
int64_t numRewrites = 0;
while (!worklist.empty() &&
- (numRewrites < config.maxNumRewrites ||
- config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
+ (numRewrites < config.getMaxNumRewrites() ||
+ config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
auto *op = worklist.pop();
LLVM_DEBUG({
@@ -488,7 +489,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
- if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
+ if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -574,21 +575,21 @@ bool GreedyPatternRewriteDriver::processWorklist() {
logger.getOStream() << ")' {\n";
logger.indent();
});
- if (config.listener)
- config.listener->notifyPatternBegin(pattern, op);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyPatternBegin(pattern, op);
return true;
};
function_ref<bool(const Pattern &)> canApply = canApplyCallback;
auto onFailureCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("failure", "pattern failed to match"));
- if (config.listener)
- config.listener->notifyPatternEnd(pattern, failure());
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyPatternEnd(pattern, failure());
};
function_ref<void(const Pattern &)> onFailure = onFailureCallback;
auto onSuccessCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("success", "pattern applied successfully"));
- if (config.listener)
- config.listener->notifyPatternEnd(pattern, success());
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyPatternEnd(pattern, success());
return success();
};
function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
@@ -596,7 +597,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#ifdef NDEBUG
// Optimization: PatternApplicator callbacks are not needed when running in
// optimized mode and without a listener.
- if (!config.listener) {
+ if (!config.getListener()) {
canApply = nullptr;
onFailure = nullptr;
onSuccess = nullptr;
@@ -604,8 +605,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#endif // NDEBUG
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope) {
- expensiveChecks.computeFingerPrints(config.scope->getParentOp());
+ if (config.getScope()) {
+ expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
}
auto clearFingerprints =
llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
@@ -640,7 +641,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
do {
ancestors.push_back(op);
region = op->getParentRegion();
- if (config.scope == region) {
+ if (config.getScope() == region) {
// Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -652,20 +653,20 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
- if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
+ if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
worklist.push(op);
}
void GreedyPatternRewriteDriver::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- if (config.listener)
- config.listener->notifyBlockInserted(block, previous, previousIt);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyBlockInserted(block, previous, previousIt);
}
void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
- if (config.listener)
- config.listener->notifyBlockErased(block);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyBlockErased(block);
}
void GreedyPatternRewriteDriver::notifyOperationInserted(
@@ -674,9 +675,9 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
- if (config.listener)
- config.listener->notifyOperationInserted(op, previous);
- if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyOperationInserted(op, previous);
+ if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
}
@@ -686,8 +687,8 @@ void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
<< ")\n";
});
- if (config.listener)
- config.listener->notifyOperationModified(op);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyOperationModified(op);
addToWorklist(op);
}
@@ -736,18 +737,18 @@ void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
// the part of the IR that is taken into account for the "expensive checks".
// A greedy pattern rewrite is not allowed to erase the parent op of the scope
// region, as that would break the worklist handling and the expensive checks.
- if (config.scope && config.scope->getParentOp() == op)
+ if (Region *scope = config.getScope(); scope->getParentOp() == op)
llvm_unreachable(
"scope region must not be erased during greedy pattern rewrite");
#endif // NDEBUG
- if (config.listener)
- config.listener->notifyOperationErased(op);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyOperationErased(op);
addOperandsToWorklist(op);
worklist.remove(op);
- if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
@@ -757,8 +758,8 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
<< ")\n";
});
- if (config.listener)
- config.listener->notifyOperationReplaced(op, replacement);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyOperationReplaced(op, replacement);
}
void GreedyPatternRewriteDriver::notifyMatchFailure(
@@ -768,8 +769,8 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
reasonCallback(diag);
logger.startLine() << "** Match Failure : " << diag.str() << "\n";
});
- if (config.listener)
- config.listener->notifyMatchFailure(loc, reasonCallback);
+ if (RewriterBase::Listener *listener = config.getListener())
+ listener->notifyMatchFailure(loc, reasonCallback);
}
//===----------------------------------------------------------------------===//
@@ -800,7 +801,7 @@ RegionPatternRewriteDriver::RegionPatternRewriteDriver(
const GreedyRewriteConfig &config, Region ®ion)
: GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
// Populate strict mode ops.
- if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+ if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
}
}
@@ -829,8 +830,8 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
- if (++iteration > config.maxIterations &&
- config.maxIterations != GreedyRewriteConfig::kNoLimit)
+ if (++iteration > config.getMaxIterations() &&
+ config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
break;
// New iteration: start with an empty worklist.
@@ -849,16 +850,16 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
return false;
};
- if (!config.useTopDownTraversal) {
+ if (!config.getUseTopDownTraversal()) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
- if (!config.cseConstants || !insertKnownConstant(op))
+ if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (!config.cseConstants || !insertKnownConstant(op)) {
+ if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
@@ -875,11 +876,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
- if (config.enableRegionSimplification !=
+ if (config.getRegionSimplificationLevel() !=
GreedySimplifyRegionLevel::Disabled) {
continueRewrites |= succeeded(simplifyRegions(
rewriter, region,
- /*mergeBlocks=*/config.enableRegionSimplification ==
+ /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
GreedySimplifyRegionLevel::Aggressive));
}
},
@@ -904,11 +905,11 @@ mlir::applyPatternsGreedily(Region ®ion,
"patterns can only be applied to operations IsolatedFromAbove");
// Set scope if not specified.
- if (!config.scope)
- config.scope = ®ion;
+ if (!config.getScope())
+ config.setScope(®ion);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (failed(verify(config.scope->getParentOp())))
+ if (failed(verify(config.getScope()->getParentOp())))
llvm::report_fatal_error(
"greedy pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -919,7 +920,7 @@ mlir::applyPatternsGreedily(Region ®ion,
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
- << config.maxIterations << " times\n";
+ << config.getMaxIterations() << " times\n";
});
return converged;
}
@@ -960,7 +961,7 @@ MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
: GreedyPatternRewriteDriver(ctx, patterns, config),
survivingOps(survivingOps) {
- if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert_range(ops);
if (survivingOps) {
@@ -1024,22 +1025,22 @@ LogicalResult mlir::applyOpPatternsGreedily(
}
// Determine scope of rewrite.
- if (!config.scope) {
+ if (!config.getScope()) {
// Compute scope if none was provided. The scope will remain `nullptr` if
// there is a top-level op among `ops`.
- config.scope = findCommonAncestor(ops);
+ config.setScope(findCommonAncestor(ops));
} else {
// If a scope was provided, make sure that all ops are in scope.
#ifndef NDEBUG
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
- return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
+ return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
});
assert(allOpsInScope && "ops must be within the specified scope");
#endif // NDEBUG
}
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
+ if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
llvm::report_fatal_error(
"greedy pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -1054,7 +1055,7 @@ LogicalResult mlir::applyOpPatternsGreedily(
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.maxNumRewrites << " rewrites";
+ << config.getMaxNumRewrites() << " rewrites";
});
return converged;
}
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index d6aaa6faf94cb..2a54e0c28f71f 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -144,7 +144,7 @@ void TestAffineDataCopy::runOnOperation() {
}
}
GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
(void)applyOpPatternsGreedily(copyOps, std::move(patterns), config);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index db02a122872d9..d073843484d81 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -386,26 +386,26 @@ struct TestGreedyPatternDriver
patterns.insert<IncrementIntAttribute<3>>(&getContext());
GreedyRewriteConfig config;
- config.useTopDownTraversal = this->useTopDownTraversal;
- config.maxIterations = this->maxIterations;
- config.fold = this->fold;
- config.cseConstants = this->cseConstants;
+ config.setUseTopDownTraversal(useTopDownTraversal)
+ .setMaxIterations(this->maxIterations)
+ .enableFolding(this->fold)
+ .enableConstantCSE(this->cseConstants);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
Option<bool> useTopDownTraversal{
*this, "top-down",
llvm::cl::desc("Seed the worklist in general top-down order"),
- llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
+ llvm::cl::init(GreedyRewriteConfig().getUseTopDownTraversal())};
Option<int> maxIterations{
*this, "max-iterations",
llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
- llvm::cl::init(GreedyRewriteConfig().maxIterations)};
+ llvm::cl::init(GreedyRewriteConfig().getMaxIterations())};
Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
- llvm::cl::init(GreedyRewriteConfig().fold)};
- Option<bool> cseConstants{*this, "cse-constants",
- llvm::cl::desc("Whether to CSE constants"),
- llvm::cl::init(GreedyRewriteConfig().cseConstants)};
+ llvm::cl::init(GreedyRewriteConfig().isFoldingEnabled())};
+ Option<bool> cseConstants{
+ *this, "cse-constants", llvm::cl::desc("Whether to CSE constants"),
+ llvm::cl::init(GreedyRewriteConfig().isConstantCSEEnabled())};
};
struct DumpNotifications : public RewriterBase::Listener {
@@ -501,13 +501,13 @@ struct TestStrictPatternDriver
DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
- config.listener = &dumpNotifications;
+ config.setListener(&dumpNotifications);
if (strictMode == "AnyOp") {
- config.strictMode = GreedyRewriteStrictness::AnyOp;
+ config.setStrictness(GreedyRewriteStrictness::AnyOp);
} else if (strictMode == "ExistingAndNewOps") {
- config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
} else if (strictMode == "ExistingOps") {
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ config.setStrictness(GreedyRewriteStrictness::ExistingOps);
} else {
llvm_unreachable("invalid strictness option");
}
More information about the flang-commits
mailing list