[Mlir-commits] [mlir] b6eb26f - [mlir][NFC] Move around the code related to PatternRewriting to improve layering

River Riddle llvmlistbot at llvm.org
Mon Oct 26 18:05:25 PDT 2020


Author: River Riddle
Date: 2020-10-26T18:01:06-07:00
New Revision: b6eb26fd0e316b36e3750f7cba7ebdb10219790c

URL: https://github.com/llvm/llvm-project/commit/b6eb26fd0e316b36e3750f7cba7ebdb10219790c
DIFF: https://github.com/llvm/llvm-project/commit/b6eb26fd0e316b36e3750f7cba7ebdb10219790c.diff

LOG: [mlir][NFC] Move around the code related to PatternRewriting to improve layering

There are several pieces of pattern rewriting infra in IR/ that really shouldn't be there. This revision moves those pieces to a better location such that they are easier to evolve in the future(e.g. with PDL). More concretely this revision does the following:

* Create a Transforms/GreedyPatternRewriteDriver.h and move the apply*andFold methods there.
The definitions for these methods are already in Transforms/ so it doesn't make sense for the declarations to be in IR.

* Create a new lib/Rewrite library and move PatternApplicator there.
This new library will be focused on applying rewrites, and will also include compiling rewrites with PDL.

Differential Revision: https://reviews.llvm.org/D89103

Added: 
    mlir/include/mlir/Rewrite/PatternApplicator.h
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Rewrite/CMakeLists.txt
    mlir/lib/Rewrite/PatternApplicator.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/CMakeLists.txt
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
    mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
    mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
    mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/lib/Transforms/Utils/CMakeLists.txt
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Dialect/Test/TestTraits.cpp
    mlir/test/lib/Transforms/TestConvVectorization.cpp
    mlir/test/lib/Transforms/TestExpandTanh.cpp
    mlir/test/lib/Transforms/TestGpuRewrite.cpp
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    mlir/lib/Transforms/DialectConversion.cpp


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ef6e3bd86258..74300f2a1882 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -416,12 +416,9 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
   void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
 };
 
-//===----------------------------------------------------------------------===//
-// Pattern-driven rewriters
-//===----------------------------------------------------------------------===//
-
 //===----------------------------------------------------------------------===//
 // OwningRewritePatternList
+//===----------------------------------------------------------------------===//
 
 class OwningRewritePatternList {
   using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
@@ -481,98 +478,6 @@ class OwningRewritePatternList {
   PatternListT patterns;
 };
 
-//===----------------------------------------------------------------------===//
-// PatternApplicator
-
-/// This class manages the application of a group of rewrite patterns, with a
-/// user-provided cost model.
-class PatternApplicator {
-public:
-  /// The cost model dynamically assigns a PatternBenefit to a particular
-  /// pattern. Users can query contained patterns and pass analysis results to
-  /// applyCostModel. Patterns to be discarded should have a benefit of
-  /// `impossibleToMatch`.
-  using CostModel = function_ref<PatternBenefit(const Pattern &)>;
-
-  explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
-      : owningPatternList(owningPatternList) {}
-
-  /// Attempt to match and rewrite the given op with any pattern, allowing a
-  /// predicate to decide if a pattern can be applied or not, and hooks for if
-  /// the pattern match was a success or failure.
-  ///
-  /// canApply:  called before each match and rewrite attempt; return false to
-  ///            skip pattern.
-  /// onFailure: called when a pattern fails to match to perform cleanup.
-  /// onSuccess: called when a pattern match succeeds; return failure() to
-  ///            invalidate the match and try another pattern.
-  LogicalResult
-  matchAndRewrite(Operation *op, PatternRewriter &rewriter,
-                  function_ref<bool(const Pattern &)> canApply = {},
-                  function_ref<void(const Pattern &)> onFailure = {},
-                  function_ref<LogicalResult(const Pattern &)> onSuccess = {});
-
-  /// Apply a cost model to the patterns within this applicator.
-  void applyCostModel(CostModel model);
-
-  /// Apply the default cost model that solely uses the pattern's static
-  /// benefit.
-  void applyDefaultCostModel() {
-    applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
-  }
-
-  /// Walk all of the patterns within the applicator.
-  void walkAllPatterns(function_ref<void(const Pattern &)> walk);
-
-private:
-  /// Attempt to match and rewrite the given op with the given pattern, allowing
-  /// a predicate to decide if a pattern can be applied or not, and hooks for if
-  /// the pattern match was a success or failure.
-  LogicalResult
-  matchAndRewrite(Operation *op, const RewritePattern &pattern,
-                  PatternRewriter &rewriter,
-                  function_ref<bool(const Pattern &)> canApply,
-                  function_ref<void(const Pattern &)> onFailure,
-                  function_ref<LogicalResult(const Pattern &)> onSuccess);
-
-  /// The list that owns the patterns used within this applicator.
-  const OwningRewritePatternList &owningPatternList;
-
-  /// The set of patterns to match for each operation, stable sorted by benefit.
-  DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
-  /// The set of patterns that may match against any operation type, stable
-  /// sorted by benefit.
-  SmallVector<RewritePattern *, 1> anyOpPatterns;
-};
-
-//===----------------------------------------------------------------------===//
-// applyPatternsGreedily
-//===----------------------------------------------------------------------===//
-
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return success if no more patterns can be matched
-/// in the result operation regions.
-/// Note: This does not apply patterns to the top-level operation itself. Note:
-///       These methods also perform folding and simple dead-code elimination
-///       before attempting to match any of the provided patterns.
-///
-LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
-                             const OwningRewritePatternList &patterns);
-/// Rewrite the given regions, which must be isolated from above.
-LogicalResult
-applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const OwningRewritePatternList &patterns);
-
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead. Note: This does not
-/// apply any patterns recursively to the regions of `op`.
-LogicalResult applyOpPatternsAndFold(Operation *op,
-                                     const OwningRewritePatternList &patterns,
-                                     bool *erased = nullptr);
 } // end namespace mlir
 
 #endif // MLIR_PATTERN_MATCH_H

diff  --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h
new file mode 100644
index 000000000000..be5911966e06
--- /dev/null
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -0,0 +1,85 @@
+//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an applicator that applies pattern rewrites based upon a
+// user defined cost model.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
+#define MLIR_REWRITE_PATTERNAPPLICATOR_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class PatternRewriter;
+
+/// This class manages the application of a group of rewrite patterns, with a
+/// user-provided cost model.
+class PatternApplicator {
+public:
+  /// The cost model dynamically assigns a PatternBenefit to a particular
+  /// pattern. Users can query contained patterns and pass analysis results to
+  /// applyCostModel. Patterns to be discarded should have a benefit of
+  /// `impossibleToMatch`.
+  using CostModel = function_ref<PatternBenefit(const Pattern &)>;
+
+  explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
+      : owningPatternList(owningPatternList) {}
+
+  /// Attempt to match and rewrite the given op with any pattern, allowing a
+  /// predicate to decide if a pattern can be applied or not, and hooks for if
+  /// the pattern match was a success or failure.
+  ///
+  /// canApply:  called before each match and rewrite attempt; return false to
+  ///            skip pattern.
+  /// onFailure: called when a pattern fails to match to perform cleanup.
+  /// onSuccess: called when a pattern match succeeds; return failure() to
+  ///            invalidate the match and try another pattern.
+  LogicalResult
+  matchAndRewrite(Operation *op, PatternRewriter &rewriter,
+                  function_ref<bool(const Pattern &)> canApply = {},
+                  function_ref<void(const Pattern &)> onFailure = {},
+                  function_ref<LogicalResult(const Pattern &)> onSuccess = {});
+
+  /// Apply a cost model to the patterns within this applicator.
+  void applyCostModel(CostModel model);
+
+  /// Apply the default cost model that solely uses the pattern's static
+  /// benefit.
+  void applyDefaultCostModel() {
+    applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
+  }
+
+  /// Walk all of the patterns within the applicator.
+  void walkAllPatterns(function_ref<void(const Pattern &)> walk);
+
+private:
+  /// Attempt to match and rewrite the given op with the given pattern, allowing
+  /// a predicate to decide if a pattern can be applied or not, and hooks for if
+  /// the pattern match was a success or failure.
+  LogicalResult
+  matchAndRewrite(Operation *op, const RewritePattern &pattern,
+                  PatternRewriter &rewriter,
+                  function_ref<bool(const Pattern &)> canApply,
+                  function_ref<void(const Pattern &)> onFailure,
+                  function_ref<LogicalResult(const Pattern &)> onSuccess);
+
+  /// The list that owns the patterns used within this applicator.
+  const OwningRewritePatternList &owningPatternList;
+
+  /// The set of patterns to match for each operation, stable sorted by benefit.
+  DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
+  /// The set of patterns that may match against any operation type, stable
+  /// sorted by benefit.
+  SmallVector<RewritePattern *, 1> anyOpPatterns;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H

diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
new file mode 100644
index 000000000000..ab88f6a1e871
--- /dev/null
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -0,0 +1,52 @@
+//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares methods for applying a set of patterns greedily, choosing
+// the patterns with the highest local benefit, until a fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
+#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// applyPatternsGreedily
+//===----------------------------------------------------------------------===//
+
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return success if no more patterns can be matched
+/// in the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself. Note:
+///       These methods also perform folding and simple dead-code elimination
+///       before attempting to match any of the provided patterns.
+///
+LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+                             const OwningRewritePatternList &patterns);
+/// Rewrite the given regions, which must be isolated from above.
+LogicalResult
+applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+                             const OwningRewritePatternList &patterns);
+
+/// Applies the specified patterns on `op` alone while also trying to fold it,
+/// by selecting the highest benefits patterns in a greedy manner. Returns
+/// success if no more patterns can be matched. `erased` is set to true if `op`
+/// was folded away or erased as a result of becoming dead. Note: This does not
+/// apply any patterns recursively to the regions of `op`.
+LogicalResult applyOpPatternsAndFold(Operation *op,
+                                     const OwningRewritePatternList &patterns,
+                                     bool *erased = nullptr);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_

diff  --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index df3b2db98fcd..e7dc02070801 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(Interfaces)
 add_subdirectory(Parser)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)
+add_subdirectory(Rewrite)
 add_subdirectory(Support)
 add_subdirectory(TableGen)
 add_subdirectory(Target)

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 69786823dd32..2837ec14ef61 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/FormatVariadic.h"
 
 #include "../GPUCommon/GPUOpsLowering.h"

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e9b44a9fef52..62746663071f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/FormatVariadic.h"
 
 #include "../GPUCommon/GPUOpsLowering.h"

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index 3ea6233700c3..120f9879f9bd 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
@@ -123,7 +124,7 @@ class ConvertShapeConstraints
     OwningRewritePatternList patterns;
     populateConvertShapeConstraintsConversionPatterns(patterns, context);
 
-    if (failed(applyPatternsAndFoldGreedily(func, patterns)))
+    if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
       return signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 1cf3a326367c..13d74088ee9c 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -17,8 +17,8 @@
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 71eaf0d59e68..cab9526b5c7f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,16 +15,12 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
 #include "mlir/Target/LLVMIR/TypeTranslation.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Module.h"

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 8a767661d6d7..70bacf397109 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,14 +24,10 @@
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Types.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 using namespace mlir;

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 70c45b4d37af..20d3b3e9ccb9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -24,7 +24,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/Support/CommandLine.h"

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index d8ffb9742fae..4d50aa1ba1e5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Utils.h"
 
 #define DEBUG_TYPE "simplify-affine-structure"

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 390b3f0b4d4c..8f8cac9fe57e 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -16,7 +16,7 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index d27985e048da..5cf17855e356 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"
 #include "mlir/Transforms/Passes.h"
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 03fdfd4555f2..4104da6965f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -20,9 +20,8 @@
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 8542c2afb086..00eb9a2fe834 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -23,9 +23,9 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Dominance.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index cf56b0e551a0..139de8cc4651 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::linalg;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index ee1512c8ec89..368f4f2c66dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/TypeSwitch.h"
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index f7becae6e328..0cfffa79f73c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -22,8 +22,8 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/Support/CommandLine.h"
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 71e3108b2b58..8638d705d132 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -21,9 +21,9 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <type_traits>

diff  --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index 88eb314a852b..0879b73846d7 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -12,10 +12,9 @@
 #include "mlir/Dialect/Quant/QuantizeUtils.h"
 #include "mlir/Dialect/Quant/UniformSupport.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::quant;

diff  --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
index 9810ea2a5358..055e4759b87b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
@@ -11,9 +11,8 @@
 #include "mlir/Dialect/Quant/Passes.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::quant;

diff  --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
index 641b4bc38e43..41c372c67f6a 100644
--- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 136d01966688..02545e18642e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -8,14 +8,9 @@
 
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Value.h"
-#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 
-#define DEBUG_TYPE "pattern-match"
-
 //===----------------------------------------------------------------------===//
 // PatternBenefit
 //===----------------------------------------------------------------------===//
@@ -205,135 +200,3 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
-//===----------------------------------------------------------------------===//
-// PatternApplicator
-//===----------------------------------------------------------------------===//
-
-void PatternApplicator::applyCostModel(CostModel model) {
-  // Separate patterns by root kind to simplify lookup later on.
-  patterns.clear();
-  anyOpPatterns.clear();
-  for (const auto &pat : owningPatternList) {
-    // If the pattern is always impossible to match, just ignore it.
-    if (pat->getBenefit().isImpossibleToMatch()) {
-      LLVM_DEBUG({
-        llvm::dbgs()
-            << "Ignoring pattern '" << pat->getRootKind()
-            << "' because it is impossible to match (by pattern benefit)\n";
-      });
-      continue;
-    }
-    if (Optional<OperationName> opName = pat->getRootKind())
-      patterns[*opName].push_back(pat.get());
-    else
-      anyOpPatterns.push_back(pat.get());
-  }
-
-  // Sort the patterns using the provided cost model.
-  llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
-  auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
-    return benefits[lhs] > benefits[rhs];
-  };
-  auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
-    // Special case for one pattern in the list, which is the most common case.
-    if (list.size() == 1) {
-      if (model(*list.front()).isImpossibleToMatch()) {
-        LLVM_DEBUG({
-          llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
-                       << "' because it is impossible to match or cannot lead "
-                          "to legal IR (by cost model)\n";
-        });
-        list.clear();
-      }
-      return;
-    }
-
-    // Collect the dynamic benefits for the current pattern list.
-    benefits.clear();
-    for (RewritePattern *pat : list)
-      benefits.try_emplace(pat, model(*pat));
-
-    // Sort patterns with highest benefit first, and remove those that are
-    // impossible to match.
-    std::stable_sort(list.begin(), list.end(), cmp);
-    while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
-      LLVM_DEBUG({
-        llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
-                     << "' because it is impossible to match or cannot lead to "
-                        "legal IR (by cost model)\n";
-      });
-      list.pop_back();
-    }
-  };
-  for (auto &it : patterns)
-    processPatternList(it.second);
-  processPatternList(anyOpPatterns);
-}
-
-void PatternApplicator::walkAllPatterns(
-    function_ref<void(const Pattern &)> walk) {
-  for (auto &it : owningPatternList)
-    walk(*it);
-}
-
-LogicalResult PatternApplicator::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter,
-    function_ref<bool(const Pattern &)> canApply,
-    function_ref<void(const Pattern &)> onFailure,
-    function_ref<LogicalResult(const Pattern &)> onSuccess) {
-  // Check to see if there are patterns matching this specific operation type.
-  MutableArrayRef<RewritePattern *> opPatterns;
-  auto patternIt = patterns.find(op->getName());
-  if (patternIt != patterns.end())
-    opPatterns = patternIt->second;
-
-  // Process the patterns for that match the specific operation type, and any
-  // operation type in an interleaved fashion.
-  // FIXME: It'd be nice to just write an llvm::make_merge_range utility
-  // and pass in a comparison function. That would make this code trivial.
-  auto opIt = opPatterns.begin(), opE = opPatterns.end();
-  auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
-  while (opIt != opE && anyIt != anyE) {
-    // Try to match the pattern providing the most benefit.
-    RewritePattern *pattern;
-    if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
-      pattern = *(opIt++);
-    else
-      pattern = *(anyIt++);
-
-    // Otherwise, try to match the generic pattern.
-    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
-                                  onSuccess)))
-      return success();
-  }
-  // If we break from the loop, then only one of the ranges can still have
-  // elements. Loop over both without checking given that we don't need to
-  // interleave anymore.
-  for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
-           llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
-    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
-                                  onSuccess)))
-      return success();
-  }
-  return failure();
-}
-
-LogicalResult PatternApplicator::matchAndRewrite(
-    Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
-    function_ref<bool(const Pattern &)> canApply,
-    function_ref<void(const Pattern &)> onFailure,
-    function_ref<LogicalResult(const Pattern &)> onSuccess) {
-  // Check that the pattern can be applied.
-  if (canApply && !canApply(pattern))
-    return failure();
-
-  // Try to match and rewrite this pattern. The patterns are sorted by
-  // benefit, so if we match we can immediately rewrite.
-  rewriter.setInsertionPoint(op);
-  if (succeeded(pattern.matchAndRewrite(op, rewriter)))
-    return success(!onSuccess || succeeded(onSuccess(pattern)));
-
-  if (onFailure)
-    onFailure(pattern);
-  return failure();
-}

diff  --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt
new file mode 100644
index 000000000000..aa6632b63f09
--- /dev/null
+++ b/mlir/lib/Rewrite/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_library(MLIRRewrite
+  PatternApplicator.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
+
+  DEPENDS
+  mlir-generic-headers
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  )

diff  --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
new file mode 100644
index 000000000000..f9c0cbfed880
--- /dev/null
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -0,0 +1,148 @@
+//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an applicator that applies pattern rewrites based upon a
+// user defined cost model.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "pattern-match"
+
+void PatternApplicator::applyCostModel(CostModel model) {
+  // Separate patterns by root kind to simplify lookup later on.
+  patterns.clear();
+  anyOpPatterns.clear();
+  for (const auto &pat : owningPatternList) {
+    // If the pattern is always impossible to match, just ignore it.
+    if (pat->getBenefit().isImpossibleToMatch()) {
+      LLVM_DEBUG({
+        llvm::dbgs()
+            << "Ignoring pattern '" << pat->getRootKind()
+            << "' because it is impossible to match (by pattern benefit)\n";
+      });
+      continue;
+    }
+    if (Optional<OperationName> opName = pat->getRootKind())
+      patterns[*opName].push_back(pat.get());
+    else
+      anyOpPatterns.push_back(pat.get());
+  }
+
+  // Sort the patterns using the provided cost model.
+  llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
+  auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
+    return benefits[lhs] > benefits[rhs];
+  };
+  auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
+    // Special case for one pattern in the list, which is the most common case.
+    if (list.size() == 1) {
+      if (model(*list.front()).isImpossibleToMatch()) {
+        LLVM_DEBUG({
+          llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
+                       << "' because it is impossible to match or cannot lead "
+                          "to legal IR (by cost model)\n";
+        });
+        list.clear();
+      }
+      return;
+    }
+
+    // Collect the dynamic benefits for the current pattern list.
+    benefits.clear();
+    for (RewritePattern *pat : list)
+      benefits.try_emplace(pat, model(*pat));
+
+    // Sort patterns with highest benefit first, and remove those that are
+    // impossible to match.
+    std::stable_sort(list.begin(), list.end(), cmp);
+    while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
+                     << "' because it is impossible to match or cannot lead to "
+                        "legal IR (by cost model)\n";
+      });
+      list.pop_back();
+    }
+  };
+  for (auto &it : patterns)
+    processPatternList(it.second);
+  processPatternList(anyOpPatterns);
+}
+
+void PatternApplicator::walkAllPatterns(
+    function_ref<void(const Pattern &)> walk) {
+  for (auto &it : owningPatternList)
+    walk(*it);
+}
+
+LogicalResult PatternApplicator::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter,
+    function_ref<bool(const Pattern &)> canApply,
+    function_ref<void(const Pattern &)> onFailure,
+    function_ref<LogicalResult(const Pattern &)> onSuccess) {
+  // Check to see if there are patterns matching this specific operation type.
+  MutableArrayRef<RewritePattern *> opPatterns;
+  auto patternIt = patterns.find(op->getName());
+  if (patternIt != patterns.end())
+    opPatterns = patternIt->second;
+
+  // Process the patterns for that match the specific operation type, and any
+  // operation type in an interleaved fashion.
+  // FIXME: It'd be nice to just write an llvm::make_merge_range utility
+  // and pass in a comparison function. That would make this code trivial.
+  auto opIt = opPatterns.begin(), opE = opPatterns.end();
+  auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
+  while (opIt != opE && anyIt != anyE) {
+    // Try to match the pattern providing the most benefit.
+    RewritePattern *pattern;
+    if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
+      pattern = *(opIt++);
+    else
+      pattern = *(anyIt++);
+
+    // Otherwise, try to match the generic pattern.
+    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
+                                  onSuccess)))
+      return success();
+  }
+  // If we break from the loop, then only one of the ranges can still have
+  // elements. Loop over both without checking given that we don't need to
+  // interleave anymore.
+  for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
+           llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
+    if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
+                                  onSuccess)))
+      return success();
+  }
+  return failure();
+}
+
+LogicalResult PatternApplicator::matchAndRewrite(
+    Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
+    function_ref<bool(const Pattern &)> canApply,
+    function_ref<void(const Pattern &)> onFailure,
+    function_ref<LogicalResult(const Pattern &)> onSuccess) {
+  // Check that the pattern can be applied.
+  if (canApply && !canApply(pattern))
+    return failure();
+
+  // Try to match and rewrite this pattern. The patterns are sorted by
+  // benefit, so if we match we can immediately rewrite.
+  rewriter.setInsertionPoint(op);
+  if (succeeded(pattern.matchAndRewrite(op, rewriter)))
+    return success(!onSuccess || succeeded(onSuccess(pattern)));
+
+  if (onFailure)
+    onFailure(pattern);
+  return failure();
+}

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index a2b96b90b543..b36f150134e0 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -7,7 +7,6 @@ add_mlir_library(MLIRTransforms
   Canonicalizer.cpp
   CopyRemoval.cpp
   CSE.cpp
-  DialectConversion.cpp
   Inliner.cpp
   LocationSnapshot.cpp
   LoopCoalescing.cpp

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 9b028bfa2525..de435440607f 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -12,8 +12,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 using namespace mlir;

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index d4ab3cc4d549..ff104071f444 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Analysis/CallGraph.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/SCCIterator.h"

diff  --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 9fa59bbde55a..6dbb74b4905a 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRTransformUtils
+  DialectConversion.cpp
   FoldUtils.cpp
   GreedyPatternRewriteDriver.cpp
   InliningUtils.cpp
@@ -19,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
   MLIRLoopAnalysis
   MLIRSCF
   MLIRPass
+  MLIRRewrite
   MLIRStandard
   )

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
similarity index 98%
rename from mlir/lib/Transforms/DialectConversion.cpp
rename to mlir/lib/Transforms/Utils/DialectConversion.cpp
index 692cd494324e..1d649fd3a02e 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
+#include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -74,8 +75,7 @@ computeConversionSet(iterator_range<Region::iterator> region,
 
 /// A utility function to log a successful result for the given reason.
 template <typename... Args>
-static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
-                       Args &&... args) {
+static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   LLVM_DEBUG({
     os.unindent();
     os.startLine() << "} -> SUCCESS";
@@ -88,8 +88,7 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
 
 /// A utility function to log a failure result for the given reason.
 template <typename... Args>
-static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
-                       Args &&... args) {
+static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   LLVM_DEBUG({
     os.unindent();
     os.startLine() << "} -> FAILURE : "
@@ -2033,21 +2032,21 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
     return minDepth;
 
   // Sort the patterns by those likely to be the most beneficial.
-  llvm::array_pod_sort(
-      patternsByDepth.begin(), patternsByDepth.end(),
-      [](const std::pair<const Pattern *, unsigned> *lhs,
-         const std::pair<const Pattern *, unsigned> *rhs) {
-        // First sort by the smaller pattern legalization depth.
-        if (lhs->second != rhs->second)
-          return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
-                                                           &rhs->second);
-
-        // Then sort by the larger pattern benefit.
-        auto lhsBenefit = lhs->first->getBenefit();
-        auto rhsBenefit = rhs->first->getBenefit();
-        return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
-                                                               &lhsBenefit);
-      });
+  llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
+                       [](const std::pair<const Pattern *, unsigned> *lhs,
+                          const std::pair<const Pattern *, unsigned> *rhs) {
+                         // First sort by the smaller pattern legalization
+                         // depth.
+                         if (lhs->second != rhs->second)
+                           return llvm::array_pod_sort_comparator<unsigned>(
+                               &lhs->second, &rhs->second);
+
+                         // Then sort by the larger pattern benefit.
+                         auto lhsBenefit = lhs->first->getBenefit();
+                         auto rhsBenefit = rhs->first->getBenefit();
+                         return llvm::array_pod_sort_comparator<PatternBenefit>(
+                             &rhsBenefit, &lhsBenefit);
+                       });
 
   // Update the legalization pattern to use the new sorted list.
   patterns.clear();

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c6276395d2c2..199cfbf3d1f0 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -10,8 +10,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/DenseMap.h"

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 8f4b0f800c05..201ba22e2959 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -23,8 +23,8 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/MathExtras.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/DenseMap.h"

diff  --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index b6df06ed9c54..03433f19f8d6 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -13,8 +13,8 @@
 
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"
 #include "mlir/Transforms/Passes.h"
 

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 04a21ec01fe6..2099b368eddb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -10,10 +10,10 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index 3cbc95ce6c74..b78884a5479d 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -7,9 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
@@ -25,9 +24,9 @@ OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
 
 OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
     ArrayRef<Attribute> operands) {
-  auto argument_op = getOperand();
+  auto argumentOp = getOperand();
   // The success case should cause the trait fold to be supressed.
-  return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
+  return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
 }
 
 namespace {

diff  --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index 79b6464f3b4c..5c16677c8849 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"
 #include "mlir/Transforms/Passes.h"
 
@@ -93,7 +94,7 @@ void TestConvVectorization::runOnOperation() {
   // VectorTransforms.cpp
   vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
       context, vectorTransformsOptions);
-  applyPatternsAndFoldGreedily(module, vectorTransferPatterns);
+  applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
 
   // Programmatic controlled lowering of linalg.copy and linalg.fill.
   PassManager pm(context);
@@ -105,13 +106,14 @@ void TestConvVectorization::runOnOperation() {
   OwningRewritePatternList vectorContractLoweringPatterns;
   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
                                          context, vectorTransformsOptions);
-  applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
+  applyPatternsAndFoldGreedily(module,
+                               std::move(vectorContractLoweringPatterns));
 
   // Programmatic controlled lowering of vector.transfer only.
   OwningRewritePatternList vectorToLoopsPatterns;
   populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
                                         VectorTransferToSCFOptions());
-  applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
+  applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
 
   // Ensure we drop the marker in the end.
   module.walk([](linalg::LinalgOp op) {

diff  --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp
index c5f6e3a5ce30..ab485e58e895 100644
--- a/mlir/test/lib/Transforms/TestExpandTanh.cpp
+++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp
@@ -11,8 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 

diff  --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
index 13f0d2e39aae..eaa7149fa994 100644
--- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
@@ -12,8 +12,8 @@
 
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 9a376c548900..0fa4e22ebf37 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::linalg;

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index ffb0f92dae99..cc8c09d89a03 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -17,8 +17,8 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/SetVector.h"
 

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 7dae0254c487..989e8cda34f9 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -14,9 +14,8 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::vector;


        


More information about the Mlir-commits mailing list