[flang-commits] [flang] a6151f4 - [mlir][IR] Move `match` and `rewrite` functions into separate class (#129861)
via flang-commits
flang-commits at lists.llvm.org
Wed Mar 5 23:48:56 PST 2025
Author: Matthias Springer
Date: 2025-03-06T08:48:51+01:00
New Revision: a6151f4e237075919c12a120c391a8b6c6a5000c
URL: https://github.com/llvm/llvm-project/commit/a6151f4e237075919c12a120c391a8b6c6a5000c
DIFF: https://github.com/llvm/llvm-project/commit/a6151f4e237075919c12a120c391a8b6c6a5000c.diff
LOG: [mlir][IR] Move `match` and `rewrite` functions into separate class (#129861)
The vast majority of rewrite / conversion patterns uses a combined
`matchAndRewrite` instead of separate `match` and `rewrite` functions.
This PR optimizes the code base for the most common case where users
implement a combined `matchAndRewrite`. There are no longer any `match`
and `rewrite` functions in `RewritePattern`, `ConversionPattern` and
their derived classes. Instead, there is a `SplitMatchAndRewriteImpl`
class that implements `matchAndRewrite` in terms of `match` and
`rewrite`.
Details:
* The `RewritePattern` and `ConversionPattern` classes are simpler
(fewer functions). Especially the `ConversionPattern` class, which now
has 5 fewer functions. (There were various `rewrite` overloads to
account for 1:1 / 1:N patterns.)
* There is a new class `SplitMatchAndRewriteImpl` that derives from
`RewritePattern` / `OpRewritePatern` / ..., along with a type alias
`RewritePattern::SplitMatchAndRewrite` for convenience.
* Fewer `llvm_unreachable` are needed throughout the code base. Instead,
we can use pure virtual functions. (In cases where users previously had
to implement `rewrite` or `matchAndRewrite`, etc.)
* This PR may also improve the number of [`-Woverload-virtual`
warnings](https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933)
that are produced by GCC. (To be confirmed...)
Note for LLVM integration: Patterns with separate `match` / `rewrite`
implementations, must derive from `X::SplitMatchAndRewrite` instead of
`X`.
---------
Co-authored-by: River Riddle <riddleriver at gmail.com>
Added:
Modified:
flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
mlir/docs/PatternRewriter.md
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/unittests/IR/PatternMatchTest.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index 35749dae5d7e9..53d16323beddf 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -187,7 +187,6 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
const fir::FIRToLLVMPassOptions &options;
- using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};
@@ -206,20 +205,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
options, benefit) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
- mlir::ConversionPatternRewriter &rewriter) const final {
- rewrite(mlir::cast<SourceOp>(op),
- OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
- }
- void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
- mlir::ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = llvm::cast<SourceOp>(op);
- rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
- rewriter);
- }
- llvm::LogicalResult match(mlir::Operation *op) const final {
- return match(mlir::cast<SourceOp>(op));
- }
llvm::LogicalResult
matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
@@ -235,28 +220,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
- virtual llvm::LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const {
- llvm::SmallVector<mlir::Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
- if (mlir::failed(match(op)))
- return mlir::failure();
- rewrite(op, adaptor, rewriter);
- return mlir::success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -268,7 +237,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
private:
using ConvertFIRToLLVMPattern::matchAndRewrite;
- using ConvertToLLVMPattern::match;
};
/// FIR conversion pattern template
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 9df4647299010..af0f56466e0cb 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -38,22 +38,23 @@ possible cost and use the predicate to guard the match.
### Root Operation Name (Optional)
The name of the root operation that this pattern matches against. If specified,
-only operations with the given root name will be provided to the `match` and
-`rewrite` implementation. If not specified, any operation type may be provided.
-The root operation name should be provided whenever possible, because it
-simplifies the analysis of patterns when applying a cost model. To match any
+only operations with the given root name will be provided to the
+`matchAndRewrite` implementation. If not specified, any operation type may be
+provided. The root operation name should be provided whenever possible, because
+it simplifies the analysis of patterns when applying a cost model. To match any
operation type, a special tag must be provided to make the intent explicit:
`MatchAnyOpTypeTag`.
-### `match` and `rewrite` implementation
+### `matchAndRewrite` implementation
This is the chunk of code that matches a given root `Operation` and performs a
rewrite of the IR. A `RewritePattern` can specify this implementation either via
-separate `match` and `rewrite` methods, or via a combined `matchAndRewrite`
-method. When using the combined `matchAndRewrite` method, no IR mutation should
-take place before the match is deemed successful. The combined `matchAndRewrite`
-is useful when non-trivially recomputable information is required by the
-matching and rewriting phase. See below for examples:
+the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
+deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
+`matchAndRewrite` method, no IR mutation should take place before the match is
+deemed successful. The combined `matchAndRewrite` is useful when non-trivially
+recomputable information is required by the matching and rewriting phase. See
+below for examples:
```c++
class MyPattern : public RewritePattern {
@@ -105,6 +106,10 @@ Within the `rewrite` section of a pattern, the following constraints apply:
`eraseOp`) should be used instead.
* The root operation is required to either be: updated in-place, replaced, or
erased.
+* `matchAndRewrite` must return "success" if and only if the IR was modified.
+ `match` must return "success" if and only if the IR is going to be modified
+ during `rewrite`.
+
### Application Recursion
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+ ConvertOpToLLVMPattern<SourceOp>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
benefit) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
}
private:
- using ConvertToLLVMPattern::match;
using ConvertToLLVMPattern::matchAndRewrite;
};
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..792b50d38817e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,52 @@ class Pattern {
// RewritePattern
//===----------------------------------------------------------------------===//
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-/// * Multi-step RewritePattern with "match" and "rewrite"
-/// - By overloading the "match" and "rewrite" functions, the user can
-/// separate the concerns of matching and rewriting.
-/// * Single-step RewritePattern with "matchAndRewrite"
-/// - By overloading the "matchAndRewrite" function, the user can perform
-/// the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
- virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
+
+ /// Attempt to match against IR rooted at the specified operation, which is
+ /// the same operation kind as getRootKind().
+ ///
+ /// Note: This function must not modify the IR.
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
- /// builder. If an unexpected error is encountered (an internal
- /// compiler error), it is emitted through the normal MLIR diagnostic
- /// hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
-
- /// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind().
- virtual LogicalResult match(Operation *op) const;
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const = 0;
- /// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind(). If successful, this
- /// function will automatically perform the rewrite.
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+ LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+ PatternRewriter &rewriter) const final {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return success();
}
return failure();
}
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+class RewritePattern : public Pattern {
+public:
+ using OperationT = Operation *;
+ using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+ virtual ~RewritePattern() = default;
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). If successful, perform
+ /// the rewrite.
+ ///
+ /// Note: Implementations must modify the IR if and only if the function
+ /// returns "success".
+ virtual LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const = 0;
/// This method provides a convenient interface for creating and initializing
/// derived rewrite patterns of the given type `T`.
@@ -317,36 +328,19 @@ namespace detail {
/// class or Interface.
template <typename SourceOp>
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+ using OperationT = SourceOp;
using RewritePattern::RewritePattern;
- /// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, PatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), rewriter);
- }
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
+ /// Wrapper around the RewritePattern method that passes the derived op type.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
- /// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
- llvm_unreachable("must override rewrite or matchAndRewrite");
- }
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
+ /// Method that operates on the SourceOp type. Must be overridden by the
+ /// derived pattern class.
virtual LogicalResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const {
- if (succeeded(match(op))) {
- rewrite(op, rewriter);
- return success();
- }
- return failure();
- }
+ PatternRewriter &rewriter) const = 0;
};
} // namespace detail
@@ -356,6 +350,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
@@ -371,6 +368,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ using SplitMatchAndRewrite =
+ detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..120709bbe5b67 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,78 @@ class TypeConverter {
// Conversion Patterns
//===----------------------------------------------------------------------===//
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+ using PatternT::PatternT;
+
+ /// Attempt to match against IR rooted at the specified operation, which is
+ /// the same operation kind as getRootKind().
+ ///
+ /// Note: This function must not modify the IR.
+ virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// rewriter.
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // One of the two `rewrite` functions must be implemented.
+ llvm_unreachable("rewrite is not implemented");
+ }
+
+ virtual void rewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ if constexpr (std::is_same<typename PatternT::OpAdaptor,
+ ArrayRef<Value>>::value) {
+ rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+ } else {
+ SmallVector<Value> oneToOneOperands =
+ PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+ rewriter);
+ }
+ }
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ if (succeeded(match(op))) {
+ rewrite(op, adaptor, rewriter);
+ return success();
+ }
+ return failure();
+ }
+
+ LogicalResult
+ matchAndRewrite(typename PatternT::OperationT op,
+ typename PatternT::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // Users would normally override this function in conversion patterns to
+ // implement a 1:1 pattern. Patterns that are derived from this class have
+ // separate `match` and `rewrite` functions, so this `matchAndRewrite`
+ // overload is obsolete.
+ llvm_unreachable("this function is unreachable");
+ }
+};
+} // namespace detail
+
/// Base class for the conversion patterns. This pattern class enables type
/// conversions, and other uses specific to the conversion framework. As such,
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
- /// Hook for derived classes to implement rewriting. `op` is the (first)
- /// operation matched by the pattern, `operands` is a list of the rewritten
- /// operand values that are passed to `op`, `rewriter` can be used to emit the
- /// new operations. This function should not fail. If some specific cases of
- /// the operation are not supported, these cases should not be matched.
- virtual void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("unimplemented rewrite");
- }
- virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
+ using OperationT = Operation *;
+ using OpAdaptor = ArrayRef<Value>;
+ using OneToNOpAdaptor = ArrayRef<ValueRange>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +608,7 @@ class ConversionPattern : public RewritePattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
/// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +657,6 @@ class ConversionPattern : public RewritePattern {
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
-
-private:
- using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +665,12 @@ class ConversionPattern : public RewritePattern {
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ using OperationT = SourceOp;
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+ using SplitMatchAndRewrite =
+ detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +681,6 @@ class OpConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- LogicalResult match(Operation *op) const final {
- return match(cast<SourceOp>(op));
- }
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +695,12 @@ class OpConversionPattern : public ConversionPattern {
rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual LogicalResult match(SourceOp op) const {
- llvm_unreachable("must override match or matchAndRewrite");
- }
- virtual void rewrite(SourceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, adaptor, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +730,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
- void rewrite(Operation *op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), operands, rewriter);
- }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +741,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
- /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// Methods that operate on the SourceOp type. One of these must be
/// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
- }
- virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
- ConversionPatternRewriter &rewriter) const {
- rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
- }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (failed(match(op)))
- return failure();
- rewrite(op, operands, rewriter);
- return success();
+ llvm_unreachable("matchAndRewrite is not implemented");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+ : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+ using SplitMatchAndRewrite::SplitMatchAndRewrite;
Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+ : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {}
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
-struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
+struct TruncFToFloat8RewritePattern final
+ : OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
- chipset(chipset) {}
+ : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx),
+ saturateFP8(saturateFP8), chipset(chipset) {}
Chipset chipset;
LogicalResult match(arith::TruncFOp op) const override;
@@ -65,9 +67,9 @@ struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
};
struct TruncfToFloat16RewritePattern final
- : public OpRewritePattern<arith::TruncFOp> {
+ : public OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
- using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+ using SplitMatchAndRewrite::SplitMatchAndRewrite;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 3646416def810..80310ce56a51b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -363,11 +363,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;
-
- LogicalResult match(Derived op) const override {
- MemRefType type = op.getMemRefType();
- return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
- }
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
@@ -662,8 +657,9 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
}
};
-struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
- using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
+struct MemRefCastOpLowering
+ : public ConvertOpToLLVMPattern<memref::CastOp>::SplitMatchAndRewrite {
+ using SplitMatchAndRewrite::SplitMatchAndRewrite;
LogicalResult match(memref::CastOp memRefCastOp) const override {
Type srcType = memRefCastOp.getOperand().getType();
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17..f105534626082 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -40,9 +40,10 @@ struct EmulateUnsupportedFloatsPass
void runOnOperation() override;
};
-struct EmulateFloatPattern final : ConversionPattern {
+struct EmulateFloatPattern final : ConversionPattern::SplitMatchAndRewrite {
EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
- : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
+ : ConversionPattern::SplitMatchAndRewrite(
+ converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
LogicalResult match(Operation *op) const override;
void rewrite(Operation *op, ArrayRef<Value> operands,
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 5982f5f55549e..6da28ddeede3c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -115,9 +115,11 @@ class DataFlowListener : public RewriterBase::Listener {
/// and replace their uses with that constant. Return success() if all results
/// where thus replaced and the operation is erased. Also replace any block
/// arguments with their constant values.
-struct MaterializeKnownConstantValues : public RewritePattern {
+struct MaterializeKnownConstantValues
+ : public RewritePattern::SplitMatchAndRewrite {
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
- : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
+ : RewritePattern::SplitMatchAndRewrite(Pattern::MatchAnyOpTypeTag(),
+ /*benefit=*/1, context),
solver(s) {}
LogicalResult match(Operation *op) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index f13e54901f690..2413a4126f3f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -772,14 +772,14 @@ class FlattenContiguousRowMajorTransferWritePattern
/// `vector.extract` and `vector.extract_element`.
template <class VectorExtractOp>
class RewriteScalarExtractOfTransferReadBase
- : public OpRewritePattern<VectorExtractOp> {
- using Base = OpRewritePattern<VectorExtractOp>;
+ : public OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite {
+ using Base = typename OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite;
public:
RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
PatternBenefit benefit,
bool allowMultipleUses)
- : Base::OpRewritePattern(context, benefit),
+ : Base::SplitMatchAndRewrite(context, benefit),
allowMultipleUses(allowMultipleUses) {}
LogicalResult match(VectorExtractOp extractOp) const override {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 286f47ce69136..3e3c06bebf142 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -87,15 +87,6 @@ Pattern::Pattern(const void *rootValue, RootKind rootKind,
// RewritePattern
//===----------------------------------------------------------------------===//
-void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
- llvm_unreachable("need to implement either matchAndRewrite or one of the "
- "rewrite functions!");
-}
-
-LogicalResult RewritePattern::match(Operation *op) const {
- llvm_unreachable("need to implement either match or matchAndRewrite!");
-}
-
/// Out-of-line vtable anchor.
void RewritePattern::anchor() {}
diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 75d5228c82d99..1c67bfc284d32 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -19,6 +19,11 @@ struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
AnOpRewritePattern(MLIRContext *context)
: OpRewritePattern(context, /*benefit=*/1,
/*generatedNames=*/{test::OpB::getOperationName()}) {}
+
+ LogicalResult matchAndRewrite(test::OpA op,
+ PatternRewriter &rewriter) const override {
+ return failure();
+ }
};
TEST(OpRewritePatternTest, GetGeneratedNames) {
MLIRContext context;
More information about the flang-commits
mailing list