[Mlir-commits] [mlir] dc4e913 - [PatternMatch] Big mechanical rename OwningRewritePatternList -> RewritePatternSet and insert -> add. NFC
Chris Lattner
llvmlistbot at llvm.org
Mon Mar 22 17:21:04 PDT 2021
Author: Chris Lattner
Date: 2021-03-22T17:20:50-07:00
New Revision: dc4e913be9c3d1c37f66348d4b5047a107499b53
URL: https://github.com/llvm/llvm-project/commit/dc4e913be9c3d1c37f66348d4b5047a107499b53
DIFF: https://github.com/llvm/llvm-project/commit/dc4e913be9c3d1c37f66348d4b5047a107499b53.diff
LOG: [PatternMatch] Big mechanical rename OwningRewritePatternList -> RewritePatternSet and insert -> add. NFC
This doesn't change APIs, this just cleans up the many in-tree uses of these
names to use the new preferred names. We'll keep the old names around for a
couple weeks to help transitions.
Differential Revision: https://reviews.llvm.org/D99127
Added:
Modified:
mlir/docs/Bufferization.md
mlir/docs/Canonicalization.md
mlir/docs/PatternRewriter.md
mlir/docs/Tutorials/QuickstartRewrites.md
mlir/docs/Tutorials/Toy/Ch-3.md
mlir/docs/Tutorials/Toy/Ch-5.md
mlir/docs/Tutorials/Toy/Ch-6.md
mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
mlir/include/mlir/Dialect/AMX/Transforms.h
mlir/include/mlir/Dialect/AVX512/Transforms.h
mlir/include/mlir/Dialect/GPU/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
mlir/include/mlir/Transforms/Bufferize.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
mlir/lib/Dialect/SCF/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Rewrite/FrozenRewritePatternList.cpp
mlir/lib/Transforms/Bufferize.cpp
mlir/lib/Transforms/Canonicalizer.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Dialect/Test/TestTraits.cpp
mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
mlir/test/lib/Transforms/TestConvVectorization.cpp
mlir/test/lib/Transforms/TestConvertCallOp.cpp
mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
mlir/test/lib/Transforms/TestExpandTanh.cpp
mlir/test/lib/Transforms/TestGpuRewrite.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
mlir/test/lib/Transforms/TestSparsification.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/unittests/Rewrite/PatternBenefit.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md
index de0648deea7a5..eba93c7a6d348 100644
--- a/mlir/docs/Bufferization.md
+++ b/mlir/docs/Bufferization.md
@@ -156,19 +156,19 @@ is very small, and follows the basic pattern of any dialect conversion pass.
```
void mlir::populateTensorBufferizePatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context);
+ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<BufferizeCastOp, BufferizeExtractOp>(typeConverter,
+ patterns.getContext());
}
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
- populateTensorBufferizePatterns(context, typeConverter, patterns);
+ populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>();
target.addLegalDialect<StandardOpsDialect>();
@@ -180,7 +180,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
```
The pass has all the hallmarks of a dialect conversion pass that does type
-conversions: a `TypeConverter`, a `OwningRewritePatternList`, and a
+conversions: a `TypeConverter`, a `RewritePatternSet`, and a
`ConversionTarget`, and a call to `applyPartialConversion`. Note that a function
`populateTensorBufferizePatterns` is separated, so that power users can use the
patterns independently, if necessary (such as to combine multiple sets of
diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index 3e1c9d11ecabd..4549369a4ccbb 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -79,9 +79,9 @@ def MyOp : ... {
Canonicalization patterns can then be provided in the source file:
```c++
-void MyOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+void MyOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.insert<...>(...);
+ patterns.add<...>(...);
}
```
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 590c9ffc95a1b..f452230584c0d 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -154,10 +154,10 @@ creation, as well as many useful attribute and type construction methods.
After a set of patterns have been defined, they are collected and provided to a
specific driver for application. A driver consists of several high levels parts:
-* Input `OwningRewritePatternList`
+* Input `RewritePatternSet`
The input patterns to a driver are provided in the form of an
-`OwningRewritePatternList`. This class provides a simplified API for building a
+`RewritePatternSet`. This class provides a simplified API for building a
list of patterns.
* Driver-specific `PatternRewriter`
@@ -173,7 +173,7 @@ mutation directly.
Each driver is responsible for defining its own operation visitation order as
well as pattern cost model, but the final application is performed via a
`PatternApplicator` class. This class takes as input the
-`OwningRewritePatternList` and transforms the patterns based upon a provided
+`RewritePatternSet` and transforms the patterns based upon a provided
cost model. This cost model computes a final benefit for a given pattern, using
whatever driver specific information necessary. After a cost model has been
computed, the driver may begin to match patterns against operations using
@@ -189,8 +189,8 @@ public:
};
/// Populate the pattern list.
-void collectMyPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<MyPattern>(/*benefit=*/1, ctx);
+void collectMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {
+ patterns.add<MyPattern>(/*benefit=*/1, ctx);
}
/// Define a custom PatternRewriter for use by the driver.
@@ -203,7 +203,7 @@ public:
/// Apply the custom driver to `op`.
void applyMyPatternDriver(Operation *op,
- const OwningRewritePatternList &patterns) {
+ const RewritePatternSet &patterns) {
// Initialize the custom PatternRewriter.
MyPatternRewriter rewriter(op->getContext());
diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index 3dea430826ae8..d537050f1a325 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -155,7 +155,7 @@ add_public_tablegen_target(<name-of-the-cmake-target>)
Then you can `#include` the generated file in any C++ implementation file you
like. (You will also need to make sure the library depends on the CMake target
defined in the above.) The generated file will have a `populateWithGenerated(
-OwningRewritePatternList &patterns)` function that you can
+RewritePatternSet &patterns)` function that you can
use to collect all the generated patterns inside `patterns` and then use
`patterns` in any pass you would like.
diff --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md
index 7976d7c30db59..abdb419f534fa 100644
--- a/mlir/docs/Tutorials/Toy/Ch-3.md
+++ b/mlir/docs/Tutorials/Toy/Ch-3.md
@@ -114,8 +114,8 @@ pattern with the canonicalization framework.
```c++
// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyRedundantTranspose>(context);
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<SimplifyRedundantTranspose>(context);
}
```
diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md
index b8964f93e1a3d..9cd1533d184d7 100644
--- a/mlir/docs/Tutorials/Toy/Ch-5.md
+++ b/mlir/docs/Tutorials/Toy/Ch-5.md
@@ -147,8 +147,8 @@ void ToyToAffineLoweringPass::runOnFunction() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
- mlir::OwningRewritePatternList patterns;
- patterns.insert<..., TransposeOpLowering>(&getContext());
+ mlir::RewritePatternSet patterns;
+ patterns.add<..., TransposeOpLowering>(&getContext());
...
```
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index bddd93688ddb8..b490421fab4aa 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -90,14 +90,14 @@ into LLVM dialect. These patterns allow for lowering the IR in multiple stages
by relying on [transitive lowering](../../../getting_started/Glossary.md#transitive-lowering).
```c++
- mlir::OwningRewritePatternList patterns;
+ mlir::RewritePatternSet patterns;
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation, to lower from the `toy` dialect, is the
// PrintOp.
- patterns.insert<PrintOpLowering>(&getContext());
+ patterns.add<PrintOpLowering>(&getContext());
```
### Full Lowering
diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
index 0af4cbfc11f12..5e74d95d573b5 100644
--- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
@@ -54,15 +54,15 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyRedundantTranspose>(context);
+ results.add<SimplifyRedundantTranspose>(context);
}
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
- FoldConstantReshapeOptPattern>(context);
+ results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
}
diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
index 0af4cbfc11f12..5e74d95d573b5 100644
--- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
@@ -54,15 +54,15 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyRedundantTranspose>(context);
+ results.add<SimplifyRedundantTranspose>(context);
}
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
- FoldConstantReshapeOptPattern>(context);
+ results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
}
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 6cd97f6b65cbf..8acbc37c77f7c 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -297,9 +297,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
- OwningRewritePatternList patterns(&getContext());
- patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
- ReturnOpLowering, TransposeOpLowering>(&getContext());
+ RewritePatternSet patterns(&getContext());
+ patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
+ ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
index 0af4cbfc11f12..5e74d95d573b5 100644
--- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
@@ -54,15 +54,15 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyRedundantTranspose>(context);
+ results.add<SimplifyRedundantTranspose>(context);
}
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
- FoldConstantReshapeOptPattern>(context);
+ results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
}
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 28d7245802b13..c1ad4dc66e996 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -296,9 +296,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
- OwningRewritePatternList patterns(&getContext());
- patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
- ReturnOpLowering, TransposeOpLowering>(&getContext());
+ RewritePatternSet patterns(&getContext());
+ patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
+ ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index d0c2412bd9e75..3fd48c5fd892f 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -191,14 +191,14 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// lowerings. Transitive lowering, or A->B->C lowering, is when multiple
// patterns must be applied to fully transform an illegal operation into a
// set of legal ones.
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
// PrintOp.
- patterns.insert<PrintOpLowering>(&getContext());
+ patterns.add<PrintOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
index 0af4cbfc11f12..5e74d95d573b5 100644
--- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp
@@ -54,15 +54,15 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyRedundantTranspose>(context);
+ results.add<SimplifyRedundantTranspose>(context);
}
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
- FoldConstantReshapeOptPattern>(context);
+ results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
}
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 6cd97f6b65cbf..8acbc37c77f7c 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -297,9 +297,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
- OwningRewritePatternList patterns(&getContext());
- patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
- ReturnOpLowering, TransposeOpLowering>(&getContext());
+ RewritePatternSet patterns(&getContext());
+ patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
+ ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index d0c2412bd9e75..3fd48c5fd892f 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -191,14 +191,14 @@ void ToyToLLVMLoweringPass::runOnOperation() {
// lowerings. Transitive lowering, or A->B->C lowering, is when multiple
// patterns must be applied to fully transform an illegal operation into a
// set of legal ones.
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateLoopToStdConversionPatterns(patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
// PrintOp.
- patterns.insert<PrintOpLowering>(&getContext());
+ patterns.add<PrintOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index bfbd36b40fa03..95072eeef1d2e 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -72,15 +72,15 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
-void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyRedundantTranspose>(context);
+ results.add<SimplifyRedundantTranspose>(context);
}
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
-void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
- FoldConstantReshapeOptPattern>(context);
+ results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
+ FoldConstantReshapeOptPattern>(context);
}
diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index 8058f5d7f12a6..b8afecdaa93de 100644
--- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -42,12 +42,11 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
/// Collect a set of patterns to convert from the Affine dialect to the Standard
/// dialect, in particular convert structured affine control flow into CFG
/// branch-based control flow.
-void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns);
+void populateAffineToStdConversionPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to convert vector-related Affine ops to the Vector
/// dialect.
-void populateAffineToVectorConversionPatterns(
- OwningRewritePatternList &patterns);
+void populateAffineToVectorConversionPatterns(RewritePatternSet &patterns);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
index 70170f8c5f99d..1cbe3f69e36ae 100644
--- a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
@@ -17,7 +17,7 @@ using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
index cf3763f449a1f..0878c633ec4f8 100644
--- a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
+++ b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
@@ -34,7 +34,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
/// the TypeConverter, but otherwise don't care what type conversions are
/// happening.
void populateAsyncStructuralTypeConversionsAndLegality(
- TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
index 3dab2a136b289..378eb006b969a 100644
--- a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
+++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
@@ -18,8 +18,8 @@ template <typename T>
class OperationPass;
/// Populate the given list with patterns that convert from Complex to LLVM.
-void populateComplexToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Create a pass to convert Complex operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertComplexToLLVMPass();
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index cdfe5fa07a640..a005fb50226f5 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -29,7 +29,7 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
/// index bitwidth used for the lowering of the device side index computations
diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index e298d2d73efbb..bcec880dfc6d4 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -26,7 +26,7 @@ class GPUModuleOp;
/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Configure target to convert from the GPU dialect to ROCDL.
void configureGpuToROCDLConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index e679b86325992..add196d441ea8 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -22,7 +22,7 @@ class SPIRVTypeConverter;
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spv.entry_point_abi attribute.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
index 948c2a4be6f2d..a1f56048cd7f4 100644
--- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
+++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
@@ -14,11 +14,12 @@
namespace mlir {
class MLIRContext;
class ModuleOp;
-template <typename T> class OperationPass;
+template <typename T>
+class OperationPass;
/// Populate the given list with patterns that convert from Linalg to LLVM.
void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Create a pass to convert Linalg operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToLLVMPass();
diff --git a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
index f05e9d53ff455..64b612ed6b129 100644
--- a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
+++ b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
@@ -22,7 +22,7 @@ using OwningRewritePatternList = RewritePatternSet;
/// Appends to a pattern list additional patterns for translating Linalg ops to
/// SPIR-V ops.
void populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index 240bc1f8dd1b8..f66a29250aa21 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -69,8 +69,7 @@ class IndexedGenericOpToLibraryCallRewrite
};
/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns);
+void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
} // namespace linalg
diff --git a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
index 5092322286d65..b51b7f3b59383 100644
--- a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
+++ b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
@@ -21,7 +21,7 @@ using OwningRewritePatternList = RewritePatternSet;
/// Populate the given list with patterns that convert from OpenMP to LLVM.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Create a pass to convert OpenMP operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertOpenMPToLLVMPass();
diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
index a27c408b9d5a5..ac1ba0e2f24b4 100644
--- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
+++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
@@ -43,7 +43,7 @@ LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
/// Adds the conversion pattern from `scf.parallel` to `gpu.launch` to the
/// provided pattern list.
-void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns);
+void populateParallelLoopToGPUPatterns(RewritePatternSet &patterns);
/// Configures the rewrite target such that only `scf.parallel` operations that
/// are not rewritten by the provided patterns are legal.
diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
index 14679f4abb7cf..284500df796c5 100644
--- a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
@@ -37,7 +37,7 @@ struct ScfToSPIRVContext {
/// loop.terminator to CFG operations within the SPIR-V dialect.
void populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
diff --git a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
index fc120798d8063..880d4ae6e7454 100644
--- a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
+++ b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
@@ -22,7 +22,7 @@ using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
-void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns);
+void populateLoopToStdConversionPatterns(RewritePatternSet &patterns);
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
std::unique_ptr<Pass> createLowerToCFGPass();
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
index 2f6b6d7ae4de7..c135eee27ee72 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
@@ -41,16 +41,16 @@ void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Populates the given list with patterns for function conversion from SPIR-V
/// to LLVM.
void populateSPIRVToLLVMFunctionConversionPatterns(
- LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
/// Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(
- LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
index 3ab3ee7144f3f..a26d4dd2e314f 100644
--- a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -20,13 +20,12 @@ class OperationPass;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
-void populateShapeToStandardConversionPatterns(
- OwningRewritePatternList &patterns);
+void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns);
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
void populateConvertShapeConstraintsConversionPatterns(
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index e9ee9e953477d..1d14fb9d0fd20 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -49,27 +49,27 @@ struct LowerToLLVMOptions {
/// Collect a set of patterns to convert memory-related operations from the
/// Standard dialect to the LLVM dialect, excluding non-memory-related
/// operations and FuncOp.
-void populateStdToLLVMMemoryConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
/// dialect, excluding the memory-related operations.
-void populateStdToLLVMNonMemoryConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
-void populateStdToLLVMFuncOpConversionPattern(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Collect the patterns to convert from the Standard dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
/// by reference meaning the references have to remain alive during the entire
/// pattern lifetime.
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
/// stdlib malloc/free is used by default for allocating memrefs allocated with
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
index 18cf4f3efd9b2..165ba0081b776 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
@@ -22,7 +22,7 @@ class SPIRVTypeConverter;
/// to SPIR-V ops. Also adds the patterns to legalize ops not directly
/// translated to SPIR-V dialect.
void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Appends to a pattern list additional patterns for translating tensor ops
/// to SPIR-V ops.
@@ -38,12 +38,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
/// threshold is used to control when the patterns should kick in.
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Appends to a pattern list patterns to legalize ops that are not directly
/// lowered to SPIR-V.
void populateStdLegalizationPatternsForSPIRVLowering(
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 75538394bfe81..8ee5316c10cc1 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -28,7 +28,7 @@ void addTosaToLinalgOnTensorsPasses(OpPassManager &pm);
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgOnTensorsConversionPatterns(
- OwningRewritePatternList *patterns);
+ RewritePatternSet *patterns);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
index 08b2fe9f5fd54..e3b2e04dd61fd 100644
--- a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
+++ b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
@@ -20,7 +20,7 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToSCF();
-void populateTosaToSCFConversionPatterns(OwningRewritePatternList *patterns);
+void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns);
/// Populates passes to convert from TOSA to SCF.
void addTosaToSCFPasses(OpPassManager &pm);
diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
index f13047187fa26..fc1284417896c 100644
--- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
+++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
@@ -20,11 +20,10 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToStandard();
-void populateTosaToStandardConversionPatterns(
- OwningRewritePatternList *patterns);
+void populateTosaToStandardConversionPatterns(RewritePatternSet *patterns);
void populateTosaRescaleToStandardConversionPatterns(
- OwningRewritePatternList *patterns);
+ RewritePatternSet *patterns);
/// Populates passes to convert from TOSA to Standard.
void addTosaToStandardPasses(OpPassManager &pm);
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 91ded03f84b06..efd26ff8808c4 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -62,12 +62,12 @@ struct LowerVectorToLLVMOptions {
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
-void populateVectorToLLVMMatrixConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+ LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool enableIndexOptimizations = true);
/// Create a pass to convert vector operations to the LLVMIR dialect.
diff --git a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
index 7f0859cc5f581..2b935cdc3dab7 100644
--- a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
+++ b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
@@ -19,8 +19,8 @@ class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
-void populateVectorToROCDLConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateVectorToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Create a pass to convert vector operations to the ROCDL dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToROCDLPass();
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 561a3e9ca2c69..e8c7e651cc860 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -163,7 +163,7 @@ struct VectorTransferRewriter : public RewritePattern {
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
void populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
/// Create a pass to convert a subset of vector ops to SCF.
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index 8fc606f8bcd5b..bfadb83a921e8 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -21,7 +21,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Vector Ops to
/// SPIR-V ops.
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 1fccbb5815149..16dff0df13817 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -18,8 +18,8 @@ using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
diff --git a/mlir/include/mlir/Dialect/AVX512/Transforms.h b/mlir/include/mlir/Dialect/AVX512/Transforms.h
index 541833652a49f..0ea3e627d78ca 100644
--- a/mlir/include/mlir/Dialect/AVX512/Transforms.h
+++ b/mlir/include/mlir/Dialect/AVX512/Transforms.h
@@ -18,8 +18,8 @@ using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower AVX512 ops to ops that map to LLVM
/// intrinsics.
-void populateAVX512LegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateAVX512LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Configure the target to support lowering AVX512 ops to ops that map to LLVM
/// intrinsics.
diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h
index 327f9d689d9c0..a207c6b2279e1 100644
--- a/mlir/include/mlir/Dialect/GPU/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -31,10 +31,10 @@ std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass();
std::unique_ptr<OperationPass<FuncOp>> createGpuAsyncRegionPass();
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
-void populateGpuAllReducePatterns(OwningRewritePatternList &patterns);
+void populateGpuAllReducePatterns(RewritePatternSet &patterns);
/// Collect all patterns to rewrite ops within the GPU dialect.
-inline void populateGpuRewritePatterns(OwningRewritePatternList &patterns) {
+inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 24f49b5235471..ecec2a3c05d24 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -52,8 +52,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
-void populateElementwiseToLinalgConversionPatterns(
- OwningRewritePatternList &patterns);
+void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
/// Create a pass to conver named Linalg operations to Linalg generic
/// operations.
@@ -66,15 +65,13 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
-void populateFoldReshapeOpsByExpansionPatterns(
- OwningRewritePatternList &patterns);
+void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
-void populateFoldReshapeOpsByLinearizationPatterns(
- OwningRewritePatternList &patterns);
+void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -83,15 +80,14 @@ void populateFoldReshapeOpsByLinearizationPatterns(
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(OwningRewritePatternList &patterns);
+void populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
-void populateLinalgFoldUnitExtentDimsPatterns(
- OwningRewritePatternList &patterns);
+void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 421a5446ad6c3..d005cc310abee 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -24,7 +24,7 @@ struct Transformation {
explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
: filter(f) {}
virtual ~Transformation() = default;
- virtual OwningRewritePatternList
+ virtual RewritePatternSet
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) = 0;
linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
@@ -35,33 +35,32 @@ template <template <typename> class PatternType, typename ConcreteOpType,
typename OptionsType,
typename = std::enable_if_t<std::is_member_function_pointer<
decltype(&ConcreteOpType::getOperationName)>::value>>
-void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
+void sfinae_enqueue(RewritePatternSet &patternList, OptionsType options,
StringRef opName, linalg::LinalgTransformationFilter m) {
assert(opName == ConcreteOpType::getOperationName() &&
"explicit name must match ConcreteOpType::getOperationName");
- patternList.insert<PatternType<ConcreteOpType>>(patternList.getContext(),
- options, m);
+ patternList.add<PatternType<ConcreteOpType>>(patternList.getContext(),
+ options, m);
}
/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
/// (e.g. LinalgOp, other interfaces, Operation*).
template <template <typename> class PatternType, typename OpType,
typename OptionsType>
-void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
+void sfinae_enqueue(RewritePatternSet &patternList, OptionsType options,
StringRef opName, linalg::LinalgTransformationFilter m) {
assert(!opName.empty() && "opName must not be empty");
- patternList.insert<PatternType<OpType>>(opName, patternList.getContext(),
- options, m);
+ patternList.add<PatternType<OpType>>(opName, patternList.getContext(),
+ options, m);
}
template <typename PatternType, typename OpType, typename OptionsType>
-void enqueue(OwningRewritePatternList &patternList, OptionsType options,
+void enqueue(RewritePatternSet &patternList, OptionsType options,
StringRef opName, linalg::LinalgTransformationFilter m) {
if (!opName.empty())
- patternList.insert<PatternType>(opName, patternList.getContext(), options,
- m);
+ patternList.add<PatternType>(opName, patternList.getContext(), options, m);
else
- patternList.insert<PatternType>(m.addOpFilter<OpType>(), options);
+ patternList.add<PatternType>(m.addOpFilter<OpType>(), options);
}
/// Promotion transformation enqueues a particular stage-1 pattern for
@@ -77,10 +76,10 @@ struct Tile : public Transformation {
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
- OwningRewritePatternList
+ RewritePatternSet
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList tilingPatterns(context);
+ RewritePatternSet tilingPatterns(context);
sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
tilingPatterns, options, opName, m);
return tilingPatterns;
@@ -105,10 +104,10 @@ struct Promote : public Transformation {
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
- OwningRewritePatternList
+ RewritePatternSet
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList promotionPatterns(context);
+ RewritePatternSet promotionPatterns(context);
sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
promotionPatterns, options, opName, m);
return promotionPatterns;
@@ -133,14 +132,14 @@ struct Vectorize : public Transformation {
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name), options(options) {}
- OwningRewritePatternList
+ RewritePatternSet
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList vectorizationPatterns(context);
+ RewritePatternSet vectorizationPatterns(context);
enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
vectorizationPatterns, options, opName, m);
- vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
- linalg::LinalgCopyVTWForwardingPattern>(
+ vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
+ linalg::LinalgCopyVTWForwardingPattern>(
context, /*benefit=*/2);
return vectorizationPatterns;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 80468c3ae26d4..e1a136c7e65bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -33,12 +33,12 @@ using LinalgLoops = SmallVector<Operation *, 4>;
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
+ MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
/// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
@@ -441,10 +441,8 @@ struct LinalgTilingOptions {
/// Canonicalization patterns relevant to apply after tiling patterns. These are
/// applied automatically by the tiling pass but need to be applied manually
/// when tiling is called programmatically.
-OwningRewritePatternList
-getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
-void populateLinalgTilingCanonicalizationPatterns(
- OwningRewritePatternList &patterns);
+RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
+void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// Base pattern that applied the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
@@ -690,10 +688,10 @@ template <
typename OpType,
typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
typename = void>
-void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
+void insertVectorizationPatternImpl(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
- patternList.insert<linalg::LinalgVectorizationPattern>(
+ patternList.add<linalg::LinalgVectorizationPattern>(
OpType::getOperationName(), patternList.getContext(), options, f);
}
@@ -701,16 +699,16 @@ void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
/// an OpInterface).
template <typename OpType, typename = std::enable_if_t<
!detect_has_get_operation_name<OpType>::value>>
-void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
+void insertVectorizationPatternImpl(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
- patternList.insert<linalg::LinalgVectorizationPattern>(
- f.addOpFilter<OpType>(), options);
+ patternList.add<linalg::LinalgVectorizationPattern>(f.addOpFilter<OpType>(),
+ options);
}
/// Variadic helper function to insert vectorization patterns for C++ ops.
template <typename... OpTypes>
-void insertVectorizationPatterns(OwningRewritePatternList &patternList,
+void insertVectorizationPatterns(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f =
linalg::LinalgTransformationFilter()) {
@@ -789,13 +787,13 @@ struct LinalgLoweringPattern : public RewritePattern {
/// Populates `patterns` with patterns to convert spec-generated named ops to
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
/// Populates `patterns` with patterns to convert linalg.conv ops to
/// linalg.generic ops.
void populateLinalgConvGeneralizationPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
//===----------------------------------------------------------------------===//
@@ -1056,12 +1054,11 @@ struct SparsificationOptions {
/// Sets up sparsification rewriting rules with the given options.
void populateSparsificationPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
const SparsificationOptions &options = SparsificationOptions());
/// Sets up sparsification conversion rules with the given options.
-void populateSparsificationConversionPatterns(
- OwningRewritePatternList &patterns);
+void populateSparsificationConversionPatterns(RewritePatternSet &patterns);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index 94473af864693..8ab13e42f477d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -61,7 +61,7 @@ tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
/// corresponding scf.yield ops need to update their types accordingly to the
/// TypeConverter, but otherwise don't care what type conversions are happening.
void populateSCFStructuralTypeConversionsAndLegality(
- TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
} // namespace scf
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
index 098d4fd563276..7d6f0ce74031a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
@@ -24,7 +24,7 @@
namespace mlir {
namespace spirv {
void populateSPIRVGLSLCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results);
+ mlir::RewritePatternSet &results);
} // namespace spirv
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index d7cd76bc0f0f0..881f8e90fa0db 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -68,7 +68,7 @@ class SPIRVTypeConverter : public TypeConverter {
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
/// types.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
namespace spirv {
class AccessChainOp;
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 9e4b4af633f5e..1cf83e6b0beff 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -28,7 +28,7 @@ namespace mlir {
std::unique_ptr<Pass> createShapeToShapeLowering();
/// Collects a set of patterns to rewrite ops within the Shape dialect.
-void populateShapeRewritePatterns(OwningRewritePatternList &patterns);
+void populateShapeRewritePatterns(RewritePatternSet &patterns);
// Collects a set of patterns to replace all constraints with passing witnesses.
// This is intended to then allow all ShapeConstraint related ops and data to
@@ -36,7 +36,7 @@ void populateShapeRewritePatterns(OwningRewritePatternList &patterns);
// canonicalization and dead code elimination.
//
// After this pass, no cstr_ operations exist.
-void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns);
+void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
/// Populates patterns for shape dialect structural type conversions and sets up
@@ -51,7 +51,7 @@ std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
/// do for a structural type conversion is to update both of their types
/// consistently to the new types prescribed by the TypeConverter.
void populateShapeStructuralTypeConversionsAndLegality(
- TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
// Bufferizes shape dialect ops.
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
index 49895acd9d241..c453e80ee71de 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h
@@ -81,9 +81,10 @@ class ValueDecomposer {
/// Populates the patterns needed to drive the conversion process for
/// decomposing call graph types with the given `ValueDecomposer`.
-void populateDecomposeCallGraphTypesPatterns(
- MLIRContext *context, TypeConverter &typeConverter,
- ValueDecomposer &decomposer, OwningRewritePatternList &patterns);
+void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
+ TypeConverter &typeConverter,
+ ValueDecomposer &decomposer,
+ RewritePatternSet &patterns);
} // end namespace mlir
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index 6e0abfcc7f0ea..b932d1e009834 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -25,15 +25,15 @@ using OwningRewritePatternList = RewritePatternSet;
/// Add a pattern to the given pattern list to convert the operand and result
/// types of a CallOp with the given type converter.
-void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
+void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter);
/// Add a pattern to the given pattern list to rewrite branch operations to use
/// operands that have been legalized by the conversion framework. This can only
/// be done if the branch operation implements the BranchOpInterface. Only
/// needed for partial conversions.
-void populateBranchOpInterfaceTypeConversionPattern(
- OwningRewritePatternList &patterns, TypeConverter &converter);
+void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns,
+ TypeConverter &converter);
/// Return true if op is a BranchOpInterface op whose operands are all legal
/// according to converter.
@@ -42,7 +42,7 @@ bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op,
/// Add a pattern to the given pattern list to rewrite `return` ops to use
/// operands that have been legalized by the conversion framework.
-void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns,
+void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter);
/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 6e95daed621f8..2b7f3da150cdf 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -23,7 +23,7 @@ class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdBufferizePass();
@@ -42,7 +42,7 @@ std::unique_ptr<Pass> createTensorConstantBufferizePass();
std::unique_ptr<Pass> createStdExpandOpsPass();
/// Collects a set of patterns to rewrite ops within the Std dialect.
-void populateStdExpandOpsPatterns(OwningRewritePatternList &patterns);
+void populateStdExpandOpsPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index dc1fd7e948422..6cb2758459eba 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -18,7 +18,7 @@ class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Creates an instance of `tensor` dialect bufferization pass.
std::unique_ptr<Pass> createTensorBufferizePass();
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 456cc88430a6b..111cb7370eefc 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,11 +40,10 @@ struct BitmaskEnumStorage;
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
- OwningRewritePatternList &patterns);
+ RewritePatternSet &patterns);
/// Collect a set of vector-to-vector transformation patterns.
-void populateVectorToVectorTransformationPatterns(
- OwningRewritePatternList &patterns);
+void populateVectorToVectorTransformationPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to split transfer read/write ops.
///
@@ -55,7 +54,7 @@ void populateVectorToVectorTransformationPatterns(
/// of being generic canonicalization patterns. Also one can let the
/// `ignoreFilter` to return true to fail matching for fine-grained control.
void populateSplitVectorTransferPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
std::function<bool(Operation *)> ignoreFilter = nullptr);
/// Collect a set of leading one dimension removal patterns.
@@ -64,15 +63,14 @@ void populateSplitVectorTransferPatterns(
/// to expose more canonical forms of read/write/insert/extract operations.
/// With them, there are more chances that we can cancel out extract-insert
/// pairs or forward write-read pairs.
-void populateCastAwayVectorLeadingOneDimPatterns(
- OwningRewritePatternList &patterns);
+void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after
/// extract ops where suitable. With them, bitcast will happen on smaller
/// vectors and there are more chances to share extract/insert ops.
-void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns);
+void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
/// Collect a set of vector slices transformation patterns:
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
@@ -82,13 +80,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns);
/// use for "slices" ops), this lowering removes all tuple related
/// operations as well (through DCE and folding). If tuple values
/// "leak" coming in, however, some tuple related ops will remain.
-void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns);
+void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
/// `vector.store` and `vector.broadcast`.
-void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns);
+void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
@@ -172,7 +170,7 @@ struct VectorTransformsOptions {
/// These transformation express higher level vector ops in terms of more
/// elementary extraction, insertion, reduction, product, and broadcast ops.
void populateVectorContractLoweringPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
/// Returns the integer type required for subscripts in the vector dialect.
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 9a0d5537f1738..35eb83d8f03ae 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -27,7 +27,7 @@ class IfOp;
/// Collect a set of patterns to convert from the Vector dialect to itself.
/// Should be merged with populateVectorToSCFLoweringPattern.
void populateVectorToVectorConversionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ MLIRContext *context, RewritePatternSet &patterns,
ArrayRef<int64_t> coarseVectorShape = {},
ArrayRef<int64_t> fineVectorShape = {});
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 8b3b052590a97..145b4cd989e50 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -185,7 +185,7 @@ class OpState {
public:
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
- static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+ static void getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}
protected:
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 19173d16757a0..c2241df46191d 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -67,7 +67,7 @@ using OwningRewritePatternList = RewritePatternSet;
/// the concrete operation types.
class AbstractOperation {
public:
- using GetCanonicalizationPatternsFn = void (*)(OwningRewritePatternList &,
+ using GetCanonicalizationPatternsFn = void (*)(RewritePatternSet &,
MLIRContext *);
using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &);
@@ -126,7 +126,7 @@ class AbstractOperation {
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
- void getCanonicalizationPatterns(OwningRewritePatternList &results,
+ void getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) const {
return getCanonicalizationPatternsFn(results, context);
}
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 115ad5f039bc0..514b7ae06938d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -894,7 +894,7 @@ class RewritePatternSet {
PDLPatternModule pdlPatterns;
};
-// TODO: OwningRewritePatternList is soft-deprecated and will be removed in the
+// TODO: RewritePatternSet is soft-deprecated and will be removed in the
// future.
using OwningRewritePatternList = RewritePatternSet;
diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
index 0e583aab3dc46..a20030cd08da1 100644
--- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
+++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
@@ -27,7 +27,7 @@ class FrozenRewritePatternList {
public:
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternList();
- FrozenRewritePatternList(OwningRewritePatternList &&patterns);
+ FrozenRewritePatternList(RewritePatternSet &&patterns);
FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default;
FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default;
FrozenRewritePatternList &
diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index 9f2c0e3f31a68..22155f7de1d4f 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -56,7 +56,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
///
/// In particular, these are the tensor_load/buffer_cast ops.
void populateEliminateBufferizeMaterializationsPatterns(
- BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns);
+ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
} // end namespace mlir
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b93fffa131a18..ae86b2679eb3c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -423,20 +423,20 @@ struct OpConversionPattern : public ConversionPattern {
/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionLike op with the given type converter. This only supports
/// FunctionLike ops which use FunctionType to represent their type.
-void populateFunctionLikeTypeConversionPattern(
- StringRef functionLikeOpName, OwningRewritePatternList &patterns,
- TypeConverter &converter);
+void populateFunctionLikeTypeConversionPattern(StringRef functionLikeOpName,
+ RewritePatternSet &patterns,
+ TypeConverter &converter);
template <typename FuncOpT>
-void populateFunctionLikeTypeConversionPattern(
- OwningRewritePatternList &patterns, TypeConverter &converter) {
+void populateFunctionLikeTypeConversionPattern(RewritePatternSet &patterns,
+ TypeConverter &converter) {
populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(),
patterns, converter);
}
/// Add a pattern to the given pattern list to convert the signature of a FuncOp
/// with the given type converter.
-void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
+void populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter);
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 4c741d46c9efe..1ad07b2f6e068 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -746,10 +746,9 @@ class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
} // end namespace
-void mlir::populateAffineToStdConversionPatterns(
- OwningRewritePatternList &patterns) {
+void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) {
// clang-format off
- patterns.insert<
+ patterns.add<
AffineApplyLowering,
AffineDmaStartLowering,
AffineDmaWaitLowering,
@@ -766,9 +765,9 @@ void mlir::populateAffineToStdConversionPatterns(
}
void mlir::populateAffineToVectorConversionPatterns(
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
// clang-format off
- patterns.insert<
+ patterns.add<
AffineVectorLoadLowering,
AffineVectorStoreLowering>(patterns.getContext());
// clang-format on
@@ -777,7 +776,7 @@ void mlir::populateAffineToVectorConversionPatterns(
namespace {
class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
void runOnOperation() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateAffineToVectorConversionPatterns(patterns);
ConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
index 1d95f73327fd8..7ac8fa2b6c99d 100644
--- a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
@@ -96,19 +96,19 @@ static Optional<Value> addUnrealizedCast(OpBuilder &builder,
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
-void mlir::populateArmSVEToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+void mlir::populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
converter.addSourceMaterialization(addUnrealizedCast);
// clang-format off
- patterns.insert<ForwardOperands<CallOp>,
+ patterns.add<ForwardOperands<CallOp>,
ForwardOperands<CallIndirectOp>,
ForwardOperands<ReturnOp>>(converter,
&converter.getContext());
- patterns.insert<SdotOpLowering,
+ patterns.add<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 23a826a873063..4452dda43f331 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -875,7 +875,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
// Convert async dialect types and operations to LLVM dialect.
AsyncRuntimeTypeConverter converter;
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
// We use conversion to LLVM type to lower async.runtime load and store
// operations.
@@ -887,24 +887,24 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
populateCallOpTypeConversionPattern(patterns, converter);
// Convert return operations inside async.execute regions.
- patterns.insert<ReturnOpOpConversion>(converter, ctx);
+ patterns.add<ReturnOpOpConversion>(converter, ctx);
// Lower async.runtime operations to the async runtime API calls.
- patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
- RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
- RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
- RuntimeDropRefOpLowering>(converter, ctx);
+ patterns.add<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
+ RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
+ RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
+ RuntimeDropRefOpLowering>(converter, ctx);
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
- patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
- RuntimeLoadOpLowering>(llvmConverter, ctx);
+ patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
+ RuntimeLoadOpLowering>(llvmConverter, ctx);
// Lower async coroutine operations to LLVM coroutine intrinsics.
- patterns.insert<CoroIdOpConversion, CoroBeginOpConversion,
- CoroFreeOpConversion, CoroEndOpConversion,
- CoroSaveOpConversion, CoroSuspendOpConversion>(converter,
- ctx);
+ patterns
+ .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
+ CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
+ converter, ctx);
ConversionTarget target(*ctx);
target.addLegalOp<ConstantOp>();
@@ -985,16 +985,15 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
}
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
- TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
typeConverter.addConversion([&](TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
return ValueType::get(typeConverter.convertType(type.getValueType()));
});
- patterns
- .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
- typeConverter, patterns.getContext());
+ patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 71b2fc05ed288..d5fb64a8eb650 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -258,9 +258,9 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
} // namespace
void mlir::populateComplexToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
- patterns.insert<
+ patterns.add<
AbsOpConversion,
AddOpConversion,
CreateOpConversion,
@@ -284,7 +284,7 @@ void ConvertComplexToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateComplexToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index dde968ced455e..81c9398759539 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -308,7 +308,7 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern
void GpuToLLVMConversionPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
LLVMConversionTarget target(getContext());
populateVectorToLLVMConversionPatterns(converter, patterns);
@@ -320,16 +320,16 @@ void GpuToLLVMConversionPass::runOnOperation() {
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
});
- patterns.insert<ConvertAllocOpToGpuRuntimeCallPattern,
- ConvertDeallocOpToGpuRuntimeCallPattern,
- ConvertHostRegisterOpToGpuRuntimeCallPattern,
- ConvertMemcpyOpToGpuRuntimeCallPattern,
- ConvertWaitAsyncOpToGpuRuntimeCallPattern,
- ConvertWaitOpToGpuRuntimeCallPattern,
- ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
- patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
- converter, gpuBinaryAnnotation);
- patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());
+ patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
+ ConvertDeallocOpToGpuRuntimeCallPattern,
+ ConvertHostRegisterOpToGpuRuntimeCallPattern,
+ ConvertMemcpyOpToGpuRuntimeCallPattern,
+ ConvertWaitAsyncOpToGpuRuntimeCallPattern,
+ ConvertWaitOpToGpuRuntimeCallPattern,
+ ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
+ patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
+ gpuBinaryAnnotation);
+ patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 034d8e9c6b274..d5f89f7e7095a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -125,8 +125,8 @@ struct LowerGpuOpsToNVVMOpsPass
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
});
- OwningRewritePatternList patterns(m.getContext());
- OwningRewritePatternList llvmPatterns(m.getContext());
+ RewritePatternSet patterns(m.getContext());
+ RewritePatternSet llvmPatterns(m.getContext());
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
@@ -158,62 +158,62 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
}
-void mlir::populateGpuToNVVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
populateWithGenerated(patterns);
patterns
- .insert<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
- NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
- GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
- NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
- GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
- NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
- GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
- NVVM::GridDimYOp, NVVM::GridDimZOp>,
- GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
+ .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
+ NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
+ NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
+ NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
+ NVVM::GridDimYOp, NVVM::GridDimZOp>,
+ GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
- patterns.insert<GPUFuncOpLowering>(
+ patterns.add<GPUFuncOpLowering>(
converter, /*allocaAddrSpace=*/0,
Identifier::get(NVVM::NVVMDialect::getKernelFuncAttrName(),
&converter.getContext()));
- patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf",
- "__nv_fabs");
- patterns.insert<OpToFuncCallLowering<math::AtanOp>>(converter, "__nv_atanf",
- "__nv_atan");
- patterns.insert<OpToFuncCallLowering<math::Atan2Op>>(converter, "__nv_atan2f",
- "__nv_atan2");
- patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "__nv_ceilf",
- "__nv_ceil");
- patterns.insert<OpToFuncCallLowering<math::CosOp>>(converter, "__nv_cosf",
- "__nv_cos");
- patterns.insert<OpToFuncCallLowering<math::ExpOp>>(converter, "__nv_expf",
- "__nv_exp");
- patterns.insert<OpToFuncCallLowering<math::ExpM1Op>>(converter, "__nv_expm1f",
- "__nv_expm1");
- patterns.insert<OpToFuncCallLowering<FloorFOp>>(converter, "__nv_floorf",
- "__nv_floor");
- patterns.insert<OpToFuncCallLowering<math::LogOp>>(converter, "__nv_logf",
- "__nv_log");
- patterns.insert<OpToFuncCallLowering<math::Log1pOp>>(converter, "__nv_log1pf",
- "__nv_log1p");
- patterns.insert<OpToFuncCallLowering<math::Log10Op>>(converter, "__nv_log10f",
- "__nv_log10");
- patterns.insert<OpToFuncCallLowering<math::Log2Op>>(converter, "__nv_log2f",
- "__nv_log2");
- patterns.insert<OpToFuncCallLowering<math::PowFOp>>(converter, "__nv_powf",
- "__nv_pow");
- patterns.insert<OpToFuncCallLowering<math::RsqrtOp>>(converter, "__nv_rsqrtf",
- "__nv_rsqrt");
- patterns.insert<OpToFuncCallLowering<math::SinOp>>(converter, "__nv_sinf",
- "__nv_sin");
- patterns.insert<OpToFuncCallLowering<math::SqrtOp>>(converter, "__nv_sqrtf",
- "__nv_sqrt");
- patterns.insert<OpToFuncCallLowering<math::TanhOp>>(converter, "__nv_tanhf",
- "__nv_tanh");
+ patterns.add<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf",
+ "__nv_fabs");
+ patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__nv_atanf",
+ "__nv_atan");
+ patterns.add<OpToFuncCallLowering<math::Atan2Op>>(converter, "__nv_atan2f",
+ "__nv_atan2");
+ patterns.add<OpToFuncCallLowering<CeilFOp>>(converter, "__nv_ceilf",
+ "__nv_ceil");
+ patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__nv_cosf",
+ "__nv_cos");
+ patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__nv_expf",
+ "__nv_exp");
+ patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(converter, "__nv_expm1f",
+ "__nv_expm1");
+ patterns.add<OpToFuncCallLowering<FloorFOp>>(converter, "__nv_floorf",
+ "__nv_floor");
+ patterns.add<OpToFuncCallLowering<math::LogOp>>(converter, "__nv_logf",
+ "__nv_log");
+ patterns.add<OpToFuncCallLowering<math::Log1pOp>>(converter, "__nv_log1pf",
+ "__nv_log1p");
+ patterns.add<OpToFuncCallLowering<math::Log10Op>>(converter, "__nv_log10f",
+ "__nv_log10");
+ patterns.add<OpToFuncCallLowering<math::Log2Op>>(converter, "__nv_log2f",
+ "__nv_log2");
+ patterns.add<OpToFuncCallLowering<math::PowFOp>>(converter, "__nv_powf",
+ "__nv_pow");
+ patterns.add<OpToFuncCallLowering<math::RsqrtOp>>(converter, "__nv_rsqrtf",
+ "__nv_rsqrt");
+ patterns.add<OpToFuncCallLowering<math::SinOp>>(converter, "__nv_sinf",
+ "__nv_sin");
+ patterns.add<OpToFuncCallLowering<math::SqrtOp>>(converter, "__nv_sqrtf",
+ "__nv_sqrt");
+ patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__nv_tanhf",
+ "__nv_tanh");
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 1b5a80720cc91..6cbf3c2798b02 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -60,8 +60,8 @@ struct LowerGpuOpsToROCDLOpsPass
/*useAlignedAlloc =*/false};
LLVMTypeConverter converter(m.getContext(), options);
- OwningRewritePatternList patterns(m.getContext());
- OwningRewritePatternList llvmPatterns(m.getContext());
+ RewritePatternSet patterns(m.getContext());
+ RewritePatternSet llvmPatterns(m.getContext());
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
@@ -92,57 +92,57 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
}
-void mlir::populateGpuToROCDLConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+void mlir::populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
populateWithGenerated(patterns);
- patterns.insert<
- GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
- ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>,
- GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
- ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
- GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, ROCDL::BlockIdXOp,
- ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>,
- GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
- ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
- GPUReturnOpLowering>(converter);
- patterns.insert<GPUFuncOpLowering>(
+ patterns
+ .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
+ ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
+ ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, ROCDL::BlockIdXOp,
+ ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
+ ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
+ GPUReturnOpLowering>(converter);
+ patterns.add<GPUFuncOpLowering>(
converter, /*allocaAddrSpace=*/5,
Identifier::get(ROCDL::ROCDLDialect::getKernelFuncAttrName(),
&converter.getContext()));
- patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
- patterns.insert<OpToFuncCallLowering<math::AtanOp>>(
- converter, "__ocml_atan_f32", "__ocml_atan_f64");
- patterns.insert<OpToFuncCallLowering<math::Atan2Op>>(
+ patterns.add<OpToFuncCallLowering<AbsFOp>>(converter, "__ocml_fabs_f32",
+ "__ocml_fabs_f64");
+ patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__ocml_atan_f32",
+ "__ocml_atan_f64");
+ patterns.add<OpToFuncCallLowering<math::Atan2Op>>(
converter, "__ocml_atan2_f32", "__ocml_atan2_f64");
- patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
- patterns.insert<OpToFuncCallLowering<math::CosOp>>(
- converter, "__ocml_cos_f32", "__ocml_cos_f64");
- patterns.insert<OpToFuncCallLowering<math::ExpOp>>(
- converter, "__ocml_exp_f32", "__ocml_exp_f64");
- patterns.insert<OpToFuncCallLowering<math::ExpM1Op>>(
+ patterns.add<OpToFuncCallLowering<CeilFOp>>(converter, "__ocml_ceil_f32",
+ "__ocml_ceil_f64");
+ patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__ocml_cos_f32",
+ "__ocml_cos_f64");
+ patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__ocml_exp_f32",
+ "__ocml_exp_f64");
+ patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(
converter, "__ocml_expm1_f32", "__ocml_expm1_f64");
- patterns.insert<OpToFuncCallLowering<FloorFOp>>(converter, "__ocml_floor_f32",
- "__ocml_floor_f64");
- patterns.insert<OpToFuncCallLowering<math::LogOp>>(
- converter, "__ocml_log_f32", "__ocml_log_f64");
- patterns.insert<OpToFuncCallLowering<math::Log10Op>>(
+ patterns.add<OpToFuncCallLowering<FloorFOp>>(converter, "__ocml_floor_f32",
+ "__ocml_floor_f64");
+ patterns.add<OpToFuncCallLowering<math::LogOp>>(converter, "__ocml_log_f32",
+ "__ocml_log_f64");
+ patterns.add<OpToFuncCallLowering<math::Log10Op>>(
converter, "__ocml_log10_f32", "__ocml_log10_f64");
- patterns.insert<OpToFuncCallLowering<math::Log1pOp>>(
+ patterns.add<OpToFuncCallLowering<math::Log1pOp>>(
converter, "__ocml_log1p_f32", "__ocml_log1p_f64");
- patterns.insert<OpToFuncCallLowering<math::Log2Op>>(
- converter, "__ocml_log2_f32", "__ocml_log2_f64");
- patterns.insert<OpToFuncCallLowering<math::PowFOp>>(
- converter, "__ocml_pow_f32", "__ocml_pow_f64");
- patterns.insert<OpToFuncCallLowering<math::RsqrtOp>>(
+ patterns.add<OpToFuncCallLowering<math::Log2Op>>(converter, "__ocml_log2_f32",
+ "__ocml_log2_f64");
+ patterns.add<OpToFuncCallLowering<math::PowFOp>>(converter, "__ocml_pow_f32",
+ "__ocml_pow_f64");
+ patterns.add<OpToFuncCallLowering<math::RsqrtOp>>(
converter, "__ocml_rsqrt_f32", "__ocml_rsqrt_f64");
- patterns.insert<OpToFuncCallLowering<math::SinOp>>(
- converter, "__ocml_sin_f32", "__ocml_sin_f64");
- patterns.insert<OpToFuncCallLowering<math::SqrtOp>>(
- converter, "__ocml_sqrt_f32", "__ocml_sqrt_f64");
- patterns.insert<OpToFuncCallLowering<math::TanhOp>>(
- converter, "__ocml_tanh_f32", "__ocml_tanh_f64");
+ patterns.add<OpToFuncCallLowering<math::SinOp>>(converter, "__ocml_sin_f32",
+ "__ocml_sin_f64");
+ patterns.add<OpToFuncCallLowering<math::SqrtOp>>(converter, "__ocml_sqrt_f32",
+ "__ocml_sqrt_f64");
+ patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__ocml_tanh_f32",
+ "__ocml_tanh_f64");
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 5175a877ec396..c2cd4baea631c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -330,9 +330,9 @@ namespace {
}
void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
populateWithGenerated(patterns);
- patterns.insert<
+ patterns.add<
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index a8644c851b489..1f23f7ce380ef 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -57,7 +57,7 @@ void GPUToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
populateStandardToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index e49d6b88191ce..f55c5a814bed7 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -200,9 +200,9 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
} // namespace
/// Populate the given list with patterns that convert from Linalg to LLVM.
-void mlir::populateLinalgToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>(
+void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>(
converter);
// Populate the type conversions for the linalg types.
@@ -221,7 +221,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 052dea406a52c..a94435c043584 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -204,7 +204,6 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
//===----------------------------------------------------------------------===//
void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<SingleWorkgroupReduction>(typeConverter,
- patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index d9df551e33af0..d91444d42af88 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -30,7 +30,7 @@ void LinalgToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateLinalgToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index ce4fe8aafeb0f..72237fdafadaa 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -192,15 +192,15 @@ mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
/// Populate the given list with patterns that convert from Linalg to Standard.
void mlir::linalg::populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
// clang-format off
- patterns.insert<
+ patterns.add<
CopyOpToLibraryCallRewrite,
CopyTransposeRewrite,
IndexedGenericOpToLibraryCallRewrite>(patterns.getContext());
- patterns.insert<LinalgOpToLibraryCallRewrite>();
+ patterns.add<LinalgOpToLibraryCallRewrite>();
// clang-format on
}
@@ -218,7 +218,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 833d51f1bc6db..878e11ae6c5aa 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -41,10 +41,10 @@ struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
};
} // namespace
-void mlir::populateOpenMPToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<RegionOpConversion<omp::ParallelOp>,
- RegionOpConversion<omp::WsLoopOp>>(converter);
+void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<RegionOpConversion<omp::ParallelOp>,
+ RegionOpConversion<omp::WsLoopOp>>(converter);
}
namespace {
@@ -58,7 +58,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to OpenMP operations with LLVM IR dialect
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index b9602ddb70b46..d13cebe3c3a2f 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -642,9 +642,8 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
return success();
}
-void mlir::populateParallelLoopToGPUPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<ParallelToGpuLaunchLowering>(patterns.getContext());
+void mlir::populateParallelLoopToGPUPatterns(RewritePatternSet &patterns) {
+ patterns.add<ParallelToGpuLaunchLowering>(patterns.getContext());
}
void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
index a6ab449b3b6a1..43c6798091e73 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
@@ -47,7 +47,7 @@ struct ForLoopMapper : public ConvertAffineForToGPUBase<ForLoopMapper> {
struct ParallelLoopToGpuPass
: public ConvertParallelLoopToGpuBase<ParallelLoopToGpuPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateParallelLoopToGPUPatterns(patterns);
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 46e67e5e24cc4..1d4d9fe84cf1e 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -90,8 +90,8 @@ static LogicalResult applyPatterns(FuncOp func) {
[](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
target.addLegalDialect<omp::OpenMPDialect>();
- OwningRewritePatternList patterns(func.getContext());
- patterns.insert<ParallelOpLowering>(func.getContext());
+ RewritePatternSet patterns(func.getContext());
+ patterns.add<ParallelOpLowering>(func.getContext());
FrozenRewritePatternList frozen(std::move(patterns));
return applyPartialConversion(func, target, frozen);
}
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 344af6853cbc5..08e3d3f727627 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -321,7 +321,7 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
- OwningRewritePatternList &patterns) {
- patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
+ RewritePatternSet &patterns) {
+ patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl());
}
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 024ff2c0e4c83..637e6a7501b71 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -37,7 +37,7 @@ void SCFToSPIRVPass::runOnOperation() {
SPIRVTypeConverter typeConverter(targetAttr);
ScfToSPIRVContext scfContext;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns);
populateStandardToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 5250d53f2d494..6efba3fc816ce 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -568,15 +568,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}
-void mlir::populateLoopToStdConversionPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
+void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
patterns.getContext());
- patterns.insert<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
+ patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
void SCFToStandardPass::runOnOperation() {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateLoopToStdConversionPatterns(patterns);
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 7f3752f11e045..f10b29a620261 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -278,10 +278,10 @@ class LowerHostCodeToLLVM
/*emitCWrappers=*/true,
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
auto *context = module.getContext();
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
LLVMTypeConverter typeConverter(context, options);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
- patterns.insert<GPULaunchLowering>(typeConverter);
+ patterns.add<GPULaunchLowering>(typeConverter);
// Pull in SPIR-V type conversion patterns to convert SPIR-V global
// variable's type to LLVM dialect type.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 6f6d56f5f936f..d3fc60a5eb6b3 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1385,8 +1385,8 @@ void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
}
void mlir::populateSPIRVToLLVMConversionPatterns(
- LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<
// Arithmetic ops
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
@@ -1499,13 +1499,13 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
}
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
- LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<FuncConversionPattern>(patterns.getContext(), typeConverter);
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
- LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<ModuleConversionPattern, ModuleEndConversionPattern>(
patterns.getContext(), typeConverter);
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
index a807b319a0701..f064bb4fc2ad9 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
@@ -36,7 +36,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
// Encode global variable's descriptor set and binding if they exist.
encodeBindAttribute(module);
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateSPIRVToLLVMTypeConversion(converter);
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index 28697ba1ddc42..d5388dfd4040c 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -37,10 +37,10 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
} // namespace
void mlir::populateConvertShapeConstraintsConversionPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<CstrBroadcastableToRequire>(patterns.getContext());
- patterns.insert<CstrEqToRequire>(patterns.getContext());
- patterns.insert<ConvertCstrRequireOp>(patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<CstrBroadcastableToRequire>(patterns.getContext());
+ patterns.add<CstrEqToRequire>(patterns.getContext());
+ patterns.add<ConvertCstrRequireOp>(patterns.getContext());
}
namespace {
@@ -54,7 +54,7 @@ class ConvertShapeConstraints
auto func = getOperation();
auto *context = &getContext();
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateConvertShapeConstraintsConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 07c5dbefff0a3..2626995b3c93c 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -678,7 +678,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
- OwningRewritePatternList patterns(&ctx);
+ RewritePatternSet patterns(&ctx);
populateShapeToStandardConversionPatterns(patterns);
// Apply conversion.
@@ -688,10 +688,10 @@ void ConvertShapeToStandardPass::runOnOperation() {
}
void mlir::populateShapeToStandardConversionPatterns(
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
// clang-format off
populateWithGenerated(patterns);
- patterns.insert<
+ patterns.add<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 63036c4508a49..5ac7fdd6f5eff 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -3856,10 +3856,10 @@ struct GenericAtomicRMWOpLowering
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
void mlir::populateStdToLLVMNonMemoryConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// FIXME: this should be tablegen'ed
// clang-format off
- patterns.insert<
+ patterns.add<
AbsFOpLowering,
AddFOpLowering,
AddIOpLowering,
@@ -3926,9 +3926,9 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
}
void mlir::populateStdToLLVMMemoryConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
- patterns.insert<
+ patterns.add<
AssumeAlignmentOpLowering,
DeallocOpLowering,
DimOpLowering,
@@ -3945,21 +3945,21 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
ViewOpLowering>(converter);
// clang-format on
if (converter.getOptions().useAlignedAlloc)
- patterns.insert<AlignedAllocOpLowering>(converter);
+ patterns.add<AlignedAllocOpLowering>(converter);
else
- patterns.insert<AllocOpLowering>(converter);
+ patterns.add<AllocOpLowering>(converter);
}
void mlir::populateStdToLLVMFuncOpConversionPattern(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
if (converter.getOptions().useBarePtrCallConv)
- patterns.insert<BarePtrFuncOpConversion>(converter);
+ patterns.add<BarePtrFuncOpConversion>(converter);
else
- patterns.insert<FuncOpConversion>(converter);
+ patterns.add<FuncOpConversion>(converter);
}
-void mlir::populateStdToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
populateStdToLLVMMemoryConversionPatterns(converter, patterns);
@@ -4079,7 +4079,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
llvm::DataLayout(this->dataLayout)};
LLVMTypeConverter typeConverter(&getContext(), options);
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 57f1b1733e3bf..fc5f474671414 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -193,11 +193,11 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
//===----------------------------------------------------------------------===//
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
- OwningRewritePatternList &patterns) {
- patterns.insert<LoadOpOfSubViewFolder<memref::LoadOp>,
- LoadOpOfSubViewFolder<vector::TransferReadOp>,
- StoreOpOfSubViewFolder<memref::StoreOp>,
- StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
+ RewritePatternSet &patterns) {
+ patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
+ LoadOpOfSubViewFolder<vector::TransferReadOp>,
+ StoreOpOfSubViewFolder<memref::StoreOp>,
+ StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
patterns.getContext());
}
@@ -213,7 +213,7 @@ struct SPIRVLegalization final
} // namespace
void SPIRVLegalization::runOnOperation() {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateStdLegalizationPatternsForSPIRVLowering(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 8552db488e614..ed66252e20ae5 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -1225,10 +1225,10 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
namespace mlir {
void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.insert<
+ patterns.add<
// Math dialect operations.
// TODO: Move to separate pass.
UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -1290,15 +1290,15 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
// capability is available.
- patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
- /*benefit=*/2);
+ patterns.add<CmpFOpNanKernelPattern>(typeConverter, context,
+ /*benefit=*/2);
}
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
- OwningRewritePatternList &patterns) {
- patterns.insert<TensorExtractPattern>(typeConverter, patterns.getContext(),
- byteCountThreshold);
+ RewritePatternSet &patterns) {
+ patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
+ byteCountThreshold);
}
} // namespace mlir
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
index a1c6f9831277c..c738537f74382 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
@@ -35,7 +35,7 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateStandardToSPIRVPatterns(typeConverter, patterns);
populateTensorToSPIRVPatterns(typeConverter,
/*byteCountThreshold=*/64, patterns);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d6cc45c4ee600..e0117e0f694fa 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1016,8 +1016,8 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
- OwningRewritePatternList *patterns) {
- patterns->insert<
+ RewritePatternSet *patterns) {
+ patterns->add<
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 7d6815ee50a06..5c0dbc50c2d75 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -37,7 +37,7 @@ struct TosaToLinalgOnTensors
}
void runOnFunction() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect>();
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 4fb06d12d68cc..ef5ccf93765b5 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -103,7 +103,7 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
} // namespace
void mlir::tosa::populateTosaToSCFConversionPatterns(
- OwningRewritePatternList *patterns) {
- patterns->insert<IfOpConverter>(patterns->getContext());
- patterns->insert<WhileOpConverter>(patterns->getContext());
+ RewritePatternSet *patterns) {
+ patterns->add<IfOpConverter>(patterns->getContext());
+ patterns->add<WhileOpConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
index 9b562faa64960..6563a446aaa21 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
@@ -29,7 +29,7 @@ namespace {
struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
public:
void runOnOperation() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 8db7868652b7b..668548d3dec39 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -154,12 +154,12 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
} // namespace
void mlir::tosa::populateTosaToStandardConversionPatterns(
- OwningRewritePatternList *patterns) {
- patterns->insert<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
+ RewritePatternSet *patterns) {
+ patterns->add<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
patterns->getContext());
}
void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
- OwningRewritePatternList *patterns) {
- patterns->insert<ApplyScaleOpConverter>(patterns->getContext());
+ RewritePatternSet *patterns) {
+ patterns->add<ApplyScaleOpConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
index de8768bbe893e..af639cb42e52e 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -29,7 +29,7 @@ namespace {
struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
public:
void runOnOperation() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
target.addIllegalOp<tosa::SliceOp>();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 15553bbd9be52..24c5092894f02 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1482,47 +1482,37 @@ class VectorExtractStridedSliceOpConversion
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+ LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool enableIndexOptimizations) {
MLIRContext *ctx = converter.getDialect()->getContext();
- // clang-format off
- patterns.insert<VectorFMAOpNDRewritePattern,
- VectorInsertStridedSliceOpDifferentRankRewritePattern,
- VectorInsertStridedSliceOpSameRankRewritePattern,
- VectorExtractStridedSliceOpConversion>(ctx);
- patterns.insert<VectorReductionOpConversion>(
- converter, reassociateFPReductions);
- patterns.insert<VectorCreateMaskOpConversion,
- VectorTransferConversion<TransferReadOp>,
- VectorTransferConversion<TransferWriteOp>>(
+ patterns.add<VectorFMAOpNDRewritePattern,
+ VectorInsertStridedSliceOpDifferentRankRewritePattern,
+ VectorInsertStridedSliceOpSameRankRewritePattern,
+ VectorExtractStridedSliceOpConversion>(ctx);
+ patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
+ patterns.add<VectorCreateMaskOpConversion,
+ VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(
converter, enableIndexOptimizations);
patterns
- .insert<VectorBitCastOpConversion,
- VectorShuffleOpConversion,
- VectorExtractElementOpConversion,
- VectorExtractOpConversion,
- VectorFMAOp1DConversion,
- VectorInsertElementOpConversion,
- VectorInsertOpConversion,
- VectorPrintOpConversion,
- VectorTypeCastOpConversion,
- VectorLoadStoreConversion<vector::LoadOp,
- vector::LoadOpAdaptor>,
- VectorLoadStoreConversion<vector::MaskedLoadOp,
- vector::MaskedLoadOpAdaptor>,
- VectorLoadStoreConversion<vector::StoreOp,
- vector::StoreOpAdaptor>,
- VectorLoadStoreConversion<vector::MaskedStoreOp,
- vector::MaskedStoreOpAdaptor>,
- VectorGatherOpConversion,
- VectorScatterOpConversion,
- VectorExpandLoadOpConversion,
- VectorCompressStoreOpConversion>(converter);
- // clang-format on
+ .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
+ VectorExtractElementOpConversion, VectorExtractOpConversion,
+ VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorInsertOpConversion, VectorPrintOpConversion,
+ VectorTypeCastOpConversion,
+ VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
+ VectorLoadStoreConversion<vector::MaskedLoadOp,
+ vector::MaskedLoadOpAdaptor>,
+ VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
+ VectorLoadStoreConversion<vector::MaskedStoreOp,
+ vector::MaskedStoreOpAdaptor>,
+ VectorGatherOpConversion, VectorScatterOpConversion,
+ VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
+ converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<VectorMatmulOpConversion>(converter);
- patterns.insert<VectorFlatTransposeOpConversion>(converter);
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<VectorMatmulOpConversion>(converter);
+ patterns.add<VectorFlatTransposeOpConversion>(converter);
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index b8c43c8c70c84..abddcd73af1ee 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -61,7 +61,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorSlicesLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
@@ -70,7 +70,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 4b097c5c15aa3..5ebed5f80e4a4 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -144,9 +144,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
} // end anonymous namespace
void mlir::populateVectorToROCDLConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<VectorTransferConversion<TransferReadOp>,
- VectorTransferConversion<TransferWriteOp>>(converter);
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(converter);
}
namespace {
@@ -158,7 +158,7 @@ struct LowerVectorToROCDLPass
void LowerVectorToROCDLPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateVectorToROCDLConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3c7c4570f6a7f..1c8e05b2d623f 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -694,10 +694,9 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
}
void populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns,
- const VectorTransferToSCFOptions &options) {
- patterns.insert<VectorTransferRewriter<vector::TransferReadOp>,
- VectorTransferRewriter<vector::TransferWriteOp>>(
+ RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
+ patterns.add<VectorTransferRewriter<vector::TransferReadOp>,
+ VectorTransferRewriter<vector::TransferWriteOp>>(
options, patterns.getContext());
}
@@ -713,7 +712,7 @@ struct ConvertVectorToSCFPass
}
void runOnFunction() override {
- OwningRewritePatternList patterns(getFunction().getContext());
+ RewritePatternSet patterns(getFunction().getContext());
populateVectorToSCFConversionPatterns(
patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2d8ffc0b12a73..4cfcb4148cf0f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -242,11 +242,11 @@ struct VectorInsertStridedSliceOpConvert final
} // namespace
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
- VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorInsertStridedSliceOpConvert>(typeConverter,
- patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
+ VectorInsertElementOpConvert, VectorInsertOpConvert,
+ VectorInsertStridedSliceOpConvert>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index b3c63848ea96d..1915b499fbdb1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -37,7 +37,7 @@ void LowerVectorToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateVectorToSPIRVPatterns(typeConverter, patterns);
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 7db57d383ba3a..d9cb5034a8dc0 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -216,9 +216,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<TileZeroConversion, TileLoadConversion, TileStoreConversion,
- TileMulFConversion, TileMulIConversion>(converter);
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
+ TileMulFConversion, TileMulIConversion>(converter);
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
diff --git a/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
index cfe9f2b3ac028..eaa6498f17529 100644
--- a/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp
@@ -105,10 +105,10 @@ struct RegistryImpl {
/// Registers the patterns specializing the "main" op to one of the
/// "intrinsic" ops depending on elemental type.
static void registerPatterns(LLVMTypeConverter &converter,
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
patterns
- .insert<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
- typename Args::Intr64Op>...>(converter);
+ .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
+ typename Args::Intr64Op>...>(converter);
}
/// Configures the conversion target to lower out "main" ops.
@@ -128,9 +128,9 @@ using Registry = RegistryImpl<
/// Populate the given list with patterns that convert from AVX512 to LLVM.
void mlir::populateAVX512LegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.insert<MaskCompressOpConversion>(converter);
+ patterns.add<MaskCompressOpConversion>(converter);
}
void mlir::configureAVX512LegalizeForExportTarget(
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7962ec21b5ded..930d0bce96c66 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -905,9 +905,9 @@ void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
}
} // end anonymous namespace.
-void AffineApplyOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
+void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyAffineOp<AffineApplyOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1611,9 +1611,9 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
};
} // end anonymous namespace
-void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<AffineForEmptyLoopFolder>(context);
+ results.add<AffineForEmptyLoopFolder>(context);
}
LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
@@ -2033,9 +2033,9 @@ LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
return failure();
}
-void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyDeadElse>(context);
+ results.add<SimplifyDeadElse>(context);
}
//===----------------------------------------------------------------------===//
@@ -2149,9 +2149,9 @@ LogicalResult verify(AffineLoadOp op) {
return success();
}
-void AffineLoadOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
+void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyAffineOp<AffineLoadOp>>(context);
}
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
@@ -2239,9 +2239,9 @@ LogicalResult verify(AffineStoreOp op) {
return success();
}
-void AffineStoreOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
+void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyAffineOp<AffineStoreOp>>(context);
}
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
@@ -2338,9 +2338,9 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
return foldMinMaxOp(*this, operands);
}
-void AffineMinOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
+void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<SimplifyAffineOp<AffineMinOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -2354,9 +2354,9 @@ OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
return foldMinMaxOp(*this, operands);
}
-void AffineMaxOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
+void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<SimplifyAffineOp<AffineMaxOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -2454,10 +2454,10 @@ static LogicalResult verify(AffinePrefetchOp op) {
return success();
}
-void AffinePrefetchOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
// prefetch(memrefcast) -> prefetch
- results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
+ results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
}
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 62cad1f33157f..cd966d404a47a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -227,7 +227,7 @@ void AffineDataCopyGeneration::runOnFunction() {
// Promoting single iteration loops could lead to simplification of
// contained load's/store's, and the latter could anyway also be
// canonicalized.
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternList frozenPatterns(std::move(patterns));
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 512ecd6ee3cd0..c3ec72f51b3fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -79,7 +79,7 @@ mlir::createSimplifyAffineStructuresPass() {
void SimplifyAffineStructures::runOnFunction() {
auto func = getFunction();
simplifiedAttributes.clear();
- OwningRewritePatternList patterns(func.getContext());
+ RewritePatternSet patterns(func.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 12d3a73e2a44a..8e2645a2d44ae 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -188,7 +188,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// effective (no unused operands). Since the pattern rewriter's folding is
// entangled with application of patterns, we may fold/end up erasing the op,
// in which case we return with `folded` being set.
- OwningRewritePatternList patterns(ifOp.getContext());
+ RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
FrozenRewritePatternList frozenPatterns(std::move(patterns));
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index cb124e374ae65..3627635ed0606 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -270,8 +270,8 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
void AsyncParallelForPass::runOnFunction() {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
- patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
+ RewritePatternSet patterns(ctx);
+ patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 99cc0b0e3a409..d511b4f8be5d7 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -485,14 +485,14 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
- OwningRewritePatternList asyncPatterns(ctx);
+ RewritePatternSet asyncPatterns(ctx);
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
- asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
- asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering,
- AwaitAllOpLowering, YieldOpLowering>(ctx,
- outlinedFunctions);
+ asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
+ asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
+ AwaitAllOpLowering, YieldOpLowering>(ctx,
+ outlinedFunctions);
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 3e4189df585a9..13f455f1cd6e7 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -401,6 +401,6 @@ struct GpuAllReduceConversion : public RewritePattern {
};
} // namespace
-void mlir::populateGpuAllReducePatterns(OwningRewritePatternList &patterns) {
- patterns.insert<GpuAllReduceConversion>(patterns.getContext());
+void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
+ patterns.add<GpuAllReduceConversion>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f456c588ffaf6..5d9bb1f5cf03d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -842,10 +842,9 @@ struct FoldInitTensorWithTensorReshapeOp
};
} // namespace
-void InitTensorOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results
- .insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
+void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
}
@@ -1546,9 +1545,9 @@ static LogicalResult verify(ReshapeOp op) {
return success();
}
-void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<CollapseReshapeOps<ReshapeOp>>(context);
+ results.add<CollapseReshapeOps<ReshapeOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1661,10 +1660,10 @@ struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
};
} // namespace
-void TensorReshapeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant,
- ReplaceDimOfReshapeOpResult>(context);
+void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant,
+ ReplaceDimOfReshapeOpResult>(context);
}
//===----------------------------------------------------------------------===//
@@ -2654,11 +2653,11 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
- void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
+ void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
MLIRContext *context) { \
- results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
- RemoveIdentityLinalgOps>(); \
- results.insert<ReplaceDimOfLinalgOpResult>(context); \
+ results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
+ RemoveIdentityLinalgOps>(); \
+ results.add<ReplaceDimOfLinalgOpResult>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index df195af580c73..0f50e13b0acd2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -323,7 +323,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
- OwningRewritePatternList patterns(&context);
+ RewritePatternSet patterns(&context);
populateLinalgBufferizePatterns(typeConverter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -337,11 +337,11 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
}
void mlir::linalg::populateLinalgBufferizePatterns(
- BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
+ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<BufferizeAnyLinalgOp>(typeConverter);
// TODO: Drop this once tensor constants work in standard.
// clang-format off
- patterns.insert<
+ patterns.add<
BufferizeFillOp,
BufferizeInitTensorOp,
SubTensorOpConverter,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index a7e13325262a4..cd4f65953c0a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -45,9 +45,9 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
currentState = nextState;
}
- OwningRewritePatternList stage2Patterns =
+ RewritePatternSet stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
- stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
+ stage2Patterns.add<AffineMinSCFCanonicalizationPattern>(context);
auto stage3Transforms = [&](Operation *op) {
// Some of these may be too aggressive as a stage 3 that is applied on each
@@ -76,18 +76,18 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Programmatic splitting of slow/fast path vector transfers.
if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
- OwningRewritePatternList patterns(context);
- patterns.insert<vector::VectorTransferFullPartialRewriter>(
+ RewritePatternSet patterns(context);
+ patterns.add<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
// Programmatic controlled lowering of vector.contract only.
if (lateCodegenStrategyOptions.enableVectorContractLowering) {
- OwningRewritePatternList vectorContractLoweringPatterns(context);
+ RewritePatternSet vectorContractLoweringPatterns(context);
vectorContractLoweringPatterns
- .insert<ContractionOpToOuterProductOpLowering,
- ContractionOpToMatmulOpLowering, ContractionOpLowering>(
+ .add<ContractionOpToOuterProductOpLowering,
+ ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
(void)applyPatternsAndFoldGreedily(
func, std::move(vectorContractLoweringPatterns));
@@ -95,7 +95,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Programmatic controlled lowering of vector.transfer only.
if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
- OwningRewritePatternList vectorToLoopsPatterns(context);
+ RewritePatternSet vectorToLoopsPatterns(context);
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
vectorToSCFOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index cc95218d870fb..aece769721ca5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -163,7 +163,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
void runOnFunction() override {
auto *context = &getContext();
DetensorizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
@@ -194,9 +194,9 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
op, typeConverter, /*returnOpAlwaysLegal*/ true);
});
- patterns.insert<DetensorizeGenericOp>(typeConverter, context);
- patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
- context, typeConverter);
+ patterns.add<DetensorizeGenericOp>(typeConverter, context);
+ patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
+ context, typeConverter);
// Since non-entry block arguments get detensorized, we also need to update
// the control flow inside the function to reflect the correct types.
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
@@ -204,8 +204,8 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
- OwningRewritePatternList canonPatterns(context);
- canonPatterns.insert<ExtractFromReshapeFromElements>(context);
+ RewritePatternSet canonPatterns(context);
+ canonPatterns.add<ExtractFromReshapeFromElements>(context);
if (failed(applyPatternsAndFoldGreedily(getFunction(),
std::move(canonPatterns))))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index a8db840fbd0f0..b771420318e50 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -490,14 +490,13 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns
- .insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
- ReplaceUnitExtentTensors<GenericOp>,
- ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
+ patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
+ ReplaceUnitExtentTensors<GenericOp>,
+ ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
- patterns.insert<FoldReshapeOpWithUnitExtent>(context);
+ patterns.add<FoldReshapeOpWithUnitExtent>(context);
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
@@ -508,10 +507,11 @@ struct LinalgFoldUnitExtentDimsPass
void runOnFunction() override {
FuncOp funcOp = getFunction();
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly)
- patterns.insert<FoldUnitDimLoops<GenericOp>,
- FoldUnitDimLoops<IndexedGenericOp>>(context);
+ patterns
+ .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
+ context);
else
populateLinalgFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 48677dffbc7a3..321961d2deac9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -116,8 +116,8 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
} // namespace
void mlir::populateElementwiseToLinalgConversionPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>();
+ RewritePatternSet &patterns) {
+ patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>();
}
namespace {
@@ -128,7 +128,7 @@ class ConvertElementwiseToLinalgPass
auto func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index a6d0fd5dd7b75..4b0951ea4c1c4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1133,7 +1133,7 @@ struct FusionOfTensorOpsPass
: public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
void runOnOperation() override {
Operation *op = getOperation();
- OwningRewritePatternList patterns(op->getContext());
+ RewritePatternSet patterns(op->getContext());
populateLinalgTensorOpsFusionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
@@ -1146,7 +1146,7 @@ struct FoldReshapeOpsByLinearizationPass
FoldReshapeOpsByLinearizationPass> {
void runOnOperation() override {
Operation *op = getOperation();
- OwningRewritePatternList patterns(op->getContext());
+ RewritePatternSet patterns(op->getContext());
populateFoldReshapeOpsByLinearizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
@@ -1155,35 +1155,35 @@ struct FoldReshapeOpsByLinearizationPass
} // namespace
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
- FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
- FoldConsumerReshapeOpByLinearization<false>>(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
+ FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
+ FoldConsumerReshapeOpByLinearization<false>>(
patterns.getContext());
}
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
- FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
- FoldConsumerReshapeOpByLinearization<true>>(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
+ FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
+ FoldConsumerReshapeOpByLinearization<true>>(
patterns.getContext());
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<FoldReshapeWithGenericOpByExpansion,
- FoldWithProducerReshapeOpByExpansion<GenericOp>,
- FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldReshapeWithGenericOpByExpansion,
+ FoldWithProducerReshapeOpByExpansion<GenericOp>,
+ FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
patterns.getContext());
}
-void mlir::populateLinalgTensorOpsFusionPatterns(
- OwningRewritePatternList &patterns) {
+void mlir::populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
- FoldSplatConstants<GenericOp>,
- FoldSplatConstants<IndexedGenericOp>>(context);
+ patterns
+ .add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
+ FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
+ context);
populateFoldReshapeOpsByExpansionPatterns(patterns);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3783ef54a31aa..cb959a866935e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -143,7 +143,7 @@ struct LinalgGeneralizationPass
void LinalgGeneralizationPass::runOnFunction() {
FuncOp func = getFunction();
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
linalg::populateLinalgConvGeneralizationPatterns(patterns);
linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
@@ -167,16 +167,14 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
}
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
- OwningRewritePatternList &patterns,
- linalg::LinalgTransformationFilter marker) {
- patterns.insert<GeneralizeConvOp>(patterns.getContext(), marker);
+ RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
+ patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
- OwningRewritePatternList &patterns,
- linalg::LinalgTransformationFilter marker) {
- patterns.insert<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
- marker);
+ RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
+ patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
+ marker);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 635855fdecfe6..1fcd3f4ed8753 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -378,7 +378,7 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
// Apply canonicalization so the newForOp + yield folds immediately, thus
// cleaning up the IR and potentially enabling more hoisting.
if (changed) {
- OwningRewritePatternList patterns(func->getContext());
+ RewritePatternSet patterns(func->getContext());
scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 10b4cacb2df8e..5bc6cefe489a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -545,11 +545,11 @@ template <typename LoopType>
static void lowerLinalgToLoopsImpl(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList patterns(context);
- patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
+ RewritePatternSet patterns(context);
+ patterns.add<LinalgRewritePattern<LoopType>>(interchangeVector);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
- patterns.insert<FoldAffineOp>(context);
+ patterns.add<FoldAffineOp>(context);
// Just apply the patterns greedily.
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
index 1fc82d5383ac5..f4c84fc4366b0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
@@ -137,8 +137,8 @@ class TensorToValuesConverter
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationConversionPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<TensorFromPointerConverter, TensorToDimSizeConverter,
- TensorToPointersConverter, TensorToIndicesConverter,
- TensorToValuesConverter>(patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<TensorFromPointerConverter, TensorToDimSizeConverter,
+ TensorToPointersConverter, TensorToIndicesConverter,
+ TensorToValuesConverter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index c74024110cc82..c0c1970290fca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -1361,6 +1361,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
/// Populates the given patterns list with rewriting rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationPatterns(
- OwningRewritePatternList &patterns, const SparsificationOptions &options) {
- patterns.insert<GenericOpSparsifier>(patterns.getContext(), options);
+ RewritePatternSet &patterns, const SparsificationOptions &options) {
+ patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3f4c698304829..aaf00721732dc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -511,13 +511,13 @@ class CanonicalizationPatternList;
template <>
class CanonicalizationPatternList<> {
public:
- static void insert(OwningRewritePatternList &patterns) {}
+ static void insert(RewritePatternSet &patterns) {}
};
template <typename OpTy, typename... OpTypes>
class CanonicalizationPatternList<OpTy, OpTypes...> {
public:
- static void insert(OwningRewritePatternList &patterns) {
+ static void insert(RewritePatternSet &patterns) {
OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
CanonicalizationPatternList<OpTypes...>::insert(patterns);
}
@@ -530,17 +530,17 @@ class RewritePatternList;
template <>
class RewritePatternList<> {
public:
- static void insert(OwningRewritePatternList &patterns,
+ static void insert(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {}
};
template <typename OpTy, typename... OpTypes>
class RewritePatternList<OpTy, OpTypes...> {
public:
- static void insert(OwningRewritePatternList &patterns,
+ static void insert(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {
auto *ctx = patterns.getContext();
- patterns.insert<LinalgTilingPattern<OpTy>>(
+ patterns.add<LinalgTilingPattern<OpTy>>(
ctx, options,
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("tiled", ctx)));
@@ -549,15 +549,15 @@ class RewritePatternList<OpTy, OpTypes...> {
};
} // namespace
-OwningRewritePatternList
+RewritePatternSet
mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
populateLinalgTilingCanonicalizationPatterns(patterns);
return patterns;
}
void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
auto *ctx = patterns.getContext();
AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
AffineForOp::getCanonicalizationPatterns(patterns, ctx);
@@ -577,7 +577,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
}
/// Populate the given list with patterns that apply Linalg tiling.
-static void insertTilingPatterns(OwningRewritePatternList &patterns,
+static void insertTilingPatterns(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {
RewritePatternList<GenericOp, IndexedGenericOp,
#define GET_OP_LIST
@@ -591,7 +591,7 @@ static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
auto options =
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType);
MLIRContext *ctx = funcOp.getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
insertTilingPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b56072cf0d08a..d4581013ae695 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -576,23 +576,21 @@ using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
/// Inserts tiling, promotion and vectorization pattern for ConvOp
/// conversion into corresponding pattern lists.
template <typename ConvOp, unsigned N>
-static void
-populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
- OwningRewritePatternList &promotionPatterns,
- OwningRewritePatternList &vectorizationPatterns,
- ArrayRef<int64_t> tileSizes) {
+static void populateVectorizationPatterns(
+ RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
+ RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
auto *context = tilingPatterns.getContext();
if (tileSizes.size() < N)
return;
constexpr static StringRef kTiledMarker = "TILED";
constexpr static StringRef kPromotedMarker = "PROMOTED";
- tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
+ tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
context, LinalgTilingOptions().setTileSizes(tileSizes),
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get(kTiledMarker, context)));
- promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
+ promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
Identifier::get(kPromotedMarker, context)));
@@ -602,15 +600,15 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
[](int64_t i) -> bool { return i > 1; });
- vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
+ vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
}
void mlir::linalg::populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
+ MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes) {
- OwningRewritePatternList tiling(context);
- OwningRewritePatternList promotion(context);
- OwningRewritePatternList vectorization(context);
+ RewritePatternSet tiling(context);
+ RewritePatternSet promotion(context);
+ RewritePatternSet vectorization(context);
populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
tileSizes);
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6c5d74f81598c..d9c78f527c8a2 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -532,7 +532,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<TanhApproximation, LogApproximation, Log2Approximation,
- ExpApproximation>(patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
+ ExpApproximation>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fddab63e3e987..e0e273d85669a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -209,14 +209,14 @@ struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
};
} // end anonymous namespace.
-void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
+ results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
}
-void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyAllocConst<AllocaOp>>(context);
+ results.add<SimplifyAllocConst<AllocaOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -290,9 +290,9 @@ struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
} // namespace
-void BufferCastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<BufferCast, TensorLoadToMemRef>(context);
+void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<BufferCast, TensorLoadToMemRef>(context);
}
//===----------------------------------------------------------------------===//
@@ -498,9 +498,9 @@ static LogicalResult verify(DeallocOp op) {
return success();
}
-void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyDeadDealloc>(context);
+ results.add<SimplifyDeadDealloc>(context);
}
LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
@@ -677,10 +677,10 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
};
} // end anonymous namespace.
-void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
- DimOfCastOp<tensor::CastOp>>(context);
+ results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
+ DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------
@@ -1069,9 +1069,9 @@ struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
};
} // end anonymous namespace.
-void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<LoadOfBufferCast>(context);
+ results.add<LoadOfBufferCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -1802,11 +1802,11 @@ struct SubViewCanonicalizer {
}
};
-void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<
- SubViewOp, SubViewCanonicalizer>,
- SubViewOpMemRefCastFolder>(context);
+ results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+ SubViewOp, SubViewCanonicalizer>,
+ SubViewOpMemRefCastFolder>(context);
}
OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
@@ -2085,9 +2085,9 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
} // end anonymous namespace
-void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
+ results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index 44d8be993a1d7..6c5ef29e82283 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -91,10 +91,10 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
}
void ConvertConstPass::runOnFunction() {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
auto func = getFunction();
auto *context = &getContext();
- patterns.insert<QuantizedConstRewrite>(context);
+ patterns.add<QuantizedConstRewrite>(context);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
index ac28ce6ee9c23..c50d09a2c0653 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
@@ -125,9 +125,9 @@ class ConstFakeQuantPerAxisRewrite
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
auto func = getFunction();
- OwningRewritePatternList patterns(func.getContext());
+ RewritePatternSet patterns(func.getContext());
auto ctx = func.getContext();
- patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
+ patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
ctx, &hadFailure);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
if (hadFailure)
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 78c72953ee6f5..2d1ad054bf052 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -703,10 +703,10 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
};
} // namespace
-void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops,
- LastTensorLoadCanonicalization>(context);
+ results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
+ LastTensorLoadCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
@@ -973,10 +973,10 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
};
} // namespace
-void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<RemoveUnusedResults, RemoveStaticCondition,
- ConvertTrivialIfToSelect>(context);
+ results.add<RemoveUnusedResults, RemoveStaticCondition,
+ ConvertTrivialIfToSelect>(context);
}
//===----------------------------------------------------------------------===//
@@ -1275,10 +1275,9 @@ struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
} // namespace
-void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
- context);
+ results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 15a5abaa13561..9ee8075892426 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -25,7 +25,7 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 0029c3b70a0e0..107c32779e926 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -134,9 +134,9 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
} // namespace
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
- TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
- patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
+ patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index b791695c1dd9e..437c762b5f550 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -111,17 +111,17 @@ struct CombineChainedAccessChain
} // end anonymous namespace
void spirv::AccessChainOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CombineChainedAccessChain>(context);
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<CombineChainedAccessChain>(context);
}
//===----------------------------------------------------------------------===//
// spv.BitcastOp
//===----------------------------------------------------------------------===//
-void spirv::BitcastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ConvertChainedBitcast>(context);
+void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ConvertChainedBitcast>(context);
}
//===----------------------------------------------------------------------===//
@@ -230,10 +230,11 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
void spirv::LogicalNotOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
- ConvertLogicalNotOfLogicalEqual,
- ConvertLogicalNotOfLogicalNotEqual>(context);
+ RewritePatternSet &results, MLIRContext *context) {
+ results
+ .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
+ ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -415,7 +416,7 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
}
} // end anonymous namespace
-void spirv::SelectionOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ConvertSelectionOpToSelect>(context);
+void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ConvertSelectionOpToSelect>(context);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
index c5eeb8a0b836a..9c72dbb6f7aa3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
@@ -22,14 +22,13 @@ namespace {
namespace mlir {
namespace spirv {
-void populateSPIRVGLSLCanonicalizationPatterns(
- OwningRewritePatternList &results) {
- results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
- ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
- ConvertComparisonIntoClampSPV_SLessThanOp,
- ConvertComparisonIntoClampSPV_SLessThanEqualOp,
- ConvertComparisonIntoClampSPV_ULessThanOp,
- ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
+void populateSPIRVGLSLCanonicalizationPatterns(RewritePatternSet &results) {
+ results.add<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
+ ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
+ ConvertComparisonIntoClampSPV_SLessThanOp,
+ ConvertComparisonIntoClampSPV_SLessThanEqualOp,
+ ConvertComparisonIntoClampSPV_ULessThanOp,
+ ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
results.getContext());
}
} // namespace spirv
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index afaadb08788e1..87aa623b7abc9 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -74,10 +74,9 @@ class SPIRVAddressOfOpLayoutInfoDecoration
};
} // namespace
-static void
-populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns) {
- patterns.insert<SPIRVGlobalVariableOpLayoutInfoDecoration,
- SPIRVAddressOfOpLayoutInfoDecoration>(patterns.getContext());
+static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns) {
+ patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
+ SPIRVAddressOfOpLayoutInfoDecoration>(patterns.getContext());
}
namespace {
@@ -90,7 +89,7 @@ class DecorateSPIRVCompositeTypeLayoutPass
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
auto module = getOperation();
- OwningRewritePatternList patterns(module.getContext());
+ RewritePatternSet patterns(module.getContext());
populateSPIRVLayoutInfoPatterns(patterns);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 71ebf8c53b354..cd2c9c52eeae8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -246,8 +246,8 @@ void LowerABIAttributesPass::runOnOperation() {
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});
- OwningRewritePatternList patterns(context);
- patterns.insert<ProcessInterfaceVarABI>(typeConverter, context);
+ RewritePatternSet patterns(context);
+ patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4aa8bd4ecd29f..7539e9a050765 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -514,9 +514,9 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
return success();
}
-void mlir::populateBuiltinFuncToSPIRVPatterns(
- SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<FuncOpConversion>(typeConverter, patterns.getContext());
+void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f3a66a7bc67ae..89f3422e126f1 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -270,10 +270,10 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
};
} // namespace
-void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// If taking a passing witness, inline region.
- patterns.insert<AssumingWithTrue>(context);
+ patterns.add<AssumingWithTrue>(context);
}
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
@@ -315,9 +315,9 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
// AssumingAllOp
//===----------------------------------------------------------------------===//
-void AssumingAllOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<AssumingAllOneOp>(context);
+void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AssumingAllOneOp>(context);
}
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
@@ -430,10 +430,10 @@ struct BroadcastForwardSingleOperandPattern
};
} // namespace
-void BroadcastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<BroadcastForwardSingleOperandPattern,
- RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<BroadcastForwardSingleOperandPattern,
+ RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -500,9 +500,9 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
-void ConstShapeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<TensorCastConstShape>(context);
+void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<TensorCastConstShape>(context);
}
//===----------------------------------------------------------------------===//
@@ -528,11 +528,11 @@ LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
} // namespace
void CstrBroadcastableOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ RewritePatternSet &patterns, MLIRContext *context) {
// Canonicalization patterns have overlap with the considerations during
// folding in case additional shape information is inferred at some point that
// does not result in folding.
- patterns.insert<CstrBroadcastableEqOps>(context);
+ patterns.add<CstrBroadcastableEqOps>(context);
}
// Return true if there is exactly one attribute not representing a scalar
@@ -595,10 +595,10 @@ static LogicalResult verify(CstrBroadcastableOp op) {
// CstrEqOp
//===----------------------------------------------------------------------===//
-void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// If inputs are equal, return passing witness
- patterns.insert<CstrEqEqOps>(context);
+ patterns.add<CstrEqEqOps>(context);
}
OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
@@ -697,9 +697,9 @@ OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-void IndexToSizeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<SizeToIndexToSizeCanonicalization>(context);
+void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<SizeToIndexToSizeCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
@@ -817,9 +817,9 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
// IsBroadcastableOp
//===----------------------------------------------------------------------===//
-void IsBroadcastableOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
+void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
}
OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
@@ -885,9 +885,9 @@ struct RankShapeOfCanonicalizationPattern
};
} // namespace
-void shape::RankOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<RankShapeOfCanonicalizationPattern>(context);
+void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<RankShapeOfCanonicalizationPattern>(context);
}
//===----------------------------------------------------------------------===//
@@ -970,9 +970,9 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
};
} // namespace
-void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.insert<ShapeOfWithTensor>(context);
+ patterns.add<ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===//
@@ -987,9 +987,9 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
-void SizeToIndexOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<IndexToSizeToIndexCanonicalization>(context);
+void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<IndexToSizeToIndexCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index 779993c01f75c..ea2f97d7d0569 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -19,7 +19,7 @@ struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns(&ctx);
+ RewritePatternSet patterns(&ctx);
BufferizeTypeConverter typeConverter;
ConversionTarget target(ctx);
diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
index b71226465c342..dc403b0ceacd8 100644
--- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
@@ -46,7 +46,7 @@ class RemoveShapeConstraintsPass
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns(&ctx);
+ RewritePatternSet patterns(&ctx);
populateRemoveShapeConstraintsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
@@ -55,9 +55,8 @@ class RemoveShapeConstraintsPass
} // namespace
-void mlir::populateRemoveShapeConstraintsPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
+void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
+ patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index 479ce71ac2cdf..66c4e5048f105 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -61,7 +61,7 @@ struct ShapeToShapeLowering
void ShapeToShapeLowering::runOnFunction() {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns(&ctx);
+ RewritePatternSet patterns(&ctx);
populateShapeRewritePatterns(patterns);
ConversionTarget target(getContext());
@@ -72,8 +72,8 @@ void ShapeToShapeLowering::runOnFunction() {
signalPassFailure();
}
-void mlir::populateShapeRewritePatterns(OwningRewritePatternList &patterns) {
- patterns.insert<NumElementsOpConverter>(patterns.getContext());
+void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<NumElementsOpConverter>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
index 6ebf9fc5b0cd2..b58fa4dfdc87d 100644
--- a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
@@ -57,9 +57,9 @@ class ConvertAssumingYieldOpTypes
} // namespace
void mlir::populateShapeStructuralTypeConversionsAndLegality(
- TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
- patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
+ patterns.add<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
return typeConverter.isLegal(op.getResultTypes());
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 4830a51827a53..5f331eef241c6 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -324,9 +324,9 @@ struct EraseRedundantAssertions : public OpRewritePattern<AssertOp> {
};
} // namespace
-void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+void AssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.insert<EraseRedundantAssertions>(context);
+ patterns.add<EraseRedundantAssertions>(context);
}
//===----------------------------------------------------------------------===//
@@ -553,10 +553,9 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
-void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(
- context);
+ results.add<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>(context);
}
Optional<MutableOperandRange>
@@ -631,9 +630,9 @@ struct SimplifyIndirectCallWithKnownCallee
};
} // end anonymous namespace.
-void CallIndirectOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyIndirectCallWithKnownCallee>(context);
+void CallIndirectOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyIndirectCallWithKnownCallee>(context);
}
//===----------------------------------------------------------------------===//
@@ -965,11 +964,11 @@ struct SimplifyCondBranchFromCondBranchOnSameCondition
};
} // end anonymous namespace
-void CondBranchOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
- SimplifyCondBranchIdenticalSuccessors,
- SimplifyCondBranchFromCondBranchOnSameCondition>(context);
+void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
+ SimplifyCondBranchIdenticalSuccessors,
+ SimplifyCondBranchFromCondBranchOnSameCondition>(context);
}
Optional<MutableOperandRange>
@@ -2017,11 +2016,11 @@ struct SubTensorCanonicalizer {
}
};
-void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<OpWithOffsetSizesAndStridesConstantArgumentFolder<
- SubTensorOp, SubTensorCanonicalizer>,
- SubTensorOpCastFolder>(context);
+ results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+ SubTensorOp, SubTensorCanonicalizer>,
+ SubTensorOpCastFolder>(context);
}
//
@@ -2188,10 +2187,10 @@ struct SubTensorInsertOpCastFolder final
};
} // namespace
-void SubTensorInsertOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SubTensorInsertOpConstantArgumentFolder,
- SubTensorInsertOpCastFolder>(context);
+void SubTensorInsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SubTensorInsertOpConstantArgumentFolder,
+ SubTensorInsertOpCastFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 6eeb39e661ec5..040bdc81f23b4 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -55,9 +55,9 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
} // namespace
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
- patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
+ patterns.getContext());
}
namespace {
@@ -65,7 +65,7 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
index 8d1d6befa66fa..a3dd9a4be5ec8 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
@@ -188,9 +188,9 @@ struct DecomposeCallGraphTypesForCallOp
void mlir::populateDecomposeCallGraphTypesPatterns(
MLIRContext *context, TypeConverter &typeConverter,
- ValueDecomposer &decomposer, OwningRewritePatternList &patterns) {
- patterns.insert<DecomposeCallGraphTypesForCallOp,
- DecomposeCallGraphTypesForFuncArgs,
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
- decomposer);
+ ValueDecomposer &decomposer, RewritePatternSet &patterns) {
+ patterns
+ .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
+ decomposer);
}
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 3f2504e0142ba..fd1b36907ae1e 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -211,7 +211,7 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns(&ctx);
+ RewritePatternSet patterns(&ctx);
populateStdExpandOpsPatterns(patterns);
ConversionTarget target(getContext());
@@ -234,9 +234,9 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
} // namespace
-void mlir::populateStdExpandOpsPatterns(OwningRewritePatternList &patterns) {
- patterns.insert<AtomicRMWOpConverter, MemRefReshapeOpConverter,
- SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
+void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
+ patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter,
+ SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
index 04424c75613f0..21ca1c3a82c2a 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
@@ -28,7 +28,7 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateFuncOpTypeConversionPattern(patterns, typeConverter);
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index 40086769e8892..b0283fe2601bf 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -37,9 +37,9 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
};
} // end anonymous namespace
-void mlir::populateCallOpTypeConversionPattern(
- OwningRewritePatternList &patterns, TypeConverter &converter) {
- patterns.insert<CallOpSignatureConversion>(converter, patterns.getContext());
+void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
+ TypeConverter &converter) {
+ patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
}
namespace {
@@ -102,9 +102,9 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
} // end anonymous namespace
void mlir::populateBranchOpInterfaceTypeConversionPattern(
- OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
- patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter,
- patterns.getContext());
+ RewritePatternSet &patterns, TypeConverter &typeConverter) {
+ patterns.add<BranchOpInterfaceTypeConversion>(typeConverter,
+ patterns.getContext());
}
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
@@ -123,9 +123,9 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
return false;
}
-void mlir::populateReturnOpTypeConversionPattern(
- OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
- patterns.insert<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
+void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
+ TypeConverter &typeConverter) {
+ patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
}
bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 625bdc1d453c1..b40e47c944141 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -90,11 +90,11 @@ struct TensorConstantBufferizePass
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
- patterns.insert<BufferizeTensorConstantOp>(globals, typeConverter, context);
+ patterns.add<BufferizeTensorConstantOp>(globals, typeConverter, context);
target.addDynamicallyLegalOp<ConstantOp>(
[&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3da606131a411..9dc9240cc4623 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -177,9 +177,9 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
} // namespace
-void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ChainedTensorCast>(context);
+ results.add<ChainedTensorCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -275,9 +275,9 @@ struct ExtractElementFromTensorFromElements
} // namespace
-void FromElementsOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ExtractElementFromTensorFromElements>(context);
+void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ExtractElementFromTensorFromElements>(context);
}
//===----------------------------------------------------------------------===//
@@ -435,11 +435,11 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
} // namespace
-void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO: Move extract patterns to tensor::ExtractOp.
- results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
- StaticTensorGenerate>(context);
+ results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
+ StaticTensorGenerate>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 4c1d0b729ee35..a52b5f69c08f3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -138,9 +138,9 @@ class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
} // namespace
void mlir::populateTensorBufferizePatterns(
- BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
- BufferizeGenerateOp>(typeConverter, patterns.getContext());
+ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
+ BufferizeGenerateOp>(typeConverter, patterns.getContext());
}
namespace {
@@ -148,7 +148,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 2ab1a648f8c4d..e2e19c2d0d1e5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -251,20 +251,20 @@ struct TosaMakeBroadcastable
public:
void runOnFunction() override {
auto func = getFunction();
- OwningRewritePatternList patterns(func.getContext());
+ RewritePatternSet patterns(func.getContext());
MLIRContext *ctx = func.getContext();
// Add the generated patterns to the list.
- patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::SubOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::MulOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::MaximumOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::MinimumOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::EqualOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::GreaterOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
- patterns.insert<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 23b194d293a54..d1703caccc462 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -714,11 +714,10 @@ struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
}
};
-void ContractionOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results
- .insert<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
- context);
+void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -1332,9 +1331,9 @@ class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
} // namespace
-void BroadcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<BroadcastToShapeCast>(context);
+ results.add<BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -2150,11 +2149,11 @@ class StridedSliceBroadcast final
} // end anonymous namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+ RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
- StridedSliceBroadcast>(context);
+ results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
+ StridedSliceBroadcast>(context);
}
//===----------------------------------------------------------------------===//
@@ -2778,9 +2777,9 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
};
} // namespace
-void MaskedLoadOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<MaskedLoadFolder>(context);
+void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<MaskedLoadFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -2823,9 +2822,9 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
};
} // namespace
-void MaskedStoreOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<MaskedStoreFolder>(context);
+void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<MaskedStoreFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -2871,9 +2870,9 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
};
} // namespace
-void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<GatherFolder>(context);
+ results.add<GatherFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -2917,9 +2916,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
};
} // namespace
-void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ScatterFolder>(context);
+ results.add<ScatterFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -2965,9 +2964,9 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
};
} // namespace
-void ExpandLoadOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ExpandLoadFolder>(context);
+void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ExpandLoadFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -3011,9 +3010,9 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
};
} // namespace
-void CompressStoreOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CompressStoreFolder>(context);
+void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CompressStoreFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -3147,10 +3146,10 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
} // namespace
-void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
- results.insert<ShapeCastConstantFolder>(context);
+ results.add<ShapeCastConstantFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -3393,8 +3392,8 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
} // end anonymous namespace
void vector::TransposeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<TransposeFolder>(context);
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<TransposeFolder>(context);
}
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
@@ -3528,17 +3527,18 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
} // end anonymous namespace
-void CreateMaskOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CreateMaskFolder>(context);
+void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CreateMaskFolder>(context);
}
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
- GatherFolder, ScatterFolder, ExpandLoadFolder,
- CompressStoreFolder, StridedSliceConstantMaskFolder,
- TransposeFolder>(patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns
+ .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
+ ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
+ StridedSliceConstantMaskFolder, TransposeFolder>(
+ patterns.getContext());
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 16664b1742694..8766efa406c28 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3263,52 +3263,52 @@ struct BubbleUpBitCastForStridedSliceInsert
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
- TransferReadExtractPattern, TransferWriteInsertPattern>(
+ RewritePatternSet &patterns) {
+ patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
+ TransferReadExtractPattern, TransferWriteInsertPattern>(
patterns.getContext());
}
void mlir::vector::populateSplitVectorTransferPatterns(
- OwningRewritePatternList &patterns,
+ RewritePatternSet &patterns,
std::function<bool(Operation *)> ignoreFilter) {
- patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(
- patterns.getContext(), ignoreFilter);
+ patterns.add<SplitTransferReadOp, SplitTransferWriteOp>(patterns.getContext(),
+ ignoreFilter);
}
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
- CastAwayInsertStridedSliceLeadingOneDim,
- CastAwayTransferReadLeadingOneDim,
- CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
+ RewritePatternSet &patterns) {
+ patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayTransferReadLeadingOneDim,
+ CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
patterns.getContext());
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<BubbleDownVectorBitCastForExtract,
- BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<BubbleDownVectorBitCastForExtract,
+ BubbleDownBitCastForStridedSliceExtract,
+ BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractSlicesOpLowering, InsertSlicesOpLowering>(
patterns.getContext());
}
void mlir::vector::populateVectorContractLoweringPatterns(
- OwningRewritePatternList &patterns, VectorTransformsOptions parameters) {
+ RewritePatternSet &patterns, VectorTransformsOptions parameters) {
// clang-format off
- patterns.insert<BroadcastOpLowering,
+ patterns.add<BroadcastOpLowering,
CreateMaskOpLowering,
ConstantMaskOpLowering,
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern,
ShapeCastOpRewritePattern>(patterns.getContext());
- patterns.insert<TransposeOpLowering,
+ patterns.add<TransposeOpLowering,
ContractionOpLowering,
ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
@@ -3316,7 +3316,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
}
void mlir::vector::populateVectorTransferLoweringPatterns(
- OwningRewritePatternList &patterns) {
- patterns.insert<TransferReadToVectorLoadLowering,
- TransferWriteToVectorStoreLowering>(patterns.getContext());
+ RewritePatternSet &patterns) {
+ patterns.add<TransferReadToVectorLoadLowering,
+ TransferWriteToVectorStoreLowering>(patterns.getContext());
}
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
index c2de51a647dd4..b61307b81b9f9 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
@@ -53,8 +53,7 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
FrozenRewritePatternList::FrozenRewritePatternList()
: impl(std::make_shared<Impl>()) {}
-FrozenRewritePatternList::FrozenRewritePatternList(
- OwningRewritePatternList &&patterns)
+FrozenRewritePatternList::FrozenRewritePatternList(RewritePatternSet &&patterns)
: impl(std::make_shared<Impl>()) {
impl->nativePatterns = std::move(patterns.getNativePatterns());
diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index ba1f566abf6f6..7ed7526549fb1 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -84,9 +84,9 @@ class BufferizeCastOp : public OpConversionPattern<memref::BufferCastOp> {
} // namespace
void mlir::populateEliminateBufferizeMaterializationsPatterns(
- BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(
- typeConverter, patterns.getContext());
+ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<BufferizeTensorLoadOp, BufferizeCastOp>(typeConverter,
+ patterns.getContext());
}
namespace {
@@ -100,7 +100,7 @@ struct FinalizingBufferizePass
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 900d89c8080ba..5b6edf9894ab3 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -25,7 +25,7 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
- OwningRewritePatternList owningPatterns(context);
+ RewritePatternSet owningPatterns(context);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 113ba467cd5f3..d6037f563f874 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2612,14 +2612,14 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
} // end anonymous namespace
void mlir::populateFunctionLikeTypeConversionPattern(
- StringRef functionLikeOpName, OwningRewritePatternList &patterns,
+ StringRef functionLikeOpName, RewritePatternSet &patterns,
TypeConverter &converter) {
- patterns.insert<FunctionLikeSignatureConversion>(
+ patterns.add<FunctionLikeSignatureConversion>(
functionLikeOpName, patterns.getContext(), converter);
}
-void mlir::populateFuncOpTypeConversionPattern(
- OwningRewritePatternList &patterns, TypeConverter &converter) {
+void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
+ TypeConverter &converter) {
populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, converter);
}
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index cd58ec9a624f5..5aaa8fa0daade 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -403,7 +403,7 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
if (res) {
// Simplify/canonicalize the affine.for.
- OwningRewritePatternList patterns(res.getContext());
+ RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
bool erased;
(void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index b8aa7da65ac8f..b2d454bb23e04 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -110,7 +110,7 @@ void TestAffineDataCopy::runOnFunction() {
// Promoting single iteration loops could lead to simplification of
// generated load's/store's, and the latter could anyway also be
// canonicalized.
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
for (auto op : copyOps) {
patterns.clear();
if (isa<AffineLoadOp>(op)) {
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index f66ac8ca97227..530318cbb53f9 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -139,10 +139,10 @@ void ConvertToTargetEnv::runOnFunction() {
auto target = spirv::SPIRVConversionTarget::get(targetEnv);
- OwningRewritePatternList patterns(context);
- patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
- ConvertToGroupNonUniformBallot, ConvertToModule,
- ConvertToSubgroupBallot>(context);
+ RewritePatternSet patterns(context);
+ patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
+ ConvertToGroupNonUniformBallot, ConvertToModule,
+ ConvertToSubgroupBallot>(context);
if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
index 75bc52a608cb9..ba6c94bc46dbe 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
@@ -25,7 +25,7 @@ class TestGLSLCanonicalizationPass
} // namespace
void TestGLSLCanonicalizationPass::runOnOperation() {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 6da619652c072..eee0a9be75dbe 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -235,9 +235,9 @@ struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
};
} // end anonymous namespace
-void FoldToCallOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<FoldToCallOpPattern>(context);
+void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldToCallOpPattern>(context);
}
//===----------------------------------------------------------------------===//
@@ -615,8 +615,8 @@ struct TestRemoveOpWithInnerOps
} // end anonymous namespace
void TestOpWithRegionPattern::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<TestRemoveOpWithInnerOps>(context);
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<TestRemoveOpWithInnerOps>(context);
}
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c72e7fee32363..e34e52a9ef4c8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -79,11 +79,11 @@ struct FoldingPattern : public RewritePattern {
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns(&getContext());
+ mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
// Verify named pattern is generated with expected name.
- patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
+ patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -557,17 +557,17 @@ struct TestLegalizePatternDriver
void runOnOperation() override {
TestTypeConverter converter;
- mlir::OwningRewritePatternList patterns(&getContext());
+ mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.insert<
- TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
- TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
- TestPassthroughInvalidOp, TestSplitReturnType,
- TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
- TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
- TestNonRootReplacement, TestBoundedRecursiveRewrite,
- TestNestedOpCreationUndoRewrite>(&getContext());
- patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+ TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
+ TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite>(&getContext());
+ patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -698,8 +698,8 @@ struct OneVResOneVOperandOp1Converter
struct TestRemappedValue
: public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns(&getContext());
- patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
@@ -740,8 +740,8 @@ struct RemoveTestDialectOps : public RewritePattern {
struct TestUnknownRootOpDriver
: public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns(&getContext());
- patterns.insert<RemoveTestDialectOps>();
+ mlir::RewritePatternSet patterns(&getContext());
+ patterns.add<RemoveTestDialectOps>();
mlir::ConversionTarget target(getContext());
target.addIllegalDialect<TestDialect>();
@@ -876,10 +876,10 @@ struct TestTypeConversionDriver
});
// Initialize the set of rewrite patterns.
- OwningRewritePatternList patterns(&getContext());
- patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo>(converter, &getContext());
- patterns.insert<TestTypeConversionAnotherProducer>(&getContext());
+ RewritePatternSet patterns(&getContext());
+ patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo>(converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
if (failed(applyPartialConversion(getOperation(), target,
@@ -964,10 +964,9 @@ struct TestMergeBlocksPatternDriver
OperationPass<ModuleOp>> {
void runOnOperation() override {
MLIRContext *context = &getContext();
- mlir::OwningRewritePatternList patterns(context);
- patterns
- .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
- context);
+ mlir::RewritePatternSet patterns(context);
+ patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
+ context);
ConversionTarget target(*context);
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
@@ -1033,8 +1032,8 @@ struct TestSelectiveReplacementPatternDriver
OperationPass<>> {
void runOnOperation() override {
MLIRContext *context = &getContext();
- mlir::OwningRewritePatternList patterns(context);
- patterns.insert<TestSelectiveOpReplacementPattern>(context);
+ mlir::RewritePatternSet patterns(context);
+ patterns.add<TestSelectiveOpReplacementPattern>(context);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index e1f151fe61544..1e675aec0d8ed 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -34,7 +34,7 @@ namespace {
struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
void runOnFunction() override {
(void)applyPatternsAndFoldGreedily(getFunction(),
- OwningRewritePatternList(&getContext()));
+ RewritePatternSet(&getContext()));
}
};
} // end anonymous namespace
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 06777ea039d75..da890c9123055 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -184,11 +184,11 @@ struct TosaTestQuantUtilAPI
void TosaTestQuantUtilAPI::runOnFunction() {
auto *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
auto func = getFunction();
- patterns.insert<ConvertTosaNegateOp>(ctx);
- patterns.insert<ConvertTosaConv2DOp>(ctx);
+ patterns.add<ConvertTosaNegateOp>(ctx);
+ patterns.add<ConvertTosaConv2DOp>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index bc45d7b083aa1..b20224a50b83d 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -96,7 +96,7 @@ struct TestPDLByteCodePass
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
- OwningRewritePatternList patternList(std::move(pdlPattern));
+ RewritePatternSet patternList(std::move(pdlPattern));
// Invoke the pattern driver with the provided patterns.
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index cd741d047791e..55464283ff7de 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -59,14 +59,14 @@ void TestConvVectorization::runOnOperation() {
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
- SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+ SmallVector<RewritePatternSet, 4> stage1Patterns;
linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
SmallVector<FrozenRewritePatternList, 4> frozenStage1Patterns;
llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
- OwningRewritePatternList stage2Patterns =
+ RewritePatternSet stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
- stage2Patterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
+ stage2Patterns.add<linalg::AffineMinSCFCanonicalizationPattern>(context);
auto stage3Transforms = [](Operation *op) {
PassManager pm(op->getContext());
@@ -91,11 +91,11 @@ void TestConvVectorization::runOnOperation() {
VectorTransformsOptions vectorTransformsOptions{
VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
- OwningRewritePatternList vectorTransferPatterns(context);
+ RewritePatternSet vectorTransferPatterns(context);
// Pattern is not applied because rank-reducing vector transfer is not yet
// supported as can be seen in splitFullAndPartialTransferPrecondition,
// VectorTransforms.cpp
- vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
+ vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
@@ -106,14 +106,14 @@ void TestConvVectorization::runOnOperation() {
llvm_unreachable("Unexpected failure in linalg to loops pass.");
// Programmatic controlled lowering of vector.contract only.
- OwningRewritePatternList vectorContractLoweringPatterns(context);
+ RewritePatternSet vectorContractLoweringPatterns(context);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
- OwningRewritePatternList vectorToLoopsPatterns(context);
+ RewritePatternSet vectorToLoopsPatterns(context);
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
VectorTransferToSCFOptions());
(void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index dbe1a319fd26c..6f69e54fd6b5e 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -49,9 +49,9 @@ class TestConvertCallOp
});
// Populate patterns.
- OwningRewritePatternList patterns(m.getContext());
+ RewritePatternSet patterns(m.getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
- patterns.insert<TestTypeProducerOpConverter>(typeConverter);
+ patterns.add<TestTypeProducerOpConverter>(typeConverter);
// Set target.
ConversionTarget target(getContext());
diff --git a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
index 13c01a106b394..f2e1fa264bcf9 100644
--- a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
@@ -33,7 +33,7 @@ struct TestDecomposeCallGraphTypes
TypeConverter typeConverter;
ConversionTarget target(*context);
ValueDecomposer decomposer;
- OwningRewritePatternList patterns(context);
+ RewritePatternSet patterns(context);
target.addLegalDialect<test::TestDialect>();
diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp
index dc54a4be83556..0241bddc982bb 100644
--- a/mlir/test/lib/Transforms/TestExpandTanh.cpp
+++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp
@@ -24,7 +24,7 @@ struct TestExpandTanhPass
} // end anonymous namespace
void TestExpandTanhPass::runOnFunction() {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateExpandTanhPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
index 5f87a9f877283..27ecae96707bf 100644
--- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
@@ -25,7 +25,7 @@ struct TestGpuRewritePass
registry.insert<StandardOpsDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 8cb770287dbd5..23e6e0056627e 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -23,9 +23,9 @@ using namespace mlir::linalg;
template <LinalgTilingLoopType LoopType>
static void fillFusionPatterns(MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
- OwningRewritePatternList &patterns) {
- patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
- LinalgTileAndFusePattern<ConvOp>>(
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>,
+ LinalgTileAndFusePattern<ConvOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({2}),
@@ -39,7 +39,7 @@ static void fillFusionPatterns(MLIRContext *context,
ArrayRef<Identifier>(),
Identifier::get("after_basic_fusion_original", context)));
- patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({0}),
@@ -52,7 +52,7 @@ static void fillFusionPatterns(MLIRContext *context,
ArrayRef<Identifier>(),
Identifier::get("after_lhs_fusion_original", context)));
- patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({1}),
@@ -65,7 +65,7 @@ static void fillFusionPatterns(MLIRContext *context,
ArrayRef<Identifier>(),
Identifier::get("after_rhs_fusion_original", context)));
- patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({0, 2}),
@@ -79,7 +79,7 @@ static void fillFusionPatterns(MLIRContext *context,
ArrayRef<Identifier>(),
Identifier::get("after_two_operand_fusion_original", context)));
- patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
+ patterns.add<LinalgTileAndFusePattern<GenericOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({0, 1}),
@@ -109,7 +109,7 @@ struct TestLinalgFusionTransforms
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
- OwningRewritePatternList fusionPatterns(context);
+ RewritePatternSet fusionPatterns(context);
Aliases alias;
LinalgDependenceGraph dependenceGraph =
LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
@@ -181,9 +181,9 @@ struct TestLinalgGreedyFusion
: public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
void runOnFunction() override {
MLIRContext *context = &getContext();
- OwningRewritePatternList patterns =
+ RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
- patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
+ patterns.add<AffineMinSCFCanonicalizationPattern>(context);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
(void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 8e1cd2d3eca85..a9765ce8c9a46 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -92,36 +92,36 @@ struct TestLinalgTransforms
static void applyPatterns(FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
//===--------------------------------------------------------------------===//
// Linalg tiling patterns.
//===--------------------------------------------------------------------===//
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
LinalgTransformationFilter(Identifier::get("MEM", ctx),
Identifier::get("L3", ctx)));
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
LinalgTransformationFilter(Identifier::get("L3", ctx),
Identifier::get("L2", ctx)));
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgTransformationFilter(Identifier::get("L2", ctx),
Identifier::get("L1", ctx)));
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
LinalgTransformationFilter(Identifier::get("L1", ctx),
Identifier::get("REG", ctx)));
- patterns.insert<LinalgTilingPattern<MatvecOp>>(
+ patterns.add<LinalgTilingPattern<MatvecOp>>(
ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("L1", ctx)));
- patterns.insert<LinalgTilingPattern<DotOp>>(
+ patterns.add<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
LinalgTransformationFilter(
ArrayRef<Identifier>{Identifier::get("MEM", ctx),
@@ -132,31 +132,31 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===--------------------------------------------------------------------===//
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
Identifier::get("L2__with_perm__", ctx)));
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx),
Identifier::get("REG__with_perm__", ctx)));
- patterns.insert<LinalgTilingPattern<MatvecOp>>(
+ patterns.add<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgTransformationFilter(Identifier::get("__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({16, 8, 4})
@@ -169,7 +169,7 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
//===--------------------------------------------------------------------===//
- patterns.insert<LinalgLoweringPattern<DotOp>>(
+ patterns.add<LinalgLoweringPattern<DotOp>>(
ctx,
/*loweringType=*/LinalgLoweringType::Loops,
LinalgTransformationFilter(Identifier::get("REG", ctx)));
@@ -182,19 +182,19 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===--------------------------------------------------------------------===//
- patterns.insert<LinalgVectorizationPattern>(
+ patterns.add<LinalgVectorizationPattern>(
LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===--------------------------------------------------------------------===//
- patterns.insert<LinalgInterchangePattern<GenericOp>>(
+ patterns.add<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("PERMUTED", ctx)));
- patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
+ patterns.add<LinalgInterchangePattern<IndexedGenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgTransformationFilter(ArrayRef<Identifier>{},
@@ -203,11 +203,11 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===--------------------------------------------------------------------===//
- patterns.insert<LinalgPromotionPattern<MatmulOp>>(
+ patterns.add<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgTransformationFilter(Identifier::get("_promote_views_", ctx),
Identifier::get("_views_promoted_", ctx)));
- patterns.insert<LinalgPromotionPattern<MatmulOp>>(
+ patterns.add<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
@@ -215,7 +215,7 @@ static void applyPatterns(FuncOp funcOp) {
LinalgTransformationFilter(
Identifier::get("_promote_first_view_", ctx),
Identifier::get("_first_view_promoted_", ctx)));
- patterns.insert<LinalgPromotionPattern<FillOp>>(
+ patterns.add<LinalgPromotionPattern<FillOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
@@ -235,7 +235,7 @@ static void applyPatterns(FuncOp funcOp) {
static void fillL1TilingAndMatmulToVectorPatterns(
FuncOp funcOp, StringRef startMarker,
- SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
+ SmallVectorImpl<RewritePatternSet> &patternsVector) {
MLIRContext *ctx = funcOp.getContext();
patternsVector.emplace_back(
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
@@ -257,7 +257,7 @@ static void fillL1TilingAndMatmulToVectorPatterns(
ctx, std::make_unique<LinalgVectorizationPattern>(
MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
LinalgTransformationFilter(Identifier::get("VEC", ctx))));
- patternsVector.back().insert<LinalgVectorizationPattern>(
+ patternsVector.back().add<LinalgVectorizationPattern>(
LinalgTransformationFilter().addFilter(
[](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
}
@@ -301,12 +301,12 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
}
static void fillPromotionCallBackPatterns(MLIRContext *ctx,
- OwningRewritePatternList &patterns) {
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgTransformationFilter(Identifier::get("START", ctx),
Identifier::get("PROMOTE", ctx)));
- patterns.insert<LinalgPromotionPattern<MatmulOp>>(
+ patterns.add<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0, 2})
@@ -335,14 +335,14 @@ getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
}
static void fillTileAndDistributePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
+ RewritePatternSet &patterns) {
{
LinalgLoopDistributionOptions cyclicNprocsEqNiters;
cyclicNprocsEqNiters.distributionMethod.resize(
2, DistributionMethod::CyclicNumProcsEqNumIters);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -359,7 +359,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
2, DistributionMethod::CyclicNumProcsGeNumIters);
cyclicNprocsGeNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -376,7 +376,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::Cyclic);
cyclicNprocsDefault.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -393,7 +393,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsEqNumIters,
DistributionMethod::CyclicNumProcsGeNumIters};
cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -410,7 +410,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsGeNumIters,
DistributionMethod::Cyclic};
cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -428,7 +428,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsEqNumIters};
cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -445,7 +445,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::Cyclic);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
- patterns.insert<LinalgTilingPattern<MatmulOp>>(
+ patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
@@ -462,7 +462,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
bool testMatmulToVectorPatterns1dTiling,
bool testMatmulToVectorPatterns2dTiling) {
MLIRContext *ctx = funcOp.getContext();
- SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+ SmallVector<RewritePatternSet, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
@@ -487,24 +487,24 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
}
static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
- OwningRewritePatternList forwardPattern(funcOp.getContext());
- forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
- forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
+ RewritePatternSet forwardPattern(funcOp.getContext());
+ forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
+ forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
}
static void applyLinalgToVectorPatterns(FuncOp funcOp) {
- OwningRewritePatternList patterns(funcOp.getContext());
- patterns.insert<LinalgVectorizationPattern>(
+ RewritePatternSet patterns(funcOp.getContext());
+ patterns.add<LinalgVectorizationPattern>(
LinalgTransformationFilter()
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
- patterns.insert<PadTensorOpVectorizationPattern>(funcOp.getContext());
+ patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
- OwningRewritePatternList foldPattern(funcOp.getContext());
- foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
+ RewritePatternSet foldPattern(funcOp.getContext());
+ foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
// Explicitly walk and apply the pattern locally to avoid more general folding
@@ -523,12 +523,12 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
static void applyTileAndPadPattern(FuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList tilingPattern(context);
+ RewritePatternSet tilingPattern(context);
auto linalgTilingOptions =
linalg::LinalgTilingOptions()
.setTileSizes({2, 3, 4})
.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
- tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
+ tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
context, linalgTilingOptions,
linalg::LinalgTransformationFilter(
Identifier::get("tile-and-pad", context)));
@@ -545,13 +545,13 @@ void TestLinalgTransforms::runOnFunction() {
std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
if (testPromotionOptions) {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
fillPromotionCallBackPatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
if (testTileAndDistributionOptions) {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
fillTileAndDistributePatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
index c702301a293fb..fed76a0de5477 100644
--- a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
+++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
@@ -32,7 +32,7 @@ struct TestMathPolynomialApproximationPass
} // end anonymous namespace
void TestMathPolynomialApproximationPass::runOnFunction() {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateMathPolynomialApproximationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index 8c58f6eb117ed..22900f7edd1b0 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -101,7 +101,7 @@ struct TestSparsification
/// Runs the test on a function.
void runOnOperation() override {
auto *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
// Translate strategy flags to strategy options.
linalg::SparsificationOptions options(parallelOption(), vectorOption(),
vectorLength, typeOption(ptrType),
@@ -112,7 +112,7 @@ struct TestSparsification
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// Lower sparse primitives to calls into runtime support library.
if (lower) {
- OwningRewritePatternList conversionPatterns(ctx);
+ RewritePatternSet conversionPatterns(ctx);
ConversionTarget target(*ctx);
target.addIllegalOp<linalg::SparseTensorFromPointerOp,
linalg::SparseTensorToPointersMemRefOp,
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index ac0b099f96702..76c5c7c0e2a46 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -37,9 +37,9 @@ struct TestVectorToVectorConversion
void runOnFunction() override {
auto *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
if (unroll) {
- patterns.insert<UnrollVectorPattern>(
+ patterns.add<UnrollVectorPattern>(
ctx,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
@@ -70,7 +70,7 @@ struct TestVectorToVectorConversion
struct TestVectorSlicesConversion
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
void runOnFunction() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateVectorSlicesLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -101,14 +101,14 @@ struct TestVectorContractionConversion
llvm::cl::init(false)};
void runOnFunction() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
- patterns.insert<ContractionOpToOuterProductOpLowering>(options,
- &getContext());
+ patterns.add<ContractionOpToOuterProductOpLowering>(options,
+ &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
@@ -117,7 +117,7 @@ struct TestVectorContractionConversion
if (lowerToFilterOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
- patterns.insert<ContractionOpToOuterProductOpLowering>(
+ patterns.add<ContractionOpToOuterProductOpLowering>(
options, &getContext(), [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
@@ -149,8 +149,8 @@ struct TestVectorUnrollingPatterns
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
- patterns.insert<UnrollVectorPattern>(
+ RewritePatternSet patterns(ctx);
+ patterns.add<UnrollVectorPattern>(
ctx, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
@@ -171,14 +171,14 @@ struct TestVectorUnrollingPatterns
}
return nativeShape;
};
- patterns.insert<UnrollVectorPattern>(
+ patterns.add<UnrollVectorPattern>(
ctx, UnrollVectorOptions()
.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint([](Operation *op) {
return success(isa<ContractionOp>(op));
}));
} else {
- patterns.insert<UnrollVectorPattern>(
+ patterns.add<UnrollVectorPattern>(
ctx, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2, 2})
.setFilterConstraint([](Operation *op) {
@@ -210,7 +210,7 @@ struct TestVectorDistributePatterns
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
OpBuilder builder(op);
@@ -240,7 +240,7 @@ struct TestVectorDistributePatterns
}
}
});
- patterns.insert<PointwiseExtractPattern>(ctx);
+ patterns.add<PointwiseExtractPattern>(ctx);
populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -260,7 +260,7 @@ struct TestVectorToLoopPatterns
llvm::cl::init(32)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
// Check that the operation type can be broken down into a loop.
@@ -300,7 +300,7 @@ struct TestVectorToLoopPatterns
}
return mlir::WalkResult::interrupt();
});
- patterns.insert<PointwiseExtractPattern>(ctx);
+ patterns.add<PointwiseExtractPattern>(ctx);
populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -313,8 +313,8 @@ struct TestVectorTransferUnrollingPatterns
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
- patterns.insert<UnrollVectorPattern>(
+ RewritePatternSet patterns(ctx);
+ patterns.add<UnrollVectorPattern>(
ctx,
UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
@@ -347,13 +347,13 @@ struct TestVectorTransferFullPartialSplitPatterns
llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns(ctx);
+ RewritePatternSet patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
- patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
+ patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -369,7 +369,7 @@ struct TestVectorTransferLoweringPatterns
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
- OwningRewritePatternList patterns(&getContext());
+ RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 5115cfed7db8a..4fb9ecb39730a 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -96,7 +96,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
// CHECK: void print(::mlir::OpAsmPrinter &p);
// CHECK: ::mlir::LogicalResult verify();
-// CHECK: static void getCanonicalizationPatterns(::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context);
+// CHECK: static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
// CHECK: ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
// CHECK: // Display a graph for debugging purposes.
// CHECK: void displayGraph();
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 7165a0fe89fef..c1fa63c00eb71 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1518,7 +1518,6 @@ LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
}
-
auto parseExpr = [&]() -> LogicalResult {
std::unique_ptr<Expression> e;
if (failed(parseExpression(currentDefinition, e, state)))
@@ -2074,10 +2073,10 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
StringRef cppOpName) {
const char *canonicalizersAndFoldersFmt = R"FMT(
void {0}::getCanonicalizationPatterns(
- OwningRewritePatternList &results,
+ RewritePatternSet &results,
MLIRContext *context) {{
- results.insert<EraseDeadLinalgOp>();
- results.insert<FoldTensorCastOp>();
+ results.add<EraseDeadLinalgOp>();
+ results.add<FoldTensorCastOp>();
}
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index cea46325d54cd..e38e71bbd9260 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -519,10 +519,10 @@ ArrayAttr {0}::iterator_types() {
// {0}: Class name
const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
void {0}::getCanonicalizationPatterns(
- OwningRewritePatternList &results,
+ RewritePatternSet &results,
MLIRContext *context) {{
- results.insert<EraseDeadLinalgOp>();
- results.insert<FoldTensorCastOp>();
+ results.add<EraseDeadLinalgOp>();
+ results.add<FoldTensorCastOp>();
}
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d2f2132b1a38c..a1853362dce22 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1678,7 +1678,7 @@ void OpEmitter::genCanonicalizerDecls() {
return;
SmallVector<OpMethodParameter, 2> paramList;
- paramList.emplace_back("::mlir::OwningRewritePatternList &", "results");
+ paramList.emplace_back("::mlir::RewritePatternSet &", "results");
paramList.emplace_back("::mlir::MLIRContext *", "context");
opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
OpMethod::MP_StaticDeclaration,
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 60d19fff1fc26..68dddc285f26f 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1291,9 +1291,9 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
// Emit function to add the generated matchers to the pattern list.
os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
- "::mlir::OwningRewritePatternList &patterns) {\n";
+ "::mlir::RewritePatternSet &patterns) {\n";
for (const auto &name : rewriterNames) {
- os << " patterns.insert<" << name << ">(patterns.getContext());\n";
+ os << " patterns.add<" << name << ">(patterns.getContext());\n";
}
os << "}\n";
}
diff --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp
index ee36c6a653ea7..9461e2f0ff8b3 100644
--- a/mlir/unittests/Rewrite/PatternBenefit.cpp
+++ b/mlir/unittests/Rewrite/PatternBenefit.cpp
@@ -52,13 +52,13 @@ TEST(PatternBenefitTest, BenefitOrder) {
bool *called;
};
- OwningRewritePatternList patterns(&context);
+ RewritePatternSet patterns(&context);
bool called1 = false;
bool called2 = false;
- patterns.insert<Pattern1>(&context, &called1);
- patterns.insert<Pattern2>(&called2);
+ patterns.add<Pattern1>(&context, &called1);
+ patterns.add<Pattern2>(&called2);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
PatternApplicator pa(frozenPatterns);
More information about the Mlir-commits
mailing list