[Mlir-commits] [mlir] 5267f5e - [mlir] Add a hook to PatternRewriter to allow for patterns to notify why a match failed.
River Riddle
llvmlistbot at llvm.org
Tue Mar 17 12:16:53 PDT 2020
Author: River Riddle
Date: 2020-03-17T12:12:21-07:00
New Revision: 5267f5e6b4cb7ad78680fc33a4414b11d95f4c12
URL: https://github.com/llvm/llvm-project/commit/5267f5e6b4cb7ad78680fc33a4414b11d95f4c12
DIFF: https://github.com/llvm/llvm-project/commit/5267f5e6b4cb7ad78680fc33a4414b11d95f4c12.diff
LOG: [mlir] Add a hook to PatternRewriter to allow for patterns to notify why a match failed.
Summary:
This revision adds a new hook, `notifyMatchFailure`, that allows for notifying the rewriter that a match failure is coming with the provided reason. This hook takes as a parameter a callback that fills a `Diagnostic` instance with the reason why the match failed. This allows for the rewriter to decide how this information can be displayed to the end-user, and may completely ignore it if desired(opt mode). For now, DialectConversion is updated to include this information in the debug output.
Differential Revision: https://reviews.llvm.org/D76203
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 48f581998146..c36e8ab5aed1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -334,6 +334,23 @@ class PatternRewriter : public OpBuilder {
finalizeRootUpdate(root);
}
+ /// Notify the pattern rewriter that the pattern is failing to match the given
+ /// operation, and provide a callback to populate a diagnostic with the reason
+ /// why the failure occurred. This method allows for derived rewriters to
+ /// optionally hook into the reason why a pattern failed, and display it to
+ /// users.
+ virtual LogicalResult
+ notifyMatchFailure(Operation *op,
+ function_ref<void(Diagnostic &)> reasonCallback) {
+ return failure();
+ }
+ LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
+ return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
+ }
+ LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
+ return notifyMatchFailure(op, Twine(msg));
+ }
+
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
virtual ~PatternRewriter();
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a58c85499d63..7db8355e4177 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -379,6 +379,12 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// PatternRewriter hook for updating the root operation in-place.
void cancelRootUpdate(Operation *op) override;
+ /// PatternRewriter hook for notifying match failure reasons.
+ LogicalResult
+ notifyMatchFailure(Operation *op,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
+ using PatternRewriter::notifyMatchFailure;
+
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 5c0c0625c392..ae04e117282a 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -989,6 +989,17 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
}
+/// PatternRewriter hook for notifying match failure reasons.
+LogicalResult ConversionPatternRewriter::notifyMatchFailure(
+ Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
+ LLVM_DEBUG({
+ Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+ reasonCallback(diag);
+ impl->logger.startLine() << "** Failure : " << diag.str() << "\n";
+ });
+ return failure();
+}
+
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index 3760ed7dae90..997d6090be80 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -272,7 +272,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
// If the type is F32, change the type to F64.
if (!Type(*op->result_type_begin()).isF32())
- return matchFailure();
+ return rewriter.notifyMatchFailure(op, "expected single f32 operand");
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
return matchSuccess();
}
More information about the Mlir-commits
mailing list