[flang-commits] [flang] 4cb9a37 - Revert "[mlir] add a fluent API to GreedyRewriterConfig (#132253)"
Kazu Hirata via flang-commits
flang-commits at lists.llvm.org
Fri Apr 18 09:40:34 PDT 2025
Author: Kazu Hirata
Date: 2025-04-18T09:40:28-07:00
New Revision: 4cb9a3700c31357821e192124baeb3a3a35ff93b
URL: https://github.com/llvm/llvm-project/commit/4cb9a3700c31357821e192124baeb3a3a35ff93b
DIFF: https://github.com/llvm/llvm-project/commit/4cb9a3700c31357821e192124baeb3a3a35ff93b.diff
LOG: Revert "[mlir] add a fluent API to GreedyRewriterConfig (#132253)"
This reverts commit 63b8f1c9482ed0a964980df4aed89bef922b8078.
Buildbot failure:
https://lab.llvm.org/buildbot/#/builders/172/builds/12083/steps/5/logs/stdio
I've reproduced the error with a release build (-DCMAKE_BUILD_TYPE=Release).
Added:
Modified:
flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
flang/lib/Optimizer/Passes/Pipelines.cpp
flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
flang/lib/Optimizer/Transforms/StackArrays.cpp
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Reducer/ReductionTreePass.cpp
mlir/lib/Transforms/Canonicalizer.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
index 7fb713ff1a6c7..7deed3d44ae5b 100644
--- a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
+++ b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
@@ -357,8 +357,8 @@ class LowerRepackArraysPass
patterns.insert<PackArrayConversion>(context);
patterns.insert<UnpackArrayConversion>(context);
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsGreedily(module, std::move(patterns), config);
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
index c42b895946d19..b68fe6ee0c747 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
@@ -119,8 +119,8 @@ class InlineElementalsPass
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
index 6e209cce07ad4..249976d5509b0 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
@@ -135,8 +135,8 @@ class InlineHLFIRAssignPass
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineHLFIRAssignConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 31e5bc1193e22..7c0fcba806869 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -557,8 +557,8 @@ class LowerHLFIRIntrinsics
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c0856082989c1..c489450384a35 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1327,8 +1327,8 @@ class OptimizedBufferizationPass
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
mlir::RewritePatternSet patterns(context);
// TODO: right now the patterns are non-conflicting,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index d153db2afff07..bac10121a881b 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -1476,8 +1476,8 @@ class SimplifyHLFIRIntrinsics
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 7a06a27748ebd..81ff6bf9b2c6a 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -35,8 +35,7 @@ void addNestedPassToAllTopLevelOperationsConditionally(
void addCanonicalizerPassWithoutRegionSimplification(mlir::OpPassManager &pm) {
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
pm.addPass(mlir::createCanonicalizerPass(config));
}
@@ -164,8 +163,7 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
// simplify the IR
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
pm.addPass(mlir::createCSEPass());
fir::addAVC(pm, pc.OptLevel);
addNestedPassToAllTopLevelOperations<PassConstructor>(
diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
index 6af1cb988a4c1..d0bd67a236419 100644
--- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
@@ -152,8 +152,8 @@ class AssumedRankOpConversion
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};
diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 239a7cdaa4cf2..562f3058f20f3 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -168,9 +168,9 @@ class ConstantArgumentGlobalisationOpt
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
- config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index 6d106046b70f2..212de2f2286db 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -205,8 +205,7 @@ void SimplifyFIROperationsPass::runOnOperation() {
fir::populateSimplifyFIROperationsPatterns(patterns,
preferInlineImplementation);
mlir::GreedyRewriteConfig config;
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index f9b9b4f4ff385..9a6566bef50f1 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -806,8 +806,7 @@ void StackArraysPass::runOnOperation() {
mlir::RewritePatternSet patterns(&context);
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
- config.setRegionSimplificationLevel(
- mlir::GreedySimplifyRegionLevel::Disabled);
+ config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsGreedily(
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 45e61b68f5db2..110b4f64856eb 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -49,43 +49,25 @@ class GreedyRewriteConfig {
/// larger patterns when given an ambiguous pattern set.
///
/// Note: Only applicable when simplifying entire regions.
- bool getUseTopDownTraversal() const { return useTopDownTraversal; }
- GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
- useTopDownTraversal = use;
- return *this;
- }
+ bool useTopDownTraversal = false;
/// Perform control flow optimizations to the region tree after applying all
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
- GreedySimplifyRegionLevel getRegionSimplificationLevel() const {
- return regionSimplificationLevel;
- }
- GreedyRewriteConfig &
- setRegionSimplificationLevel(GreedySimplifyRegionLevel level) {
- regionSimplificationLevel = level;
- return *this;
- }
+ GreedySimplifyRegionLevel enableRegionSimplification =
+ GreedySimplifyRegionLevel::Aggressive;
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
///
/// Note: Only applicable when simplifying entire regions.
- int64_t getMaxIterations() const { return maxIterations; }
- GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
- maxIterations = iterations;
- return *this;
- }
+ int64_t maxIterations = 10;
/// This specifies the maximum number of rewrites within an iteration. Use
/// `kNoLimit` to disable this limit.
- int64_t getMaxNumRewrites() const { return maxNumRewrites; }
- GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
- maxNumRewrites = limit;
- return *this;
- }
+ int64_t maxNumRewrites = kNoLimit;
static constexpr int64_t kNoLimit = -1;
@@ -93,11 +75,7 @@ class GreedyRewriteConfig {
/// specified, the closest enclosing region around the initial list of ops
/// (or the specified region, depending on which greedy rewrite entry point
/// is used) is used as a scope.
- Region *getScope() const { return scope; }
- GreedyRewriteConfig &setScope(Region *scope) {
- this->scope = scope;
- return *this;
- }
+ Region *scope = nullptr;
/// Strict mode can restrict the ops that are added to the worklist during
/// the rewrite.
@@ -109,44 +87,16 @@ class GreedyRewriteConfig {
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
/// were on the worklist at the very beginning) enqueued. All other ops are
/// excluded.
- GreedyRewriteStrictness getStrictness() const { return strictness; }
- GreedyRewriteConfig &setStrictness(GreedyRewriteStrictness mode) {
- strictness = mode;
- return *this;
- }
+ GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
/// An optional listener that should be notified about IR modifications.
- RewriterBase::Listener *getListener() const { return listener; }
- GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
- this->listener = listener;
- return *this;
- }
+ RewriterBase::Listener *listener = nullptr;
/// Whether this should fold while greedily rewriting.
- bool isFoldingEnabled() const { return fold; }
- GreedyRewriteConfig &enableFolding(bool enable = true) {
- fold = enable;
- return *this;
- }
+ bool fold = true;
/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
- bool isConstantCSEEnabled() const { return cseConstants; }
- GreedyRewriteConfig &enableConstantCSE(bool enable = true) {
- cseConstants = enable;
- return *this;
- }
-
-private:
- Region *scope = nullptr;
- bool useTopDownTraversal = false;
- GreedySimplifyRegionLevel regionSimplificationLevel =
- GreedySimplifyRegionLevel::Aggressive;
- int64_t maxIterations = 10;
- int64_t maxNumRewrites = kNoLimit;
- GreedyRewriteStrictness strictness = GreedyRewriteStrictness::AnyOp;
- RewriterBase::Listener *listener = nullptr;
- bool fold = true;
bool cseConstants = true;
};
@@ -178,14 +128,14 @@ applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
-/// FIXME: Remove this once transition to above is completed.
+/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
- config.enableFolding();
+ config.fold = true;
return applyPatternsGreedily(region, patterns, config, changed);
}
@@ -237,7 +187,7 @@ applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
- config.enableFolding();
+ config.fold = true;
return applyPatternsGreedily(op, patterns, config, changed);
}
@@ -283,7 +233,7 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr) {
- config.enableFolding();
+ config.fold = true;
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1e89a78912e99..a39ab77fc8fb3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -33,7 +33,7 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
- Option<"regionSimplifyLevel", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+ Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
"Perform control flow optimizations to the region tree",
[{::llvm::cl::values(
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 43d37ee3332ef..9f7df7823d997 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -127,13 +127,12 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ GreedyRewriteConfig config;
+ config.listener =
+ static_cast<RewriterBase::Listener *>(rewriter.getListener());
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
- if (failed(applyOpPatternsGreedily(
- targets, frozenPatterns,
- GreedyRewriteConfig()
- .setListener(
- static_cast<RewriterBase::Listener *>(rewriter.getListener()))
- .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
+ if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 62c1857e4b1da..4d30213cc6ec2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -237,8 +237,7 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- (void)applyOpPatternsGreedily(
- copyOps, frozenPatterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingAndNewOps));
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 9e9096c2e3186..31711ade3153b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -109,8 +109,7 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
- (void)applyOpPatternsGreedily(
- opsToSimplify, frozenPatterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingAndNewOps));
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 0d4ba3940c48e..dd539ff685653 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -315,12 +315,11 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
- (void)applyOpPatternsGreedily(
- res.getOperation(), std::move(patterns),
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingAndNewOps),
- /*changed=*/nullptr, &erased);
+ (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
+ config, /*changed=*/nullptr, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 11798b99fa879..2925aa918cb1c 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -426,11 +426,11 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
- (void)applyOpPatternsGreedily(
- ifOp.getOperation(), frozenPatterns,
- GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
- /*changed=*/nullptr, &erased);
+ (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
+ /*changed=*/nullptr, &erased);
if (erased) {
if (folded)
*folded = true;
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f2f93883eb2b7..602d80a45993e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -494,9 +494,10 @@ struct IntRangeOptimizationsPass final
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
- if (failed(applyPatternsGreedily(
- op, std::move(patterns),
- GreedyRewriteConfig().setListener(&listener))))
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
};
@@ -519,12 +520,13 @@ struct IntRangeNarrowingPass final
RewritePatternSet patterns(ctx);
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
+ GreedyRewriteConfig config;
// We specifically need bottom-up traversal as cmpi pattern needs range
// data, attached to its original argument values.
- if (failed(applyPatternsGreedily(
- op, std::move(patterns),
- GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
- &listener))))
+ config.useTopDownTraversal = false;
+ config.listener = &listener;
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index c5fab80ecaa08..35f86a62ae592 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
-
- populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
// We don't want that the block structure changes invalidating the
- // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
+ // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
// region simplification
- if (failed(applyPatternsGreedily(
- getOperation(), std::move(patterns),
- GreedyRewriteConfig().setRegionSimplificationLevel(
- GreedySimplifyRegionLevel::Normal))))
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
+ populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
+
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(patterns), config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b20e6050fb4f8..c90ebe4487ca4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3587,9 +3587,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
vector::populateVectorStepLoweringPatterns(patterns);
TrackingListener listener(state, *this);
- if (failed(
- applyPatternsGreedily(target, std::move(patterns),
- GreedyRewriteConfig().setListener(&listener))))
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+ if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
return emitDefaultDefiniteFailure(target);
results.push_back(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 62d016b87d627..bf70597d5ddfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -2327,9 +2327,10 @@ struct LinalgElementwiseOpFusionPass
// Add constant folding patterns.
populateConstantFoldLinalgOperations(patterns, defaultControlFn);
- // Use TopDownTraversal for compile time reasons.
- (void)applyPatternsGreedily(op, std::move(patterns),
- GreedyRewriteConfig().setUseTopDownTraversal());
+ // Use TopDownTraversal for compile time reasons
+ GreedyRewriteConfig grc;
+ grc.useTopDownTraversal = true;
+ (void)applyPatternsGreedily(op, std::move(patterns), grc);
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 7edf19689d2e1..91862d2e17d71 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
if (!patterns)
return success();
- return applyOpPatternsGreedily(
- ops, patterns.value(),
- GreedyRewriteConfig().setListener(this).setStrictness(
- GreedyRewriteStrictness::ExistingAndNewOps));
+ GreedyRewriteConfig config;
+ config.listener = this;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ return applyOpPatternsGreedily(ops, patterns.value(), config);
}
void SliceTrackingListener::notifyOperationInserted(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 811f03abb3461..19b9af146f4a4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1353,9 +1353,9 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
// We only want to apply signature conversion once to the existing func ops.
// Without specifying strictMode, the greedy pattern rewriter will keep
// looking for newly created func ops.
- return applyPatternsGreedily(op, std::move(patterns),
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingOps));
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ return applyPatternsGreedily(op, std::move(patterns), config);
}
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a811fae003584..798853a75441a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -393,16 +393,16 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
// Configure the GreedyPatternRewriteDriver.
GreedyRewriteConfig config;
- config.setListener(
- static_cast<RewriterBase::Listener *>(rewriter.getListener()));
+ config.listener =
+ static_cast<RewriterBase::Listener *>(rewriter.getListener());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
+ config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
+ ? GreedyRewriteConfig::kNoLimit
+ : getMaxIterations();
+ config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
- : getMaxIterations());
- config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
- ? GreedyRewriteConfig::kNoLimit
- : getMaxNumRewrites());
+ : getMaxNumRewrites();
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 549e4f2bd813b..7292752c712ae 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -62,11 +62,11 @@ static void applyPatterns(Region ®ion,
// before that transform.
for (Operation *op : opsInRange) {
// `applyOpPatternsGreedily` with folding returns whether the op is
- // converted. Omit it because we don't have expectation this reduction will
+ // convered. Omit it because we don't have expectation this reduction will
// be success or not.
- (void)applyOpPatternsGreedily(op, patterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingOps));
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ (void)applyOpPatternsGreedily(op, patterns, config);
}
if (eraseOpNotInRange)
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 4b0ac28a03713..7ccd503fb0288 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -32,10 +32,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
ArrayRef<std::string> disabledPatterns,
ArrayRef<std::string> enabledPatterns)
: config(config) {
- this->topDownProcessingEnabled = config.getUseTopDownTraversal();
- this->regionSimplifyLevel = config.getRegionSimplificationLevel();
- this->maxIterations = config.getMaxIterations();
- this->maxNumRewrites = config.getMaxNumRewrites();
+ this->topDownProcessingEnabled = config.useTopDownTraversal;
+ this->enableRegionSimplification = config.enableRegionSimplification;
+ this->maxIterations = config.maxIterations;
+ this->maxNumRewrites = config.maxNumRewrites;
this->disabledPatterns = disabledPatterns;
this->enabledPatterns = enabledPatterns;
}
@@ -44,10 +44,10 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
/// execution.
LogicalResult initialize(MLIRContext *context) override {
// Set the config from possible pass options set in the meantime.
- config.setUseTopDownTraversal(topDownProcessingEnabled);
- config.setRegionSimplificationLevel(regionSimplifyLevel);
- config.setMaxIterations(maxIterations);
- config.setMaxNumRewrites(maxNumRewrites);
+ config.useTopDownTraversal = topDownProcessingEnabled;
+ config.enableRegionSimplification = enableRegionSimplification;
+ config.maxIterations = maxIterations;
+ config.maxNumRewrites = maxNumRewrites;
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 84d547f03829a..7c1cfd91f85e6 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -455,8 +455,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
bool changed = false;
int64_t numRewrites = 0;
while (!worklist.empty() &&
- (numRewrites < config.getMaxNumRewrites() ||
- config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
+ (numRewrites < config.maxNumRewrites ||
+ config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
auto *op = worklist.pop();
LLVM_DEBUG({
@@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
- if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
+ if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@@ -574,21 +574,21 @@ bool GreedyPatternRewriteDriver::processWorklist() {
logger.getOStream() << ")' {\n";
logger.indent();
});
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyPatternBegin(pattern, op);
+ if (config.listener)
+ config.listener->notifyPatternBegin(pattern, op);
return true;
};
function_ref<bool(const Pattern &)> canApply = canApplyCallback;
auto onFailureCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("failure", "pattern failed to match"));
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyPatternEnd(pattern, failure());
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, failure());
};
function_ref<void(const Pattern &)> onFailure = onFailureCallback;
auto onSuccessCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("success", "pattern applied successfully"));
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyPatternEnd(pattern, success());
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, success());
return success();
};
function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
@@ -640,7 +640,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
do {
ancestors.push_back(op);
region = op->getParentRegion();
- if (config.getScope() == region) {
+ if (config.scope == region) {
// Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -652,20 +652,20 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
- if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
+ if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
worklist.push(op);
}
void GreedyPatternRewriteDriver::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyBlockInserted(block, previous, previousIt);
+ if (config.listener)
+ config.listener->notifyBlockInserted(block, previous, previousIt);
}
void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyBlockErased(block);
+ if (config.listener)
+ config.listener->notifyBlockErased(block);
}
void GreedyPatternRewriteDriver::notifyOperationInserted(
@@ -674,9 +674,9 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyOperationInserted(op, previous);
- if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
+ if (config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+ if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
}
@@ -686,8 +686,8 @@ void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
<< ")\n";
});
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyOperationModified(op);
+ if (config.listener)
+ config.listener->notifyOperationModified(op);
addToWorklist(op);
}
@@ -736,18 +736,18 @@ void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
// the part of the IR that is taken into account for the "expensive checks".
// A greedy pattern rewrite is not allowed to erase the parent op of the scope
// region, as that would break the worklist handling and the expensive checks.
- if (Region *scope = config.getScope(); scope->getParentOp() == op)
+ if (config.scope && config.scope->getParentOp() == op)
llvm_unreachable(
"scope region must not be erased during greedy pattern rewrite");
#endif // NDEBUG
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyOperationErased(op);
+ if (config.listener)
+ config.listener->notifyOperationErased(op);
addOperandsToWorklist(op);
worklist.remove(op);
- if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
@@ -757,8 +757,8 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
<< ")\n";
});
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyOperationReplaced(op, replacement);
+ if (config.listener)
+ config.listener->notifyOperationReplaced(op, replacement);
}
void GreedyPatternRewriteDriver::notifyMatchFailure(
@@ -768,8 +768,8 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
reasonCallback(diag);
logger.startLine() << "** Match Failure : " << diag.str() << "\n";
});
- if (RewriterBase::Listener *listener = config.getListener())
- listener->notifyMatchFailure(loc, reasonCallback);
+ if (config.listener)
+ config.listener->notifyMatchFailure(loc, reasonCallback);
}
//===----------------------------------------------------------------------===//
@@ -800,7 +800,7 @@ RegionPatternRewriteDriver::RegionPatternRewriteDriver(
const GreedyRewriteConfig &config, Region ®ion)
: GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
// Populate strict mode ops.
- if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
}
}
@@ -829,8 +829,8 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
- if (++iteration > config.getMaxIterations() &&
- config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
+ if (++iteration > config.maxIterations &&
+ config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
// New iteration: start with an empty worklist.
@@ -849,16 +849,16 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
return false;
};
- if (!config.getUseTopDownTraversal()) {
+ if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
- if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
+ if (!config.cseConstants || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
+ if (!config.cseConstants || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
@@ -875,11 +875,11 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
- if (config.getRegionSimplificationLevel() !=
+ if (config.enableRegionSimplification !=
GreedySimplifyRegionLevel::Disabled) {
continueRewrites |= succeeded(simplifyRegions(
rewriter, region,
- /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
+ /*mergeBlocks=*/config.enableRegionSimplification ==
GreedySimplifyRegionLevel::Aggressive));
}
},
@@ -904,8 +904,8 @@ mlir::applyPatternsGreedily(Region ®ion,
"patterns can only be applied to operations IsolatedFromAbove");
// Set scope if not specified.
- if (!config.getScope())
- config.setScope(®ion);
+ if (!config.scope)
+ config.scope = ®ion;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(config.scope->getParentOp())))
@@ -919,7 +919,7 @@ mlir::applyPatternsGreedily(Region ®ion,
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
- << config.getMaxIterations() << " times\n";
+ << config.maxIterations << " times\n";
});
return converged;
}
@@ -960,7 +960,7 @@ MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
: GreedyPatternRewriteDriver(ctx, patterns, config),
survivingOps(survivingOps) {
- if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert_range(ops);
if (survivingOps) {
@@ -1024,15 +1024,15 @@ LogicalResult mlir::applyOpPatternsGreedily(
}
// Determine scope of rewrite.
- if (!config.getScope()) {
+ if (!config.scope) {
// Compute scope if none was provided. The scope will remain `nullptr` if
// there is a top-level op among `ops`.
- config.setScope(findCommonAncestor(ops));
+ config.scope = findCommonAncestor(ops);
} else {
// If a scope was provided, make sure that all ops are in scope.
#ifndef NDEBUG
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
- return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
+ return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
});
assert(allOpsInScope && "ops must be within the specified scope");
#endif // NDEBUG
@@ -1054,7 +1054,7 @@ LogicalResult mlir::applyOpPatternsGreedily(
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.getMaxNumRewrites() << " rewrites";
+ << config.maxNumRewrites << " rewrites";
});
return converged;
}
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 2a54e0c28f71f..d6aaa6faf94cb 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -144,7 +144,7 @@ void TestAffineDataCopy::runOnOperation() {
}
}
GreedyRewriteConfig config;
- config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsGreedily(copyOps, std::move(patterns), config);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..db02a122872d9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -386,26 +386,26 @@ struct TestGreedyPatternDriver
patterns.insert<IncrementIntAttribute<3>>(&getContext());
GreedyRewriteConfig config;
- config.setUseTopDownTraversal(useTopDownTraversal)
- .setMaxIterations(this->maxIterations)
- .enableFolding(this->fold)
- .enableConstantCSE(this->cseConstants);
+ config.useTopDownTraversal = this->useTopDownTraversal;
+ config.maxIterations = this->maxIterations;
+ config.fold = this->fold;
+ config.cseConstants = this->cseConstants;
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
Option<bool> useTopDownTraversal{
*this, "top-down",
llvm::cl::desc("Seed the worklist in general top-down order"),
- llvm::cl::init(GreedyRewriteConfig().getUseTopDownTraversal())};
+ llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
Option<int> maxIterations{
*this, "max-iterations",
llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
- llvm::cl::init(GreedyRewriteConfig().getMaxIterations())};
+ llvm::cl::init(GreedyRewriteConfig().maxIterations)};
Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
- llvm::cl::init(GreedyRewriteConfig().isFoldingEnabled())};
- Option<bool> cseConstants{
- *this, "cse-constants", llvm::cl::desc("Whether to CSE constants"),
- llvm::cl::init(GreedyRewriteConfig().isConstantCSEEnabled())};
+ llvm::cl::init(GreedyRewriteConfig().fold)};
+ Option<bool> cseConstants{*this, "cse-constants",
+ llvm::cl::desc("Whether to CSE constants"),
+ llvm::cl::init(GreedyRewriteConfig().cseConstants)};
};
struct DumpNotifications : public RewriterBase::Listener {
@@ -501,13 +501,13 @@ struct TestStrictPatternDriver
DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
- config.setListener(&dumpNotifications);
+ config.listener = &dumpNotifications;
if (strictMode == "AnyOp") {
- config.setStrictness(GreedyRewriteStrictness::AnyOp);
+ config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {
- config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
} else if (strictMode == "ExistingOps") {
- config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
} else {
llvm_unreachable("invalid strictness option");
}
More information about the flang-commits
mailing list