[Mlir-commits] [mlir] 549e190 - [PatternRewriter] Rename OwningRewritePatternList -> RewritePatternSet and insert -> add
Chris Lattner
llvmlistbot at llvm.org
Mon Mar 22 16:33:24 PDT 2021
Author: Chris Lattner
Date: 2021-03-22T16:33:18-07:00
New Revision: 549e190236f638c087fca664d8823a268efdf5c8
URL: https://github.com/llvm/llvm-project/commit/549e190236f638c087fca664d8823a268efdf5c8
DIFF: https://github.com/llvm/llvm-project/commit/549e190236f638c087fca664d8823a268efdf5c8.diff
LOG: [PatternRewriter] Rename OwningRewritePatternList -> RewritePatternSet and insert -> add
This maintains the old name to have minimal source impact on downstream codes, and
does not do the huge mechanical patch. I expect the huge mechanical patch to land
sometime this week, but we can keep around the old names for a couple weeks to reduce
impact on downstream projects.
Differential Revision: https://reviews.llvm.org/D99119
Added:
Modified:
mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Dialect/AMX/Transforms.h
mlir/include/mlir/Dialect/AVX512/Transforms.h
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index 8d3301c3b451a..8058f5d7f12a6 100644
--- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -24,8 +24,8 @@ class RewritePattern;
class Value;
class ValueRange;
-// Owning list of rewriting patterns.
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.
diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
index 8cba4e9be5d5c..70170f8c5f99d 100644
--- a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
@@ -12,7 +12,8 @@
namespace mlir {
class LLVMTypeConverter;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
index 670942a464ab5..cf3763f449a1f 100644
--- a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
+++ b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
@@ -18,8 +18,9 @@ class ModuleOp;
template <typename T>
class OperationPass;
class MLIRContext;
-class OwningRewritePatternList;
class TypeConverter;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Create a pass to convert Async operations to the LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 878861e406e4f..708a3fe0b23ee 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -24,7 +24,8 @@ class Location;
struct LogicalResult;
class ModuleOp;
class Operation;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
template <typename T>
class OperationPass;
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 233b947bcfede..cdfe5fa07a640 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -13,10 +13,12 @@
namespace mlir {
class LLVMTypeConverter;
-class OwningRewritePatternList;
class ConversionTarget;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
-template <typename OpT> class OperationPass;
+template <typename OpT>
+class OperationPass;
namespace gpu {
class GPUModuleOp;
diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 5fa798bf28342..e298d2d73efbb 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -13,8 +13,9 @@
namespace mlir {
class LLVMTypeConverter;
-class OwningRewritePatternList;
class ConversionTarget;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
template <typename OpT>
class OperationPass;
diff --git a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
index 8f94597323e7a..f05e9d53ff455 100644
--- a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
+++ b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
@@ -15,8 +15,9 @@
namespace mlir {
class MLIRContext;
-class OwningRewritePatternList;
class SPIRVTypeConverter;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Appends to a pattern list additional patterns for translating Linalg ops to
/// SPIR-V ops.
diff --git a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
index 4eae84cd0135c..5092322286d65 100644
--- a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
+++ b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h
@@ -8,7 +8,7 @@
#ifndef MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_
#define MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_
-#include<memory>
+#include <memory>
namespace mlir {
class LLVMTypeConverter;
@@ -16,7 +16,8 @@ class MLIRContext;
class ModuleOp;
template <typename T>
class OperationPass;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Populate the given list with patterns that convert from OpenMP to LLVM.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
index 14c16088270ff..a27c408b9d5a5 100644
--- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
+++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
@@ -15,8 +15,9 @@ class AffineForOp;
class ConversionTarget;
struct LogicalResult;
class MLIRContext;
-class OwningRewritePatternList;
class Value;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
namespace scf {
class ForOp;
diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
index 5a14c9b2d35a4..14679f4abb7cf 100644
--- a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
@@ -18,9 +18,10 @@ namespace mlir {
class Pass;
// Owning list of rewriting patterns.
-class OwningRewritePatternList;
class SPIRVTypeConverter;
struct ScfToSPIRVContextImpl;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
struct ScfToSPIRVContext {
ScfToSPIRVContext();
diff --git a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
index 95667d86133ab..fc120798d8063 100644
--- a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
+++ b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
@@ -15,10 +15,9 @@
namespace mlir {
struct LogicalResult;
class Pass;
-class RewritePattern;
-// Owning list of rewriting patterns.
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
index 7c94470f4d269..3ab3ee7144f3f 100644
--- a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -17,7 +17,8 @@ class FuncOp;
class ModuleOp;
template <typename T>
class OperationPass;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
void populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns);
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index ca623a3c050e0..e9ee9e953477d 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -18,7 +18,8 @@ class LLVMTypeConverter;
class ModuleOp;
template <typename T>
class OperationPass;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Value to pass as bitwidth for the index type when the converter is expected
/// to derive the bitwidth from the LLVM data layout.
diff --git a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
index 660de02ee36fa..7f0859cc5f581 100644
--- a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
+++ b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
@@ -12,10 +12,11 @@
namespace mlir {
class LLVMTypeConverter;
-class OwningRewritePatternList;
class ModuleOp;
template <typename OpT>
class OperationPass;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateVectorToROCDLConversionPatterns(
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index e7478cf4c196f..561a3e9ca2c69 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -13,8 +13,9 @@
namespace mlir {
class MLIRContext;
-class OwningRewritePatternList;
class Pass;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Control whether unrolling is used when lowering vector transfer ops to SCF.
///
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 11b3004292d4e..1fccbb5815149 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -13,7 +13,8 @@ namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
diff --git a/mlir/include/mlir/Dialect/AVX512/Transforms.h b/mlir/include/mlir/Dialect/AVX512/Transforms.h
index 3506f50dc2582..541833652a49f 100644
--- a/mlir/include/mlir/Dialect/AVX512/Transforms.h
+++ b/mlir/include/mlir/Dialect/AVX512/Transforms.h
@@ -13,7 +13,8 @@ namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower AVX512 ops to ops that map to LLVM
/// intrinsics.
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 3ce88a135899c..10635667a5fcc 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -11,12 +11,11 @@
namespace mlir {
-class OwningRewritePatternList;
+class RewritePatternSet;
-void populateExpandTanhPattern(OwningRewritePatternList &patterns);
+void populateExpandTanhPattern(RewritePatternSet &patterns);
-void populateMathPolynomialApproximationPatterns(
- OwningRewritePatternList &patterns);
+void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index 914a1a0cb8ac7..94473af864693 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -19,9 +19,10 @@ namespace mlir {
class ConversionTarget;
class MLIRContext;
-class OwningRewritePatternList;
class Region;
class TypeConverter;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
namespace scf {
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index a7eb59a45dae5..6e0abfcc7f0ea 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -19,8 +19,9 @@ namespace mlir {
class ConversionTarget;
class MLIRContext;
class Operation;
-class OwningRewritePatternList;
class TypeConverter;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
/// Add a pattern to the given pattern list to convert the operand and result
/// types of a CallOp with the given type converter.
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 1e04b22985376..6e95daed621f8 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -19,7 +19,8 @@
namespace mlir {
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index 72539c8e25727..dc1fd7e948422 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -14,7 +14,8 @@
namespace mlir {
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 7d20e64b23792..456cc88430a6b 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -28,7 +28,8 @@
namespace mlir {
class MLIRContext;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
namespace vector {
class VectorDialect;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index ff3dbfdcad1f9..9a0d5537f1738 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -16,8 +16,9 @@
namespace mlir {
class MLIRContext;
-class OwningRewritePatternList;
class VectorTransferOpInterface;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
namespace scf {
class IfOp;
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 60af4b09e0e10..19173d16757a0 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -51,9 +51,11 @@ class RewritePattern;
class Type;
class Value;
class ValueRange;
-template <typename ValueRangeT> class ValueTypeRange;
+template <typename ValueRangeT>
+class ValueTypeRange;
-class OwningRewritePatternList;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
//===----------------------------------------------------------------------===//
// AbstractOperation
@@ -132,12 +134,14 @@ class AbstractOperation {
/// Returns an instance of the concept object for the given interface if it
/// was registered to this operation, null otherwise. This should not be used
/// directly.
- template <typename T> typename T::Concept *getInterface() const {
+ template <typename T>
+ typename T::Concept *getInterface() const {
return interfaceMap.lookup<T>();
}
/// Returns true if the operation has a particular trait.
- template <template <typename T> class Trait> bool hasTrait() const {
+ template <template <typename T> class Trait>
+ bool hasTrait() const {
return hasTraitFn(TypeID::get<Trait>());
}
@@ -148,7 +152,8 @@ class AbstractOperation {
/// This constructor is used by Dialect objects when they register the list of
/// operations they contain.
- template <typename T> static void insert(Dialect &dialect) {
+ template <typename T>
+ static void insert(Dialect &dialect) {
insert(T::getOperationName(), dialect, TypeID::get<T>(),
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
T::getVerifyInvariantsFn(), T::getFoldHookFn(),
@@ -220,7 +225,8 @@ class NamedAttrList {
void append(NamedAttribute attr) { push_back(attr); }
/// Add an array of named attributes.
- template <typename RangeT> void append(RangeT &&newAttributes) {
+ template <typename RangeT>
+ void append(RangeT &&newAttributes) {
append(std::begin(newAttributes), std::end(newAttributes));
}
@@ -851,7 +857,8 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
namespace llvm {
// Identifiers hash just like pointers, there is no need to hash the bytes.
-template <> struct DenseMapInfo<mlir::OperationName> {
+template <>
+struct DenseMapInfo<mlir::OperationName> {
static mlir::OperationName getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::OperationName::getFromOpaquePointer(pointer);
@@ -871,7 +878,8 @@ template <> struct DenseMapInfo<mlir::OperationName> {
/// The pointer inside of an identifier comes from a StringMap, so its alignment
/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
/// steal the low bits.
-template <> struct PointerLikeTypeTraits<mlir::OperationName> {
+template <>
+struct PointerLikeTypeTraits<mlir::OperationName> {
public:
static inline void *getAsVoidPointer(mlir::OperationName I) {
return const_cast<void *>(I.getAsOpaquePointer());
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index aac321dece61a..115ad5f039bc0 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -715,22 +715,22 @@ class PatternRewriter : public RewriterBase {
};
//===----------------------------------------------------------------------===//
-// OwningRewritePatternList
+// RewritePatternSet
//===----------------------------------------------------------------------===//
-class OwningRewritePatternList {
+class RewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
- OwningRewritePatternList(MLIRContext *context) : context(context) {}
+ RewritePatternSet(MLIRContext *context) : context(context) {}
- /// Construct a OwningRewritePatternList populated with the given pattern.
- OwningRewritePatternList(MLIRContext *context,
- std::unique_ptr<RewritePattern> pattern)
+ /// Construct a RewritePatternSet populated with the given pattern.
+ RewritePatternSet(MLIRContext *context,
+ std::unique_ptr<RewritePattern> pattern)
: context(context) {
nativePatterns.emplace_back(std::move(pattern));
}
- OwningRewritePatternList(PDLPatternModule &&pattern)
+ RewritePatternSet(PDLPatternModule &&pattern)
: context(pattern.getModule()->getContext()),
pdlPatterns(std::move(pattern)) {}
@@ -748,51 +748,114 @@ class OwningRewritePatternList {
pdlPatterns.clear();
}
+ //===--------------------------------------------------------------------===//
+ // 'add' methods for adding patterns to the set.
+ //===--------------------------------------------------------------------===//
+
+ /// Add an instance of each of the pattern types 'Ts' to the pattern list with
+ /// the given arguments. Return a reference to `this` for chaining insertions.
+ /// Note: ConstructorArg is necessary here to separate the two variadic lists.
+ template <typename... Ts, typename ConstructorArg,
+ typename... ConstructorArgs,
+ typename = std::enable_if_t<sizeof...(Ts) != 0>>
+ RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&... args) {
+ // The following expands a call to emplace_back for each of the pattern
+ // types 'Ts'. This magic is necessary due to a limitation in the places
+ // that a parameter pack can be expanded in c++11.
+ // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+ (void)std::initializer_list<int>{0, (addImpl<Ts>(arg, args...), 0)...};
+ return *this;
+ }
+
+ /// Add an instance of each of the pattern types 'Ts'. Return a reference to
+ /// `this` for chaining insertions.
+ template <typename... Ts>
+ RewritePatternSet &add() {
+ (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
+ return *this;
+ }
+
+ /// Add the given native pattern to the pattern list. Return a reference to
+ /// `this` for chaining insertions.
+ RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
+ nativePatterns.emplace_back(std::move(pattern));
+ return *this;
+ }
+
+ /// Add the given PDL pattern to the pattern list. Return a reference to
+ /// `this` for chaining insertions.
+ RewritePatternSet &add(PDLPatternModule &&pattern) {
+ pdlPatterns.mergeIn(std::move(pattern));
+ return *this;
+ }
+
+ // Add a matchAndRewrite style pattern represented as a C function pointer.
+ template <typename OpType>
+ RewritePatternSet &add(LogicalResult (*implFn)(OpType,
+ PatternRewriter &rewriter)) {
+ struct FnPattern final : public OpRewritePattern<OpType> {
+ FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
+ MLIRContext *context)
+ : OpRewritePattern<OpType>(context), implFn(implFn) {}
+
+ LogicalResult matchAndRewrite(OpType op,
+ PatternRewriter &rewriter) const override {
+ return implFn(op, rewriter);
+ }
+
+ private:
+ LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
+ };
+ add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
+ return *this;
+ }
+
//===--------------------------------------------------------------------===//
// Pattern Insertion
//===--------------------------------------------------------------------===//
+ // TODO: These are soft deprecated in favor of the 'add' methods above.
+
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
/// the given arguments. Return a reference to `this` for chaining insertions.
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
- OwningRewritePatternList &insert(ConstructorArg &&arg,
- ConstructorArgs &&... args) {
+ RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{0, (insertImpl<Ts>(arg, args...), 0)...};
+ (void)std::initializer_list<int>{0, (addImpl<Ts>(arg, args...), 0)...};
return *this;
}
/// Add an instance of each of the pattern types 'Ts'. Return a reference to
/// `this` for chaining insertions.
template <typename... Ts>
- OwningRewritePatternList &insert() {
- (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
+ RewritePatternSet &insert() {
+ (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
return *this;
}
/// Add the given native pattern to the pattern list. Return a reference to
/// `this` for chaining insertions.
- OwningRewritePatternList &insert(std::unique_ptr<RewritePattern> pattern) {
+ RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
nativePatterns.emplace_back(std::move(pattern));
return *this;
}
/// Add the given PDL pattern to the pattern list. Return a reference to
/// `this` for chaining insertions.
- OwningRewritePatternList &insert(PDLPatternModule &&pattern) {
+ RewritePatternSet &insert(PDLPatternModule &&pattern) {
pdlPatterns.mergeIn(std::move(pattern));
return *this;
}
// Add a matchAndRewrite style pattern represented as a C function pointer.
template <typename OpType>
- OwningRewritePatternList &
+ RewritePatternSet &
insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
struct FnPattern final : public OpRewritePattern<OpType> {
FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
@@ -816,13 +879,13 @@ class OwningRewritePatternList {
/// chaining insertions.
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
- insertImpl(Args &&... args) {
+ addImpl(Args &&... args) {
nativePatterns.emplace_back(
std::make_unique<T>(std::forward<Args>(args)...));
}
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
- insertImpl(Args &&... args) {
+ addImpl(Args &&... args) {
pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
}
@@ -831,6 +894,10 @@ class OwningRewritePatternList {
PDLPatternModule pdlPatterns;
};
+// TODO: OwningRewritePatternList is soft-deprecated and will be removed in the
+// future.
+using OwningRewritePatternList = RewritePatternSet;
+
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
index c795ad55a3563..994501cfa0e54 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
@@ -51,6 +51,6 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
return success();
}
-void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
- patterns.insert(convertTanhOp);
+void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertTanhOp);
}
More information about the Mlir-commits
mailing list