[Mlir-commits] [mlir] bd5941b - [mlir] Remove the PatternState class and simplify PatternMatchResult.
River Riddle
llvmlistbot at llvm.org
Mon Mar 16 17:59:54 PDT 2020
Author: River Riddle
Date: 2020-03-16T17:55:54-07:00
New Revision: bd5941b9ceef0c1af0a10b292d22fdbd2433f4cd
URL: https://github.com/llvm/llvm-project/commit/bd5941b9ceef0c1af0a10b292d22fdbd2433f4cd
DIFF: https://github.com/llvm/llvm-project/commit/bd5941b9ceef0c1af0a10b292d22fdbd2433f4cd.diff
LOG: [mlir] Remove the PatternState class and simplify PatternMatchResult.
Summary: PatternState was a mechanism to pass state between the match and rewrite calls of a RewritePattern. With the rise of matchAndRewrite, this class is unused and unnecessary. This revision removes PatternState and simplifies PatternMatchResult to just be a LogicalResult. A future revision will replace all usages of PatternMatchResult/matchSuccess/matchFailure with LogicalResult equivalents.
Differential Revision: https://reviews.llvm.org/D76202
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index aa17952b79c4..48f581998146 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -54,22 +54,9 @@ class PatternBenefit {
unsigned short representation;
};
-/// Pattern state is used by patterns that want to maintain state between their
-/// match and rewrite phases. Patterns can define a pattern-specific subclass
-/// of this.
-class PatternState {
-public:
- virtual ~PatternState() {}
-
-protected:
- // Must be subclassed.
- PatternState() {}
-};
-
-/// This is the type returned by a pattern match. A match failure returns a
-/// None value. A match success returns a Some value with any state the pattern
-/// may need to maintain (but may also be null).
-using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
+/// This is the type returned by a pattern match.
+/// TODO: Replace usages with LogicalResult directly.
+using PatternMatchResult = LogicalResult;
//===----------------------------------------------------------------------===//
// Pattern class
@@ -97,9 +84,7 @@ class Pattern {
//===--------------------------------------------------------------------===//
/// Attempt to match against code rooted at the specified operation,
- /// which is the same operation code as getRootKind(). On failure, this
- /// returns a None value. On success it returns a (possibly null)
- /// pattern-specific state wrapped in an Optional.
+ /// which is the same operation code as getRootKind().
virtual PatternMatchResult match(Operation *op) const = 0;
virtual ~Pattern() {}
@@ -108,14 +93,11 @@ class Pattern {
// Helper methods to simplify pattern implementations
//===--------------------------------------------------------------------===//
- /// This method indicates that no match was found.
- static PatternMatchResult matchFailure() { return None; }
+ /// Return a result, indicating that no match was found.
+ PatternMatchResult matchFailure() const { return failure(); }
- /// This method indicates that a match was found and has the specified cost.
- PatternMatchResult
- matchSuccess(std::unique_ptr<PatternState> state = {}) const {
- return PatternMatchResult(std::move(state));
- }
+ /// This method indicates that a match was found.
+ PatternMatchResult matchSuccess() const { return success(); }
protected:
/// Patterns must specify the root operation name they match against, and can
@@ -136,19 +118,10 @@ class Pattern {
/// separate the concerns of matching and rewriting.
/// * Single-step RewritePattern with "matchAndRewrite"
/// - By overloading the "matchAndRewrite" function, the user can perform
-/// the rewrite in the same call as the match. This removes the need for
-/// any PatternState.
+/// the rewrite in the same call as the match.
///
class RewritePattern : public Pattern {
public:
- /// Rewrite the IR rooted at the specified operation with the result of
- /// this pattern, generating any new operations with the specified
- /// rewriter. If an unexpected error is encountered (an internal
- /// compiler error), it is emitted through the normal MLIR diagnostic
- /// hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
- PatternRewriter &rewriter) const;
-
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
@@ -168,8 +141,8 @@ class RewritePattern : public Pattern {
/// function will automatically perform the rewrite.
virtual PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- if (auto matchResult = match(op)) {
- rewrite(op, std::move(*matchResult), rewriter);
+ if (succeeded(match(op))) {
+ rewrite(op, rewriter);
return matchSuccess();
}
return matchFailure();
@@ -206,10 +179,6 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
: RewritePattern(SourceOp::getOperationName(), benefit, context) {}
/// Wrappers around the RewritePattern methods that pass the derived op type.
- void rewrite(Operation *op, std::unique_ptr<PatternState> state,
- PatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), std::move(state), rewriter);
- }
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
}
@@ -223,20 +192,16 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
- PatternRewriter &rewriter) const {
- rewrite(op, rewriter);
- }
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
- llvm_unreachable("must override matchAndRewrite or a rewrite method");
+ llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual PatternMatchResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual PatternMatchResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
- if (auto matchResult = match(op)) {
- rewrite(op, std::move(*matchResult), rewriter);
+ if (succeeded(match(op))) {
+ rewrite(op, rewriter);
return matchSuccess();
}
return matchFailure();
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 48ce470b88dd..a58c85499d63 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -238,7 +238,7 @@ class ConversionPattern : public RewritePattern {
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (!match(op))
+ if (failed(match(op)))
return matchFailure();
rewrite(op, operands, rewriter);
return matchSuccess();
@@ -285,7 +285,7 @@ struct OpConversionPattern : public ConversionPattern {
virtual PatternMatchResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (!match(op))
+ if (failed(match(op)))
return matchFailure();
rewrite(op, operands, rewriter);
return matchSuccess();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d295870d09ab..f967793c36f0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -805,7 +805,7 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
// multiple times.
auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
(void)success;
- assert(success && "Unexpected failure");
+ assert(succeeded(success) && "Unexpected failure");
extractedSource = insertStridedSliceOp;
}
// 4. Insert the extractedSource into the res vector.
@@ -1083,7 +1083,7 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
// multiple times.
auto success = matchAndRewrite(stridedSliceOp, rewriter);
(void)success;
- assert(success && "Unexpected failure");
+ assert(succeeded(success) && "Unexpected failure");
extracted = stridedSliceOp;
}
res = insertOne(rewriter, loc, extracted, res, idx);
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index c705dc87bfa8..68c7c018b584 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -318,9 +318,8 @@ struct ConvertSelectionOpToSelect
auto *falseBlock = brConditionalOp.getSuccessor(1);
auto *mergeBlock = selectionOp.getMergeBlock();
- if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
+ if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
return matchFailure();
- }
auto trueValue = getSrcValue(trueBlock);
auto falseValue = getSrcValue(falseBlock);
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index b111137e736a..bf1cd3c29b3e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -39,11 +39,6 @@ void Pattern::anchor() {}
// RewritePattern and PatternRewriter implementation
//===----------------------------------------------------------------------===//
-void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
- PatternRewriter &rewriter) const {
- rewrite(op, rewriter);
-}
-
void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
llvm_unreachable("need to implement either matchAndRewrite or one of the "
"rewrite functions!");
@@ -191,7 +186,7 @@ bool RewritePatternMatcher::matchAndRewrite(Operation *op,
// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite and return.
- if (pattern->matchAndRewrite(op, rewriter))
+ if (succeeded(pattern->matchAndRewrite(op, rewriter)))
return true;
}
return false;
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index a58dd05d0812..5c0c0625c392 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1237,12 +1237,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
// Try to rewrite with the given pattern.
rewriter.setInsertionPoint(op);
- auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
+ LogicalResult matchedPattern = pattern->matchAndRewrite(op, rewriter);
#ifndef NDEBUG
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
#endif
- if (!matchedPattern) {
+ if (failed(matchedPattern)) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
return cleanupFailure();
}
More information about the Mlir-commits
mailing list