[Mlir-commits] [mlir] b6eb26f - [mlir][NFC] Move around the code related to PatternRewriting to improve layering
River Riddle
llvmlistbot at llvm.org
Mon Oct 26 18:05:25 PDT 2020
Author: River Riddle
Date: 2020-10-26T18:01:06-07:00
New Revision: b6eb26fd0e316b36e3750f7cba7ebdb10219790c
URL: https://github.com/llvm/llvm-project/commit/b6eb26fd0e316b36e3750f7cba7ebdb10219790c
DIFF: https://github.com/llvm/llvm-project/commit/b6eb26fd0e316b36e3750f7cba7ebdb10219790c.diff
LOG: [mlir][NFC] Move around the code related to PatternRewriting to improve layering
There are several pieces of pattern rewriting infra in IR/ that really shouldn't be there. This revision moves those pieces to a better location such that they are easier to evolve in the future(e.g. with PDL). More concretely this revision does the following:
* Create a Transforms/GreedyPatternRewriteDriver.h and move the apply*andFold methods there.
The definitions for these methods are already in Transforms/ so it doesn't make sense for the declarations to be in IR.
* Create a new lib/Rewrite library and move PatternApplicator there.
This new library will be focused on applying rewrites, and will also include compiling rewrites with PDL.
Differential Revision: https://reviews.llvm.org/D89103
Added:
mlir/include/mlir/Rewrite/PatternApplicator.h
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Rewrite/CMakeLists.txt
mlir/lib/Rewrite/PatternApplicator.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/CMakeLists.txt
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/Canonicalizer.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/Utils/CMakeLists.txt
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Dialect/Test/TestTraits.cpp
mlir/test/lib/Transforms/TestConvVectorization.cpp
mlir/test/lib/Transforms/TestExpandTanh.cpp
mlir/test/lib/Transforms/TestGpuRewrite.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
mlir/lib/Transforms/DialectConversion.cpp
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ef6e3bd86258..74300f2a1882 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -416,12 +416,9 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};
-//===----------------------------------------------------------------------===//
-// Pattern-driven rewriters
-//===----------------------------------------------------------------------===//
-
//===----------------------------------------------------------------------===//
// OwningRewritePatternList
+//===----------------------------------------------------------------------===//
class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
@@ -481,98 +478,6 @@ class OwningRewritePatternList {
PatternListT patterns;
};
-//===----------------------------------------------------------------------===//
-// PatternApplicator
-
-/// This class manages the application of a group of rewrite patterns, with a
-/// user-provided cost model.
-class PatternApplicator {
-public:
- /// The cost model dynamically assigns a PatternBenefit to a particular
- /// pattern. Users can query contained patterns and pass analysis results to
- /// applyCostModel. Patterns to be discarded should have a benefit of
- /// `impossibleToMatch`.
- using CostModel = function_ref<PatternBenefit(const Pattern &)>;
-
- explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
- : owningPatternList(owningPatternList) {}
-
- /// Attempt to match and rewrite the given op with any pattern, allowing a
- /// predicate to decide if a pattern can be applied or not, and hooks for if
- /// the pattern match was a success or failure.
- ///
- /// canApply: called before each match and rewrite attempt; return false to
- /// skip pattern.
- /// onFailure: called when a pattern fails to match to perform cleanup.
- /// onSuccess: called when a pattern match succeeds; return failure() to
- /// invalidate the match and try another pattern.
- LogicalResult
- matchAndRewrite(Operation *op, PatternRewriter &rewriter,
- function_ref<bool(const Pattern &)> canApply = {},
- function_ref<void(const Pattern &)> onFailure = {},
- function_ref<LogicalResult(const Pattern &)> onSuccess = {});
-
- /// Apply a cost model to the patterns within this applicator.
- void applyCostModel(CostModel model);
-
- /// Apply the default cost model that solely uses the pattern's static
- /// benefit.
- void applyDefaultCostModel() {
- applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
- }
-
- /// Walk all of the patterns within the applicator.
- void walkAllPatterns(function_ref<void(const Pattern &)> walk);
-
-private:
- /// Attempt to match and rewrite the given op with the given pattern, allowing
- /// a predicate to decide if a pattern can be applied or not, and hooks for if
- /// the pattern match was a success or failure.
- LogicalResult
- matchAndRewrite(Operation *op, const RewritePattern &pattern,
- PatternRewriter &rewriter,
- function_ref<bool(const Pattern &)> canApply,
- function_ref<void(const Pattern &)> onFailure,
- function_ref<LogicalResult(const Pattern &)> onSuccess);
-
- /// The list that owns the patterns used within this applicator.
- const OwningRewritePatternList &owningPatternList;
-
- /// The set of patterns to match for each operation, stable sorted by benefit.
- DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
- /// The set of patterns that may match against any operation type, stable
- /// sorted by benefit.
- SmallVector<RewritePattern *, 1> anyOpPatterns;
-};
-
-//===----------------------------------------------------------------------===//
-// applyPatternsGreedily
-//===----------------------------------------------------------------------===//
-
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return success if no more patterns can be matched
-/// in the result operation regions.
-/// Note: This does not apply patterns to the top-level operation itself. Note:
-/// These methods also perform folding and simple dead-code elimination
-/// before attempting to match any of the provided patterns.
-///
-LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
- const OwningRewritePatternList &patterns);
-/// Rewrite the given regions, which must be isolated from above.
-LogicalResult
-applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
- const OwningRewritePatternList &patterns);
-
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead. Note: This does not
-/// apply any patterns recursively to the regions of `op`.
-LogicalResult applyOpPatternsAndFold(Operation *op,
- const OwningRewritePatternList &patterns,
- bool *erased = nullptr);
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H
diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h
new file mode 100644
index 000000000000..be5911966e06
--- /dev/null
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -0,0 +1,85 @@
+//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an applicator that applies pattern rewrites based upon a
+// user defined cost model.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
+#define MLIR_REWRITE_PATTERNAPPLICATOR_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class PatternRewriter;
+
+/// This class manages the application of a group of rewrite patterns, with a
+/// user-provided cost model.
+class PatternApplicator {
+public:
+ /// The cost model dynamically assigns a PatternBenefit to a particular
+ /// pattern. Users can query contained patterns and pass analysis results to
+ /// applyCostModel. Patterns to be discarded should have a benefit of
+ /// `impossibleToMatch`.
+ using CostModel = function_ref<PatternBenefit(const Pattern &)>;
+
+ explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
+ : owningPatternList(owningPatternList) {}
+
+ /// Attempt to match and rewrite the given op with any pattern, allowing a
+ /// predicate to decide if a pattern can be applied or not, and hooks for if
+ /// the pattern match was a success or failure.
+ ///
+ /// canApply: called before each match and rewrite attempt; return false to
+ /// skip pattern.
+ /// onFailure: called when a pattern fails to match to perform cleanup.
+ /// onSuccess: called when a pattern match succeeds; return failure() to
+ /// invalidate the match and try another pattern.
+ LogicalResult
+ matchAndRewrite(Operation *op, PatternRewriter &rewriter,
+ function_ref<bool(const Pattern &)> canApply = {},
+ function_ref<void(const Pattern &)> onFailure = {},
+ function_ref<LogicalResult(const Pattern &)> onSuccess = {});
+
+ /// Apply a cost model to the patterns within this applicator.
+ void applyCostModel(CostModel model);
+
+ /// Apply the default cost model that solely uses the pattern's static
+ /// benefit.
+ void applyDefaultCostModel() {
+ applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
+ }
+
+ /// Walk all of the patterns within the applicator.
+ void walkAllPatterns(function_ref<void(const Pattern &)> walk);
+
+private:
+ /// Attempt to match and rewrite the given op with the given pattern, allowing
+ /// a predicate to decide if a pattern can be applied or not, and hooks for if
+ /// the pattern match was a success or failure.
+ LogicalResult
+ matchAndRewrite(Operation *op, const RewritePattern &pattern,
+ PatternRewriter &rewriter,
+ function_ref<bool(const Pattern &)> canApply,
+ function_ref<void(const Pattern &)> onFailure,
+ function_ref<LogicalResult(const Pattern &)> onSuccess);
+
+ /// The list that owns the patterns used within this applicator.
+ const OwningRewritePatternList &owningPatternList;
+
+ /// The set of patterns to match for each operation, stable sorted by benefit.
+ DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
+ /// The set of patterns that may match against any operation type, stable
+ /// sorted by benefit.
+ SmallVector<RewritePattern *, 1> anyOpPatterns;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
new file mode 100644
index 000000000000..ab88f6a1e871
--- /dev/null
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -0,0 +1,52 @@
+//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares methods for applying a set of patterns greedily, choosing
+// the patterns with the highest local benefit, until a fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
+#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// applyPatternsGreedily
+//===----------------------------------------------------------------------===//
+
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return success if no more patterns can be matched
+/// in the result operation regions.
+/// Note: This does not apply patterns to the top-level operation itself. Note:
+/// These methods also perform folding and simple dead-code elimination
+/// before attempting to match any of the provided patterns.
+///
+LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+ const OwningRewritePatternList &patterns);
+/// Rewrite the given regions, which must be isolated from above.
+LogicalResult
+applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+ const OwningRewritePatternList &patterns);
+
+/// Applies the specified patterns on `op` alone while also trying to fold it,
+/// by selecting the highest benefits patterns in a greedy manner. Returns
+/// success if no more patterns can be matched. `erased` is set to true if `op`
+/// was folded away or erased as a result of becoming dead. Note: This does not
+/// apply any patterns recursively to the regions of `op`.
+LogicalResult applyOpPatternsAndFold(Operation *op,
+ const OwningRewritePatternList &patterns,
+ bool *erased = nullptr);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index df3b2db98fcd..e7dc02070801 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(Interfaces)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Reducer)
+add_subdirectory(Rewrite)
add_subdirectory(Support)
add_subdirectory(TableGen)
add_subdirectory(Target)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 69786823dd32..2837ec14ef61 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
#include "../GPUCommon/GPUOpsLowering.h"
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e9b44a9fef52..62746663071f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
#include "../GPUCommon/GPUOpsLowering.h"
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index 3ea6233700c3..120f9879f9bd 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -123,7 +124,7 @@ class ConvertShapeConstraints
OwningRewritePatternList patterns;
populateConvertShapeConstraintsConversionPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(func, patterns)))
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 1cf3a326367c..13d74088ee9c 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -17,8 +17,8 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 71eaf0d59e68..cab9526b5c7f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,16 +15,12 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 8a767661d6d7..70bacf397109 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,14 +24,10 @@
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 70c45b4d37af..20d3b3e9ccb9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -24,7 +24,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/CommandLine.h"
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index d8ffb9742fae..4d50aa1ba1e5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -15,7 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Utils.h"
#define DEBUG_TYPE "simplify-affine-structure"
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 390b3f0b4d4c..8f8cac9fe57e 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -16,7 +16,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index d27985e048da..5cf17855e356 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 03fdfd4555f2..4104da6965f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -20,9 +20,8 @@
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 8542c2afb086..00eb9a2fe834 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -23,9 +23,9 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index cf56b0e551a0..139de8cc4651 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index ee1512c8ec89..368f4f2c66dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -22,6 +22,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index f7becae6e328..0cfffa79f73c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -22,8 +22,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 71e3108b2b58..8638d705d132 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -21,9 +21,9 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index 88eb314a852b..0879b73846d7 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -12,10 +12,9 @@
#include "mlir/Dialect/Quant/QuantizeUtils.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::quant;
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
index 9810ea2a5358..055e4759b87b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
@@ -11,9 +11,8 @@
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::quant;
diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
index 641b4bc38e43..41c372c67f6a 100644
--- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 136d01966688..02545e18642e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -8,14 +8,9 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Value.h"
-#include "llvm/Support/Debug.h"
using namespace mlir;
-#define DEBUG_TYPE "pattern-match"
-
//===----------------------------------------------------------------------===//
// PatternBenefit
//===----------------------------------------------------------------------===//
@@ -205,135 +200,3 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
-//===----------------------------------------------------------------------===//
-// PatternApplicator
-//===----------------------------------------------------------------------===//
-
-void PatternApplicator::applyCostModel(CostModel model) {
- // Separate patterns by root kind to simplify lookup later on.
- patterns.clear();
- anyOpPatterns.clear();
- for (const auto &pat : owningPatternList) {
- // If the pattern is always impossible to match, just ignore it.
- if (pat->getBenefit().isImpossibleToMatch()) {
- LLVM_DEBUG({
- llvm::dbgs()
- << "Ignoring pattern '" << pat->getRootKind()
- << "' because it is impossible to match (by pattern benefit)\n";
- });
- continue;
- }
- if (Optional<OperationName> opName = pat->getRootKind())
- patterns[*opName].push_back(pat.get());
- else
- anyOpPatterns.push_back(pat.get());
- }
-
- // Sort the patterns using the provided cost model.
- llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
- auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
- return benefits[lhs] > benefits[rhs];
- };
- auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
- // Special case for one pattern in the list, which is the most common case.
- if (list.size() == 1) {
- if (model(*list.front()).isImpossibleToMatch()) {
- LLVM_DEBUG({
- llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
- << "' because it is impossible to match or cannot lead "
- "to legal IR (by cost model)\n";
- });
- list.clear();
- }
- return;
- }
-
- // Collect the dynamic benefits for the current pattern list.
- benefits.clear();
- for (RewritePattern *pat : list)
- benefits.try_emplace(pat, model(*pat));
-
- // Sort patterns with highest benefit first, and remove those that are
- // impossible to match.
- std::stable_sort(list.begin(), list.end(), cmp);
- while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
- LLVM_DEBUG({
- llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
- << "' because it is impossible to match or cannot lead to "
- "legal IR (by cost model)\n";
- });
- list.pop_back();
- }
- };
- for (auto &it : patterns)
- processPatternList(it.second);
- processPatternList(anyOpPatterns);
-}
-
-void PatternApplicator::walkAllPatterns(
- function_ref<void(const Pattern &)> walk) {
- for (auto &it : owningPatternList)
- walk(*it);
-}
-
-LogicalResult PatternApplicator::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter,
- function_ref<bool(const Pattern &)> canApply,
- function_ref<void(const Pattern &)> onFailure,
- function_ref<LogicalResult(const Pattern &)> onSuccess) {
- // Check to see if there are patterns matching this specific operation type.
- MutableArrayRef<RewritePattern *> opPatterns;
- auto patternIt = patterns.find(op->getName());
- if (patternIt != patterns.end())
- opPatterns = patternIt->second;
-
- // Process the patterns for that match the specific operation type, and any
- // operation type in an interleaved fashion.
- // FIXME: It'd be nice to just write an llvm::make_merge_range utility
- // and pass in a comparison function. That would make this code trivial.
- auto opIt = opPatterns.begin(), opE = opPatterns.end();
- auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
- while (opIt != opE && anyIt != anyE) {
- // Try to match the pattern providing the most benefit.
- RewritePattern *pattern;
- if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
- pattern = *(opIt++);
- else
- pattern = *(anyIt++);
-
- // Otherwise, try to match the generic pattern.
- if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
- onSuccess)))
- return success();
- }
- // If we break from the loop, then only one of the ranges can still have
- // elements. Loop over both without checking given that we don't need to
- // interleave anymore.
- for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
- llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
- if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
- onSuccess)))
- return success();
- }
- return failure();
-}
-
-LogicalResult PatternApplicator::matchAndRewrite(
- Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
- function_ref<bool(const Pattern &)> canApply,
- function_ref<void(const Pattern &)> onFailure,
- function_ref<LogicalResult(const Pattern &)> onSuccess) {
- // Check that the pattern can be applied.
- if (canApply && !canApply(pattern))
- return failure();
-
- // Try to match and rewrite this pattern. The patterns are sorted by
- // benefit, so if we match we can immediately rewrite.
- rewriter.setInsertionPoint(op);
- if (succeeded(pattern.matchAndRewrite(op, rewriter)))
- return success(!onSuccess || succeeded(onSuccess(pattern)));
-
- if (onFailure)
- onFailure(pattern);
- return failure();
-}
diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt
new file mode 100644
index 000000000000..aa6632b63f09
--- /dev/null
+++ b/mlir/lib/Rewrite/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_library(MLIRRewrite
+ PatternApplicator.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
+
+ DEPENDS
+ mlir-generic-headers
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ )
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
new file mode 100644
index 000000000000..f9c0cbfed880
--- /dev/null
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -0,0 +1,148 @@
+//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an applicator that applies pattern rewrites based upon a
+// user defined cost model.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "pattern-match"
+
+void PatternApplicator::applyCostModel(CostModel model) {
+ // Separate patterns by root kind to simplify lookup later on.
+ patterns.clear();
+ anyOpPatterns.clear();
+ for (const auto &pat : owningPatternList) {
+ // If the pattern is always impossible to match, just ignore it.
+ if (pat->getBenefit().isImpossibleToMatch()) {
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "Ignoring pattern '" << pat->getRootKind()
+ << "' because it is impossible to match (by pattern benefit)\n";
+ });
+ continue;
+ }
+ if (Optional<OperationName> opName = pat->getRootKind())
+ patterns[*opName].push_back(pat.get());
+ else
+ anyOpPatterns.push_back(pat.get());
+ }
+
+ // Sort the patterns using the provided cost model.
+ llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
+ auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
+ return benefits[lhs] > benefits[rhs];
+ };
+ auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
+ // Special case for one pattern in the list, which is the most common case.
+ if (list.size() == 1) {
+ if (model(*list.front()).isImpossibleToMatch()) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
+ << "' because it is impossible to match or cannot lead "
+ "to legal IR (by cost model)\n";
+ });
+ list.clear();
+ }
+ return;
+ }
+
+ // Collect the dynamic benefits for the current pattern list.
+ benefits.clear();
+ for (RewritePattern *pat : list)
+ benefits.try_emplace(pat, model(*pat));
+
+ // Sort patterns with highest benefit first, and remove those that are
+ // impossible to match.
+ std::stable_sort(list.begin(), list.end(), cmp);
+ while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
+ << "' because it is impossible to match or cannot lead to "
+ "legal IR (by cost model)\n";
+ });
+ list.pop_back();
+ }
+ };
+ for (auto &it : patterns)
+ processPatternList(it.second);
+ processPatternList(anyOpPatterns);
+}
+
+void PatternApplicator::walkAllPatterns(
+ function_ref<void(const Pattern &)> walk) {
+ for (auto &it : owningPatternList)
+ walk(*it);
+}
+
+LogicalResult PatternApplicator::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter,
+ function_ref<bool(const Pattern &)> canApply,
+ function_ref<void(const Pattern &)> onFailure,
+ function_ref<LogicalResult(const Pattern &)> onSuccess) {
+ // Check to see if there are patterns matching this specific operation type.
+ MutableArrayRef<RewritePattern *> opPatterns;
+ auto patternIt = patterns.find(op->getName());
+ if (patternIt != patterns.end())
+ opPatterns = patternIt->second;
+
+ // Process the patterns for that match the specific operation type, and any
+ // operation type in an interleaved fashion.
+ // FIXME: It'd be nice to just write an llvm::make_merge_range utility
+ // and pass in a comparison function. That would make this code trivial.
+ auto opIt = opPatterns.begin(), opE = opPatterns.end();
+ auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
+ while (opIt != opE && anyIt != anyE) {
+ // Try to match the pattern providing the most benefit.
+ RewritePattern *pattern;
+ if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
+ pattern = *(opIt++);
+ else
+ pattern = *(anyIt++);
+
+ // Otherwise, try to match the generic pattern.
+ if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
+ onSuccess)))
+ return success();
+ }
+ // If we break from the loop, then only one of the ranges can still have
+ // elements. Loop over both without checking given that we don't need to
+ // interleave anymore.
+ for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
+ llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
+ if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
+ onSuccess)))
+ return success();
+ }
+ return failure();
+}
+
+LogicalResult PatternApplicator::matchAndRewrite(
+ Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
+ function_ref<bool(const Pattern &)> canApply,
+ function_ref<void(const Pattern &)> onFailure,
+ function_ref<LogicalResult(const Pattern &)> onSuccess) {
+ // Check that the pattern can be applied.
+ if (canApply && !canApply(pattern))
+ return failure();
+
+ // Try to match and rewrite this pattern. The patterns are sorted by
+ // benefit, so if we match we can immediately rewrite.
+ rewriter.setInsertionPoint(op);
+ if (succeeded(pattern.matchAndRewrite(op, rewriter)))
+ return success(!onSuccess || succeeded(onSuccess(pattern)));
+
+ if (onFailure)
+ onFailure(pattern);
+ return failure();
+}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index a2b96b90b543..b36f150134e0 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -7,7 +7,6 @@ add_mlir_library(MLIRTransforms
Canonicalizer.cpp
CopyRemoval.cpp
CSE.cpp
- DialectConversion.cpp
Inliner.cpp
LocationSnapshot.cpp
LoopCoalescing.cpp
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 9b028bfa2525..de435440607f 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -12,8 +12,8 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index d4ab3cc4d549..ff104071f444 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -17,6 +17,7 @@
#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 9fa59bbde55a..6dbb74b4905a 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(MLIRTransformUtils
+ DialectConversion.cpp
FoldUtils.cpp
GreedyPatternRewriteDriver.cpp
InliningUtils.cpp
@@ -19,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
MLIRLoopAnalysis
MLIRSCF
MLIRPass
+ MLIRRewrite
MLIRStandard
)
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
similarity index 98%
rename from mlir/lib/Transforms/DialectConversion.cpp
rename to mlir/lib/Transforms/Utils/DialectConversion.cpp
index 692cd494324e..1d649fd3a02e 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
+#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -74,8 +75,7 @@ computeConversionSet(iterator_range<Region::iterator> region,
/// A utility function to log a successful result for the given reason.
template <typename... Args>
-static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
- Args &&... args) {
+static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> SUCCESS";
@@ -88,8 +88,7 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
/// A utility function to log a failure result for the given reason.
template <typename... Args>
-static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
- Args &&... args) {
+static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> FAILURE : "
@@ -2033,21 +2032,21 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
return minDepth;
// Sort the patterns by those likely to be the most beneficial.
- llvm::array_pod_sort(
- patternsByDepth.begin(), patternsByDepth.end(),
- [](const std::pair<const Pattern *, unsigned> *lhs,
- const std::pair<const Pattern *, unsigned> *rhs) {
- // First sort by the smaller pattern legalization depth.
- if (lhs->second != rhs->second)
- return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
- &rhs->second);
-
- // Then sort by the larger pattern benefit.
- auto lhsBenefit = lhs->first->getBenefit();
- auto rhsBenefit = rhs->first->getBenefit();
- return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
- &lhsBenefit);
- });
+ llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
+ [](const std::pair<const Pattern *, unsigned> *lhs,
+ const std::pair<const Pattern *, unsigned> *rhs) {
+ // First sort by the smaller pattern legalization
+ // depth.
+ if (lhs->second != rhs->second)
+ return llvm::array_pod_sort_comparator<unsigned>(
+ &lhs->second, &rhs->second);
+
+ // Then sort by the larger pattern benefit.
+ auto lhsBenefit = lhs->first->getBenefit();
+ auto rhsBenefit = rhs->first->getBenefit();
+ return llvm::array_pod_sort_comparator<PatternBenefit>(
+ &rhsBenefit, &lhsBenefit);
+ });
// Update the legalization pattern to use the new sorted list.
patterns.clear();
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c6276395d2c2..199cfbf3d1f0 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -10,8 +10,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 8f4b0f800c05..201ba22e2959 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -23,8 +23,8 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index b6df06ed9c54..03433f19f8d6 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -13,8 +13,8 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 04a21ec01fe6..2099b368eddb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -10,10 +10,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index 3cbc95ce6c74..b78884a5479d 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -7,9 +7,8 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -25,9 +24,9 @@ OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
ArrayRef<Attribute> operands) {
- auto argument_op = getOperand();
+ auto argumentOp = getOperand();
// The success case should cause the trait fold to be supressed.
- return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
+ return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
}
namespace {
diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index 79b6464f3b4c..5c16677c8849 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -14,6 +14,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
@@ -93,7 +94,7 @@ void TestConvVectorization::runOnOperation() {
// VectorTransforms.cpp
vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
- applyPatternsAndFoldGreedily(module, vectorTransferPatterns);
+ applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
// Programmatic controlled lowering of linalg.copy and linalg.fill.
PassManager pm(context);
@@ -105,13 +106,14 @@ void TestConvVectorization::runOnOperation() {
OwningRewritePatternList vectorContractLoweringPatterns;
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
context, vectorTransformsOptions);
- applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
+ applyPatternsAndFoldGreedily(module,
+ std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
VectorTransferToSCFOptions());
- applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
+ applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
// Ensure we drop the marker in the end.
module.walk([](linalg::LinalgOp op) {
diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp
index c5f6e3a5ce30..ab485e58e895 100644
--- a/mlir/test/lib/Transforms/TestExpandTanh.cpp
+++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp
@@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
diff --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
index 13f0d2e39aae..eaa7149fa994 100644
--- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
@@ -12,8 +12,8 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 9a376c548900..0fa4e22ebf37 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index ffb0f92dae99..cc8c09d89a03 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -17,8 +17,8 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 7dae0254c487..989e8cda34f9 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -14,9 +14,8 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
More information about the Mlir-commits
mailing list