[Mlir-commits] [mlir] bd5941b - [mlir] Remove the PatternState class and simplify PatternMatchResult.

River Riddle llvmlistbot at llvm.org
Mon Mar 16 17:59:54 PDT 2020


Author: River Riddle
Date: 2020-03-16T17:55:54-07:00
New Revision: bd5941b9ceef0c1af0a10b292d22fdbd2433f4cd

URL: https://github.com/llvm/llvm-project/commit/bd5941b9ceef0c1af0a10b292d22fdbd2433f4cd
DIFF: https://github.com/llvm/llvm-project/commit/bd5941b9ceef0c1af0a10b292d22fdbd2433f4cd.diff

LOG: [mlir] Remove the PatternState class and simplify PatternMatchResult.

Summary: PatternState was a mechanism to pass state between the match and rewrite calls of a RewritePattern. With the rise of matchAndRewrite, this class is unused and unnecessary. This revision removes PatternState and simplifies PatternMatchResult to just be a LogicalResult. A future revision will replace all usages of PatternMatchResult/matchSuccess/matchFailure with LogicalResult equivalents.

Differential Revision: https://reviews.llvm.org/D76202

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index aa17952b79c4..48f581998146 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -54,22 +54,9 @@ class PatternBenefit {
   unsigned short representation;
 };
 
-/// Pattern state is used by patterns that want to maintain state between their
-/// match and rewrite phases.  Patterns can define a pattern-specific subclass
-/// of this.
-class PatternState {
-public:
-  virtual ~PatternState() {}
-
-protected:
-  // Must be subclassed.
-  PatternState() {}
-};
-
-/// This is the type returned by a pattern match.  A match failure returns a
-/// None value.  A match success returns a Some value with any state the pattern
-/// may need to maintain (but may also be null).
-using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
+/// This is the type returned by a pattern match.
+/// TODO: Replace usages with LogicalResult directly.
+using PatternMatchResult = LogicalResult;
 
 //===----------------------------------------------------------------------===//
 // Pattern class
@@ -97,9 +84,7 @@ class Pattern {
   //===--------------------------------------------------------------------===//
 
   /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind().  On failure, this
-  /// returns a None value.  On success it returns a (possibly null)
-  /// pattern-specific state wrapped in an Optional.
+  /// which is the same operation code as getRootKind().
   virtual PatternMatchResult match(Operation *op) const = 0;
 
   virtual ~Pattern() {}
@@ -108,14 +93,11 @@ class Pattern {
   // Helper methods to simplify pattern implementations
   //===--------------------------------------------------------------------===//
 
-  /// This method indicates that no match was found.
-  static PatternMatchResult matchFailure() { return None; }
+  /// Return a result, indicating that no match was found.
+  PatternMatchResult matchFailure() const { return failure(); }
 
-  /// This method indicates that a match was found and has the specified cost.
-  PatternMatchResult
-  matchSuccess(std::unique_ptr<PatternState> state = {}) const {
-    return PatternMatchResult(std::move(state));
-  }
+  /// This method indicates that a match was found.
+  PatternMatchResult matchSuccess() const { return success(); }
 
 protected:
   /// Patterns must specify the root operation name they match against, and can
@@ -136,19 +118,10 @@ class Pattern {
 ///       separate the concerns of matching and rewriting.
 ///   * Single-step RewritePattern with "matchAndRewrite"
 ///     - By overloading the "matchAndRewrite" function, the user can perform
-///       the rewrite in the same call as the match. This removes the need for
-///       any PatternState.
+///       the rewrite in the same call as the match.
 ///
 class RewritePattern : public Pattern {
 public:
-  /// Rewrite the IR rooted at the specified operation with the result of
-  /// this pattern, generating any new operations with the specified
-  /// rewriter.  If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
-  virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
-                       PatternRewriter &rewriter) const;
-
   /// Rewrite the IR rooted at the specified operation with the result of
   /// this pattern, generating any new operations with the specified
   /// builder.  If an unexpected error is encountered (an internal
@@ -168,8 +141,8 @@ class RewritePattern : public Pattern {
   /// function will automatically perform the rewrite.
   virtual PatternMatchResult matchAndRewrite(Operation *op,
                                              PatternRewriter &rewriter) const {
-    if (auto matchResult = match(op)) {
-      rewrite(op, std::move(*matchResult), rewriter);
+    if (succeeded(match(op))) {
+      rewrite(op, rewriter);
       return matchSuccess();
     }
     return matchFailure();
@@ -206,10 +179,6 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
       : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, std::unique_ptr<PatternState> state,
-               PatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), std::move(state), rewriter);
-  }
   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
     rewrite(cast<SourceOp>(op), rewriter);
   }
@@ -223,20 +192,16 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
-                       PatternRewriter &rewriter) const {
-    rewrite(op, rewriter);
-  }
   virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
+    llvm_unreachable("must override rewrite or matchAndRewrite");
   }
   virtual PatternMatchResult match(SourceOp op) const {
     llvm_unreachable("must override match or matchAndRewrite");
   }
   virtual PatternMatchResult matchAndRewrite(SourceOp op,
                                              PatternRewriter &rewriter) const {
-    if (auto matchResult = match(op)) {
-      rewrite(op, std::move(*matchResult), rewriter);
+    if (succeeded(match(op))) {
+      rewrite(op, rewriter);
       return matchSuccess();
     }
     return matchFailure();

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 48ce470b88dd..a58c85499d63 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -238,7 +238,7 @@ class ConversionPattern : public RewritePattern {
   virtual PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (!match(op))
+    if (failed(match(op)))
       return matchFailure();
     rewrite(op, operands, rewriter);
     return matchSuccess();
@@ -285,7 +285,7 @@ struct OpConversionPattern : public ConversionPattern {
   virtual PatternMatchResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (!match(op))
+    if (failed(match(op)))
       return matchFailure();
     rewrite(op, operands, rewriter);
     return matchSuccess();

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d295870d09ab..f967793c36f0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -805,7 +805,7 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
         // multiple times.
         auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
         (void)success;
-        assert(success && "Unexpected failure");
+        assert(succeeded(success) && "Unexpected failure");
         extractedSource = insertStridedSliceOp;
       }
       // 4. Insert the extractedSource into the res vector.
@@ -1083,7 +1083,7 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
         // multiple times.
         auto success = matchAndRewrite(stridedSliceOp, rewriter);
         (void)success;
-        assert(success && "Unexpected failure");
+        assert(succeeded(success) && "Unexpected failure");
         extracted = stridedSliceOp;
       }
       res = insertOne(rewriter, loc, extracted, res, idx);

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index c705dc87bfa8..68c7c018b584 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -318,9 +318,8 @@ struct ConvertSelectionOpToSelect
     auto *falseBlock = brConditionalOp.getSuccessor(1);
     auto *mergeBlock = selectionOp.getMergeBlock();
 
-    if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
+    if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
       return matchFailure();
-    }
 
     auto trueValue = getSrcValue(trueBlock);
     auto falseValue = getSrcValue(falseBlock);

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index b111137e736a..bf1cd3c29b3e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -39,11 +39,6 @@ void Pattern::anchor() {}
 // RewritePattern and PatternRewriter implementation
 //===----------------------------------------------------------------------===//
 
-void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
-                             PatternRewriter &rewriter) const {
-  rewrite(op, rewriter);
-}
-
 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
   llvm_unreachable("need to implement either matchAndRewrite or one of the "
                    "rewrite functions!");
@@ -191,7 +186,7 @@ bool RewritePatternMatcher::matchAndRewrite(Operation *op,
 
     // Try to match and rewrite this pattern. The patterns are sorted by
     // benefit, so if we match we can immediately rewrite and return.
-    if (pattern->matchAndRewrite(op, rewriter))
+    if (succeeded(pattern->matchAndRewrite(op, rewriter)))
       return true;
   }
   return false;

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index a58dd05d0812..5c0c0625c392 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1237,12 +1237,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
 
   // Try to rewrite with the given pattern.
   rewriter.setInsertionPoint(op);
-  auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
+  LogicalResult matchedPattern = pattern->matchAndRewrite(op, rewriter);
 #ifndef NDEBUG
   assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
 #endif
 
-  if (!matchedPattern) {
+  if (failed(matchedPattern)) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
     return cleanupFailure();
   }


        


More information about the Mlir-commits mailing list