[Mlir-commits] [mlir] 5267f5e - [mlir] Add a hook to PatternRewriter to allow for patterns to notify why a match failed.

River Riddle llvmlistbot at llvm.org
Tue Mar 17 12:16:53 PDT 2020


Author: River Riddle
Date: 2020-03-17T12:12:21-07:00
New Revision: 5267f5e6b4cb7ad78680fc33a4414b11d95f4c12

URL: https://github.com/llvm/llvm-project/commit/5267f5e6b4cb7ad78680fc33a4414b11d95f4c12
DIFF: https://github.com/llvm/llvm-project/commit/5267f5e6b4cb7ad78680fc33a4414b11d95f4c12.diff

LOG: [mlir] Add a hook to PatternRewriter to allow for patterns to notify why a match failed.

Summary:
This revision adds a new hook, `notifyMatchFailure`, that allows for notifying the rewriter that a match failure is coming with the provided reason. This hook takes as a parameter a callback that fills a `Diagnostic` instance with the reason why the match failed. This allows for the rewriter to decide how this information can be displayed to the end-user, and may completely ignore it if desired(opt mode). For now, DialectConversion is updated to include this information in the debug output.

Differential Revision: https://reviews.llvm.org/D76203

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/lib/TestDialect/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 48f581998146..c36e8ab5aed1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -334,6 +334,23 @@ class PatternRewriter : public OpBuilder {
     finalizeRootUpdate(root);
   }
 
+  /// Notify the pattern rewriter that the pattern is failing to match the given
+  /// operation, and provide a callback to populate a diagnostic with the reason
+  /// why the failure occurred. This method allows for derived rewriters to
+  /// optionally hook into the reason why a pattern failed, and display it to
+  /// users.
+  virtual LogicalResult
+  notifyMatchFailure(Operation *op,
+                     function_ref<void(Diagnostic &)> reasonCallback) {
+    return failure();
+  }
+  LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
+    return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
+  }
+  LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
+    return notifyMatchFailure(op, Twine(msg));
+  }
+
 protected:
   explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
   virtual ~PatternRewriter();

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a58c85499d63..7db8355e4177 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -379,6 +379,12 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// PatternRewriter hook for updating the root operation in-place.
   void cancelRootUpdate(Operation *op) override;
 
+  /// PatternRewriter hook for notifying match failure reasons.
+  LogicalResult
+  notifyMatchFailure(Operation *op,
+                     function_ref<void(Diagnostic &)> reasonCallback) override;
+  using PatternRewriter::notifyMatchFailure;
+
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 5c0c0625c392..ae04e117282a 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -989,6 +989,17 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
 }
 
+/// PatternRewriter hook for notifying match failure reasons.
+LogicalResult ConversionPatternRewriter::notifyMatchFailure(
+    Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
+  LLVM_DEBUG({
+    Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+    reasonCallback(diag);
+    impl->logger.startLine() << "** Failure : " << diag.str() << "\n";
+  });
+  return failure();
+}
+
 /// Return a reference to the internal implementation.
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;

diff  --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index 3760ed7dae90..997d6090be80 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -272,7 +272,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
                   ConversionPatternRewriter &rewriter) const final {
     // If the type is F32, change the type to F64.
     if (!Type(*op->result_type_begin()).isF32())
-      return matchFailure();
+      return rewriter.notifyMatchFailure(op, "expected single f32 operand");
     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
     return matchSuccess();
   }


        


More information about the Mlir-commits mailing list