[Mlir-commits] [mlir] cd7af14 - Fix canonicalizer to copy the entire GreedyRewriteConfig instead of selected fields
Mehdi Amini
llvmlistbot at llvm.org
Tue Aug 22 20:38:42 PDT 2023
Author: Mehdi Amini
Date: 2023-08-22T20:38:15-07:00
New Revision: cd7af14cbc98e2af93c47ad48eba8f24a1791671
URL: https://github.com/llvm/llvm-project/commit/cd7af14cbc98e2af93c47ad48eba8f24a1791671
DIFF: https://github.com/llvm/llvm-project/commit/cd7af14cbc98e2af93c47ad48eba8f24a1791671.diff
LOG: Fix canonicalizer to copy the entire GreedyRewriteConfig instead of selected fields
It is surprising for the user that only some fields were honored.
Also make the FrozenRewritePatternSet a shared_ptr<const T>.
Fixes #64543
Differential Revision: https://reviews.llvm.org/D157469
Added:
Modified:
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Transforms/Canonicalizer.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0eeb8bb1ec8da5..2131fe313f8c59 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -715,18 +715,20 @@ class AsmParser {
//===--------------------------------------------------------------------===//
/// This class represents a StringSwitch like class that is useful for parsing
- /// expected keywords. On construction, it invokes `parseKeyword` and
- /// processes each of the provided cases statements until a match is hit. The
- /// provided `ResultT` must be assignable from `failure()`.
+ /// expected keywords. On construction, unless a non-empty keyword is
+ /// provided, it invokes `parseKeyword` and processes each of the provided
+ /// cases statements until a match is hit. The provided `ResultT` must be
+ /// assignable from `failure()`.
template <typename ResultT = ParseResult>
class KeywordSwitch {
public:
- KeywordSwitch(AsmParser &parser)
+ KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr)
: parser(parser), loc(parser.getCurrentLocation()) {
- if (failed(parser.parseKeywordOrCompletion(&keyword)))
+ if (keyword && !keyword->empty())
+ this->keyword = *keyword;
+ else if (failed(parser.parseKeywordOrCompletion(&this->keyword)))
result = failure();
}
-
/// Case that uses the provided value when true.
KeywordSwitch &Case(StringLiteral str, ResultT value) {
return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index b4ad85c7c7dad2..d50019bd6aee55 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -29,7 +29,8 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
Canonicalizer() = default;
Canonicalizer(const GreedyRewriteConfig &config,
ArrayRef<std::string> disabledPatterns,
- ArrayRef<std::string> enabledPatterns) {
+ ArrayRef<std::string> enabledPatterns)
+ : config(config) {
this->topDownProcessingEnabled = config.useTopDownTraversal;
this->enableRegionSimplification = config.enableRegionSimplification;
this->maxIterations = config.maxIterations;
@@ -41,30 +42,31 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
+ // Set the config from possible pass options set in the meantime.
+ config.useTopDownTraversal = topDownProcessingEnabled;
+ config.enableRegionSimplification = enableRegionSimplification;
+ config.maxIterations = maxIterations;
+ config.maxNumRewrites = maxNumRewrites;
+
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
- patterns = FrozenRewritePatternSet(std::move(owningPatterns),
- disabledPatterns, enabledPatterns);
+ patterns = std::make_shared<FrozenRewritePatternSet>(
+ std::move(owningPatterns), disabledPatterns, enabledPatterns);
return success();
}
void runOnOperation() override {
- GreedyRewriteConfig config;
- config.useTopDownTraversal = topDownProcessingEnabled;
- config.enableRegionSimplification = enableRegionSimplification;
- config.maxIterations = maxIterations;
- config.maxNumRewrites = maxNumRewrites;
LogicalResult converged =
- applyPatternsAndFoldGreedily(getOperation(), patterns, config);
+ applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
if (testConvergence && failed(converged))
signalPassFailure();
}
-
- FrozenRewritePatternSet patterns;
+ GreedyRewriteConfig config;
+ std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
More information about the Mlir-commits
mailing list