[Mlir-commits] [mlir] e468c60 - Fix canonicalizer to copy the entire GreedyRewriteConfig instead of selected fields

Mehdi Amini llvmlistbot at llvm.org
Wed Aug 9 19:59:22 PDT 2023


Author: Mehdi Amini
Date: 2023-08-09T19:59:10-07:00
New Revision: e468c60c96e1af1945179c2ca80b7a9dfcd38398

URL: https://github.com/llvm/llvm-project/commit/e468c60c96e1af1945179c2ca80b7a9dfcd38398
DIFF: https://github.com/llvm/llvm-project/commit/e468c60c96e1af1945179c2ca80b7a9dfcd38398.diff

LOG: Fix canonicalizer to copy the entire GreedyRewriteConfig instead of selected fields

It is surprising for the user that only some fields were honored.

Also make the FrozenRewritePatternSet a shared_ptr<const T>.

Fixes #64543

Differential Revision: https://reviews.llvm.org/D157469

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Transforms/Canonicalizer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0eeb8bb1ec8da5..2131fe313f8c59 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -715,18 +715,20 @@ class AsmParser {
   //===--------------------------------------------------------------------===//
 
   /// This class represents a StringSwitch like class that is useful for parsing
-  /// expected keywords. On construction, it invokes `parseKeyword` and
-  /// processes each of the provided cases statements until a match is hit. The
-  /// provided `ResultT` must be assignable from `failure()`.
+  /// expected keywords. On construction, unless a non-empty keyword is
+  /// provided, it invokes `parseKeyword` and processes each of the provided
+  /// cases statements until a match is hit. The provided `ResultT` must be
+  /// assignable from `failure()`.
   template <typename ResultT = ParseResult>
   class KeywordSwitch {
   public:
-    KeywordSwitch(AsmParser &parser)
+    KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr)
         : parser(parser), loc(parser.getCurrentLocation()) {
-      if (failed(parser.parseKeywordOrCompletion(&keyword)))
+      if (keyword && !keyword->empty())
+        this->keyword = *keyword;
+      else if (failed(parser.parseKeywordOrCompletion(&this->keyword)))
         result = failure();
     }
-
     /// Case that uses the provided value when true.
     KeywordSwitch &Case(StringLiteral str, ResultT value) {
       return Case(str, [&](StringRef, SMLoc) { return std::move(value); });

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index b4ad85c7c7dad2..d50019bd6aee55 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -29,7 +29,8 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
   Canonicalizer() = default;
   Canonicalizer(const GreedyRewriteConfig &config,
                 ArrayRef<std::string> disabledPatterns,
-                ArrayRef<std::string> enabledPatterns) {
+                ArrayRef<std::string> enabledPatterns)
+      : config(config) {
     this->topDownProcessingEnabled = config.useTopDownTraversal;
     this->enableRegionSimplification = config.enableRegionSimplification;
     this->maxIterations = config.maxIterations;
@@ -41,30 +42,31 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
   /// Initialize the canonicalizer by building the set of patterns used during
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
+    // Set the config from possible pass options set in the meantime.
+    config.useTopDownTraversal = topDownProcessingEnabled;
+    config.enableRegionSimplification = enableRegionSimplification;
+    config.maxIterations = maxIterations;
+    config.maxNumRewrites = maxNumRewrites;
+
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
       dialect->getCanonicalizationPatterns(owningPatterns);
     for (RegisteredOperationName op : context->getRegisteredOperations())
       op.getCanonicalizationPatterns(owningPatterns, context);
 
-    patterns = FrozenRewritePatternSet(std::move(owningPatterns),
-                                       disabledPatterns, enabledPatterns);
+    patterns = std::make_shared<FrozenRewritePatternSet>(
+        std::move(owningPatterns), disabledPatterns, enabledPatterns);
     return success();
   }
   void runOnOperation() override {
-    GreedyRewriteConfig config;
-    config.useTopDownTraversal = topDownProcessingEnabled;
-    config.enableRegionSimplification = enableRegionSimplification;
-    config.maxIterations = maxIterations;
-    config.maxNumRewrites = maxNumRewrites;
     LogicalResult converged =
-        applyPatternsAndFoldGreedily(getOperation(), patterns, config);
+        applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
     // Canonicalization is best-effort. Non-convergence is not a pass failure.
     if (testConvergence && failed(converged))
       signalPassFailure();
   }
-
-  FrozenRewritePatternSet patterns;
+  GreedyRewriteConfig config;
+  std::shared_ptr<const FrozenRewritePatternSet> patterns;
 };
 } // namespace
 


        


More information about the Mlir-commits mailing list