[Mlir-commits] [mlir] 79d7f61 - Rename FrozenRewritePatternList -> FrozenRewritePatternSet; NFC.

Chris Lattner llvmlistbot at llvm.org
Mon Mar 22 17:41:02 PDT 2021


Author: Chris Lattner
Date: 2021-03-22T17:40:45-07:00
New Revision: 79d7f618af5f0362e6c4a8cccdeb251e82654907

URL: https://github.com/llvm/llvm-project/commit/79d7f618af5f0362e6c4a8cccdeb251e82654907
DIFF: https://github.com/llvm/llvm-project/commit/79d7f618af5f0362e6c4a8cccdeb251e82654907.diff

LOG: Rename FrozenRewritePatternList -> FrozenRewritePatternSet; NFC.

This nicely aligns the naming with RewritePatternSet.  This type isn't
as widely used, but we keep a using declaration in to help with
downstream consumption of this change.

Differential Revision: https://reviews.llvm.org/D99131

Added: 
    mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
    mlir/lib/Rewrite/FrozenRewritePatternSet.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Rewrite/PatternApplicator.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
    mlir/lib/Rewrite/CMakeLists.txt
    mlir/lib/Rewrite/PatternApplicator.cpp
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Transforms/TestConvVectorization.cpp
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp
    mlir/unittests/Rewrite/PatternBenefit.cpp

Removed: 
    mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
    mlir/lib/Rewrite/FrozenRewritePatternList.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e1a136c7e65b..71dbe9fb24cd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -19,7 +19,7 @@
 
 namespace mlir {
 class BufferizeTypeConverter;
-class FrozenRewritePatternList;
+class FrozenRewritePatternSet;
 
 namespace linalg {
 
@@ -964,8 +964,8 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
 //===----------------------------------------------------------------------===//
 /// Helper function to allow applying rewrite patterns, interleaved with more
 /// global transformations, in a staged fashion:
-///   1. the first stage consists of a list of FrozenRewritePatternList. Each
-///   FrozenRewritePatternList in this list is applied once, in order.
+///   1. the first stage consists of a list of FrozenRewritePatternSet. Each
+///   FrozenRewritePatternSet in this list is applied once, in order.
 ///   2. the second stage consists of a single OwningRewritePattern that is
 ///   applied greedily until convergence.
 ///   3. the third stage consists of applying a lambda, generally used for
@@ -973,8 +973,8 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
 ///   transformations where patterns can be ordered and applied at a finer
 ///   granularity than a sequence of traditional compiler passes.
 LogicalResult applyStagedPatterns(
-    Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns,
-    const FrozenRewritePatternList &stage2Patterns,
+    Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
+    const FrozenRewritePatternSet &stage2Patterns,
     function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 514b7ae06938..115ad5f039bc 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -894,7 +894,7 @@ class RewritePatternSet {
   PDLPatternModule pdlPatterns;
 };
 
-// TODO: RewritePatternSet is soft-deprecated and will be removed in the
+// TODO: OwningRewritePatternList is soft-deprecated and will be removed in the
 // future.
 using OwningRewritePatternList = RewritePatternSet;
 

diff  --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
similarity index 71%
rename from mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
rename to mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
index a20030cd08da..554bfd217534 100644
--- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
+++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h
@@ -1,4 +1,4 @@
-//===- FrozenRewritePatternList.h - FrozenRewritePatternList ----*- C++ -*-===//
+//===- FrozenRewritePatternSet.h --------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H
-#define MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H
+#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNSET_H
+#define MLIR_REWRITE_FROZENREWRITEPATTERNSET_H
 
 #include "mlir/IR/PatternMatch.h"
 
@@ -21,20 +21,20 @@ class PDLByteCode;
 /// such that they need not be continuously recomputed. Note that all copies of
 /// this class share the same compiled pattern list, allowing for a reduction in
 /// the number of duplicated patterns that need to be created.
-class FrozenRewritePatternList {
+class FrozenRewritePatternSet {
   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
   /// Freeze the patterns held in `patterns`, and take ownership.
-  FrozenRewritePatternList();
-  FrozenRewritePatternList(RewritePatternSet &&patterns);
-  FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default;
-  FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default;
-  FrozenRewritePatternList &
-  operator=(const FrozenRewritePatternList &patterns) = default;
-  FrozenRewritePatternList &
-  operator=(FrozenRewritePatternList &&patterns) = default;
-  ~FrozenRewritePatternList();
+  FrozenRewritePatternSet();
+  FrozenRewritePatternSet(RewritePatternSet &&patterns);
+  FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default;
+  FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default;
+  FrozenRewritePatternSet &
+  operator=(const FrozenRewritePatternSet &patterns) = default;
+  FrozenRewritePatternSet &
+  operator=(FrozenRewritePatternSet &&patterns) = default;
+  ~FrozenRewritePatternSet();
 
   /// Return the native patterns held by this list.
   iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
@@ -66,6 +66,10 @@ class FrozenRewritePatternList {
   std::shared_ptr<Impl> impl;
 };
 
+// TODO: FrozenRewritePatternList is soft-deprecated and will be removed in the
+// future.
+using FrozenRewritePatternList = FrozenRewritePatternSet;
+
 } // end namespace mlir
 
-#endif // MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H
+#endif // MLIR_REWRITE_FROZENREWRITEPATTERNSET_H

diff  --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h
index 9d197175b47d..9314496ecda1 100644
--- a/mlir/include/mlir/Rewrite/PatternApplicator.h
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -14,7 +14,7 @@
 #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
 #define MLIR_REWRITE_PATTERNAPPLICATOR_H
 
-#include "mlir/Rewrite/FrozenRewritePatternList.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
 
 namespace mlir {
 class PatternRewriter;
@@ -33,7 +33,7 @@ class PatternApplicator {
   /// `impossibleToMatch`.
   using CostModel = function_ref<PatternBenefit(const Pattern &)>;
 
-  explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList);
+  explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList);
   ~PatternApplicator();
 
   /// Attempt to match and rewrite the given op with any pattern, allowing a
@@ -65,7 +65,7 @@ class PatternApplicator {
 
 private:
   /// The list that owns the patterns used within this applicator.
-  const FrozenRewritePatternList &frozenPatternList;
+  const FrozenRewritePatternSet &frozenPatternList;
   /// The set of patterns to match for each operation, stable sorted by benefit.
   DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns;
   /// The set of patterns that may match against any operation type, stable

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ae86b2679eb3..7ebd07d8cb42 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -13,7 +13,7 @@
 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
 
-#include "mlir/Rewrite/FrozenRewritePatternList.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/StringMap.h"
 
@@ -842,11 +842,11 @@ class ConversionTarget {
 /// the `unconvertedOps` set will not necessarily be complete.)
 LogicalResult
 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                       const FrozenRewritePatternList &patterns,
+                       const FrozenRewritePatternSet &patterns,
                        DenseSet<Operation *> *unconvertedOps = nullptr);
 LogicalResult
 applyPartialConversion(Operation *op, ConversionTarget &target,
-                       const FrozenRewritePatternList &patterns,
+                       const FrozenRewritePatternSet &patterns,
                        DenseSet<Operation *> *unconvertedOps = nullptr);
 
 /// Apply a complete conversion on the given operations, and all nested
@@ -855,9 +855,9 @@ applyPartialConversion(Operation *op, ConversionTarget &target,
 /// within 'ops'.
 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
                                   ConversionTarget &target,
-                                  const FrozenRewritePatternList &patterns);
+                                  const FrozenRewritePatternSet &patterns);
 LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
-                                  const FrozenRewritePatternList &patterns);
+                                  const FrozenRewritePatternSet &patterns);
 
 /// Apply an analysis conversion on the given operations, and all nested
 /// operations. This method analyzes which operations would be successfully
@@ -869,10 +869,10 @@ LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
 /// the regions nested within 'ops'.
 LogicalResult applyAnalysisConversion(ArrayRef<Operation *> ops,
                                       ConversionTarget &target,
-                                      const FrozenRewritePatternList &patterns,
+                                      const FrozenRewritePatternSet &patterns,
                                       DenseSet<Operation *> &convertedOps);
 LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target,
-                                      const FrozenRewritePatternList &patterns,
+                                      const FrozenRewritePatternSet &patterns,
                                       DenseSet<Operation *> &convertedOps);
 } // end namespace mlir
 

diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index aa06c02c2b9e..cbbe5c10948d 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -14,7 +14,7 @@
 #ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
 #define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
 
-#include "mlir/Rewrite/FrozenRewritePatternList.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
 
 namespace mlir {
 
@@ -35,25 +35,25 @@ namespace mlir {
 ///       before attempting to match any of the provided patterns.
 LogicalResult
 applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternList &patterns,
+                             const FrozenRewritePatternSet &patterns,
                              bool useTopDownTraversal = true);
 
 /// Rewrite the regions of the specified operation, with a user-provided limit
 /// on iterations to attempt before reaching convergence.
 LogicalResult applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternList &patterns,
+    Operation *op, const FrozenRewritePatternSet &patterns,
     unsigned maxIterations, bool useTopDownTraversal = true);
 
 /// Rewrite the given regions, which must be isolated from above.
 LogicalResult
 applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const FrozenRewritePatternList &patterns,
+                             const FrozenRewritePatternSet &patterns,
                              bool useTopDownTraversal = true);
 
 /// Rewrite the given regions, with a user-provided limit on iterations to
 /// attempt before reaching convergence.
 LogicalResult applyPatternsAndFoldGreedily(
-    MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
+    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
     unsigned maxIterations, bool useTopDownTraversal = true);
 
 /// Applies the specified patterns on `op` alone while also trying to fold it,
@@ -62,7 +62,7 @@ LogicalResult applyPatternsAndFoldGreedily(
 /// was folded away or erased as a result of becoming dead. Note: This does not
 /// apply any patterns recursively to the regions of `op`.
 LogicalResult applyOpPatternsAndFold(Operation *op,
-                                     const FrozenRewritePatternList &patterns,
+                                     const FrozenRewritePatternSet &patterns,
                                      bool *erased = nullptr);
 
 } // end namespace mlir

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 1d4d9fe84cf1..1dedc2c39d8f 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -92,7 +92,7 @@ static LogicalResult applyPatterns(FuncOp func) {
 
   RewritePatternSet patterns(func.getContext());
   patterns.add<ParallelOpLowering>(func.getContext());
-  FrozenRewritePatternList frozen(std::move(patterns));
+  FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(func, target, frozen);
 }
 

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index cd966d404a47..851ec5051a6b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -230,7 +230,7 @@ void AffineDataCopyGeneration::runOnFunction() {
   RewritePatternSet patterns(&getContext());
   AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
   AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
-  FrozenRewritePatternList frozenPatterns(std::move(patterns));
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   for (Operation *op : copyOps)
     (void)applyOpPatternsAndFold(op, frozenPatterns);
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index c3ec72f51b3f..8f59074e6b79 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -83,7 +83,7 @@ void SimplifyAffineStructures::runOnFunction() {
   AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
   AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
   AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
-  FrozenRewritePatternList frozenPatterns(std::move(patterns));
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   func.walk([&](Operation *op) {
     for (auto attr : op->getAttrs()) {
       if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 8e2645a2d44a..522cfd7fca95 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -191,7 +191,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
   RewritePatternSet patterns(ifOp.getContext());
   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
   bool erased;
-  FrozenRewritePatternList frozenPatterns(std::move(patterns));
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
   if (erased) {
     if (folded)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index cd4f65953c0a..e31a6b5210e3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -31,7 +31,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
   // Emplace patterns one at a time while also maintaining a simple chained
   // state transition.
   unsigned stepCount = 0;
-  SmallVector<FrozenRewritePatternList, 4> stage1Patterns;
+  SmallVector<FrozenRewritePatternSet, 4> stage1Patterns;
   auto zeroState = Identifier::get(std::to_string(stepCount), context);
   auto currentState = zeroState;
   for (const std::unique_ptr<Transformation> &t : transformationSequence) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 965275dc2bcc..4202cb268576 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -479,8 +479,8 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
 }
 
 LogicalResult mlir::linalg::applyStagedPatterns(
-    Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns,
-    const FrozenRewritePatternList &stage2Patterns,
+    Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
+    const FrozenRewritePatternSet &stage2Patterns,
     function_ref<LogicalResult(Operation *)> stage3Lambda) {
   unsigned iteration = 0;
   (void)iteration;

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index 87aa623b7abc..372295a986af 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -106,7 +106,7 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
 
   // TODO: Change the type for the indirect users such as spv.Load, spv.Store,
   // spv.FunctionCall and so on.
-  FrozenRewritePatternList frozenPatterns(std::move(patterns));
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   for (auto spirvModule : module.getOps<spirv::ModuleOp>())
     if (failed(applyFullConversion(spirvModule, target, frozenPatterns)))
       signalPassFailure();

diff  --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt
index 5822789cc916..76bf64944d50 100644
--- a/mlir/lib/Rewrite/CMakeLists.txt
+++ b/mlir/lib/Rewrite/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_library(MLIRRewrite
   ByteCode.cpp
-  FrozenRewritePatternList.cpp
+  FrozenRewritePatternSet.cpp
   PatternApplicator.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
similarity index 87%
rename from mlir/lib/Rewrite/FrozenRewritePatternList.cpp
rename to mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index b61307b81b9f..9c81363f13f2 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -1,4 +1,4 @@
-//===- FrozenRewritePatternList.cpp - Frozen Pattern List -------*- C++ -*-===//
+//===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Rewrite/FrozenRewritePatternList.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "ByteCode.h"
 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
 #include "mlir/Dialect/PDL/IR/PDLOps.h"
@@ -47,13 +47,13 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
 }
 
 //===----------------------------------------------------------------------===//
-// FrozenRewritePatternList
+// FrozenRewritePatternSet
 //===----------------------------------------------------------------------===//
 
-FrozenRewritePatternList::FrozenRewritePatternList()
+FrozenRewritePatternSet::FrozenRewritePatternSet()
     : impl(std::make_shared<Impl>()) {}
 
-FrozenRewritePatternList::FrozenRewritePatternList(RewritePatternSet &&patterns)
+FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
     : impl(std::make_shared<Impl>()) {
   impl->nativePatterns = std::move(patterns.getNativePatterns());
 
@@ -72,4 +72,4 @@ FrozenRewritePatternList::FrozenRewritePatternList(RewritePatternSet &&patterns)
       pdlPatterns.takeRewriteFunctions());
 }
 
-FrozenRewritePatternList::~FrozenRewritePatternList() {}
+FrozenRewritePatternSet::~FrozenRewritePatternSet() {}

diff  --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 5032f0203257..3db598883360 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -19,7 +19,7 @@ using namespace mlir;
 using namespace mlir::detail;
 
 PatternApplicator::PatternApplicator(
-    const FrozenRewritePatternList &frozenPatternList)
+    const FrozenRewritePatternSet &frozenPatternList)
     : frozenPatternList(frozenPatternList) {
   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
     mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 5b6edf9894ab..2d987778f22f 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -35,7 +35,7 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);
   }
 
-  FrozenRewritePatternList patterns;
+  FrozenRewritePatternSet patterns;
 };
 } // end anonymous namespace
 

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d6037f563f87..41d3eabb07ea 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1506,7 +1506,7 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(ConversionTarget &targetInfo,
-                     const FrozenRewritePatternList &patterns);
+                     const FrozenRewritePatternSet &patterns);
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1602,7 +1602,7 @@ class OperationLegalizer {
 } // namespace
 
 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
-                                       const FrozenRewritePatternList &patterns)
+                                       const FrozenRewritePatternSet &patterns)
     : target(targetInfo), applicator(patterns) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
@@ -2125,7 +2125,7 @@ enum OpConversionMode {
 // conversion mode.
 struct OperationConverter {
   explicit OperationConverter(ConversionTarget &target,
-                              const FrozenRewritePatternList &patterns,
+                              const FrozenRewritePatternSet &patterns,
                               OpConversionMode mode,
                               DenseSet<Operation *> *trackedOps = nullptr)
       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
@@ -2755,7 +2755,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
 LogicalResult
 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
                              ConversionTarget &target,
-                             const FrozenRewritePatternList &patterns,
+                             const FrozenRewritePatternSet &patterns,
                              DenseSet<Operation *> *unconvertedOps) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
                                  unconvertedOps);
@@ -2763,7 +2763,7 @@ mlir::applyPartialConversion(ArrayRef<Operation *> ops,
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
-                             const FrozenRewritePatternList &patterns,
+                             const FrozenRewritePatternSet &patterns,
                              DenseSet<Operation *> *unconvertedOps) {
   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
                                 unconvertedOps);
@@ -2774,13 +2774,13 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
 /// operation fails.
 LogicalResult
 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                          const FrozenRewritePatternList &patterns) {
+                          const FrozenRewritePatternSet &patterns) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
   return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
-                          const FrozenRewritePatternList &patterns) {
+                          const FrozenRewritePatternSet &patterns) {
   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
 }
 
@@ -2793,7 +2793,7 @@ mlir::applyFullConversion(Operation *op, ConversionTarget &target,
 LogicalResult
 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
                               ConversionTarget &target,
-                              const FrozenRewritePatternList &patterns,
+                              const FrozenRewritePatternSet &patterns,
                               DenseSet<Operation *> &convertedOps) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
                                  &convertedOps);
@@ -2801,7 +2801,7 @@ mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
-                              const FrozenRewritePatternList &patterns,
+                              const FrozenRewritePatternSet &patterns,
                               DenseSet<Operation *> &convertedOps) {
   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
                                  convertedOps);

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c4b5fe043e48..f28f228737a8 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -37,7 +37,7 @@ namespace {
 class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
-                                      const FrozenRewritePatternList &patterns,
+                                      const FrozenRewritePatternSet &patterns,
                                       bool useTopDownTraversal)
       : PatternRewriter(ctx), matcher(patterns), folder(ctx),
         useTopDownTraversal(useTopDownTraversal) {
@@ -242,13 +242,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 ///
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(Operation *op,
-                                   const FrozenRewritePatternList &patterns,
+                                   const FrozenRewritePatternSet &patterns,
                                    bool useTopDownTraversal) {
   return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
                                       useTopDownTraversal);
 }
 LogicalResult mlir::applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternList &patterns,
+    Operation *op, const FrozenRewritePatternSet &patterns,
     unsigned maxIterations, bool useTopDownTraversal) {
   return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
                                       useTopDownTraversal);
@@ -256,13 +256,13 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(
 /// Rewrite the given regions, which must be isolated from above.
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                                   const FrozenRewritePatternList &patterns,
+                                   const FrozenRewritePatternSet &patterns,
                                    bool useTopDownTraversal) {
   return applyPatternsAndFoldGreedily(
       regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
 }
 LogicalResult mlir::applyPatternsAndFoldGreedily(
-    MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
+    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
     unsigned maxIterations, bool useTopDownTraversal) {
   if (regions.empty())
     return success();
@@ -298,7 +298,7 @@ namespace {
 class OpPatternRewriteDriver : public PatternRewriter {
 public:
   explicit OpPatternRewriteDriver(MLIRContext *ctx,
-                                  const FrozenRewritePatternList &patterns)
+                                  const FrozenRewritePatternSet &patterns)
       : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
     // Apply a simple cost model based solely on pattern benefit.
     matcher.applyDefaultCostModel();
@@ -382,7 +382,7 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
 /// folding. `erased` is set to true if the op is erased as a result of being
 /// folded, replaced, or dead.
 LogicalResult mlir::applyOpPatternsAndFold(
-    Operation *op, const FrozenRewritePatternList &patterns, bool *erased) {
+    Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
   // Start the pattern driver.
   OpPatternRewriteDriver driver(op->getContext(), patterns);
   bool opErased;

diff  --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index 55464283ff7d..7bf298904780 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -61,7 +61,7 @@ void TestConvVectorization::runOnOperation() {
 
   SmallVector<RewritePatternSet, 4> stage1Patterns;
   linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
-  SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
+  SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
 
   RewritePatternSet stage2Patterns =

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 23e6e0056627..e752c46ecea9 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -184,7 +184,7 @@ struct TestLinalgGreedyFusion
     RewritePatternSet patterns =
         linalg::getLinalgTilingCanonicalizationPatterns(context);
     patterns.add<AffineMinSCFCanonicalizationPattern>(context);
-    FrozenRewritePatternList frozenPatterns(std::move(patterns));
+    FrozenRewritePatternSet frozenPatterns(std::move(patterns));
     while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
       (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
       PassManager pm(context);

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index a9765ce8c9a4..0bb46455b9ca 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -478,9 +478,9 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
     fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
                                           stage1Patterns);
   }
-  SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
+  SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
   llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
-  FrozenRewritePatternList stage2Patterns =
+  FrozenRewritePatternSet stage2Patterns =
       getLinalgTilingCanonicalizationPatterns(ctx);
   (void)applyStagedPatterns(funcOp, frozenStage1Patterns,
                             std::move(stage2Patterns));
@@ -505,7 +505,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
   RewritePatternSet foldPattern(funcOp.getContext());
   foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
-  FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
+  FrozenRewritePatternSet frozenPatterns(std::move(foldPattern));
 
   // Explicitly walk and apply the pattern locally to avoid more general folding
   // on the rest of the IR.

diff  --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp
index 9461e2f0ff8b..0d2f74ae4890 100644
--- a/mlir/unittests/Rewrite/PatternBenefit.cpp
+++ b/mlir/unittests/Rewrite/PatternBenefit.cpp
@@ -60,7 +60,7 @@ TEST(PatternBenefitTest, BenefitOrder) {
   patterns.add<Pattern1>(&context, &called1);
   patterns.add<Pattern2>(&called2);
 
-  FrozenRewritePatternList frozenPatterns(std::move(patterns));
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   PatternApplicator pa(frozenPatterns);
   pa.applyDefaultCostModel();
 


        


More information about the Mlir-commits mailing list