[Mlir-commits] [mlir] b8c6b15 - [mlir] Support collecting logs from notifyMatchFailure().
Chia-hung Duan
llvmlistbot at llvm.org
Fri Dec 3 20:41:07 PST 2021
Author: Chia-hung Duan
Date: 2021-12-04T04:35:24Z
New Revision: b8c6b15283000f1f065acd10d487ef87df0542c9
URL: https://github.com/llvm/llvm-project/commit/b8c6b15283000f1f065acd10d487ef87df0542c9
DIFF: https://github.com/llvm/llvm-project/commit/b8c6b15283000f1f065acd10d487ef87df0542c9.diff
LOG: [mlir] Support collecting logs from notifyMatchFailure().
Let the user registers their own handler to processing the matching
failure information.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D110896
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e66dbbc664b4e..d5fb29387b2b2 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -525,7 +525,7 @@ struct ConversionPatternRewriterImpl;
/// hooks.
class ConversionPatternRewriter final : public PatternRewriter {
public:
- ConversionPatternRewriter(MLIRContext *ctx);
+ explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
/// Apply a signature conversion to the entry block of the given region. This
@@ -932,14 +932,20 @@ LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
/// operations on success and only pre-existing operations are added to the set.
/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'.
-LogicalResult applyAnalysisConversion(ArrayRef<Operation *> ops,
- ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps);
-LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps);
+/// the regions nested within 'ops'. There's an additional argument
+/// `notifyCallback` which is used for collecting match failure diagnostics
+/// generated during the conversion. Diagnostics are only reported to this
+/// callback may only be available in debug mode.
+LogicalResult applyAnalysisConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+LogicalResult applyAnalysisConversion(
+ Operation *op, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1d793f91da5fd..ad34eebff5b9b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -851,8 +851,9 @@ void ArgConverter::insertConversion(Block *newBlock,
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl {
- ConversionPatternRewriterImpl(PatternRewriter &rewriter)
- : argConverter(rewriter, unresolvedMaterializations) {}
+ explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+ : argConverter(rewriter, unresolvedMaterializations),
+ notifyCallback(nullptr) {}
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
@@ -1004,6 +1005,9 @@ struct ConversionPatternRewriterImpl {
/// active.
TypeConverter *currentTypeConverter = nullptr;
+ /// This allows the user to collect the match failure message.
+ function_ref<void(Diagnostic &)> notifyCallback;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1475,6 +1479,8 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
+ if (notifyCallback)
+ notifyCallback(diag);
});
return failure();
}
@@ -1949,7 +1955,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that cleans up the rewriter state after a pattern failed to match.
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
- LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
+ LLVM_DEBUG({
+ logFailure(rewriterImpl.logger, "pattern failed to match");
+ if (rewriterImpl.notifyCallback) {
+ Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+ diag << "Failed to apply pattern \"" << pattern.getDebugName()
+ << "\" on op:\n"
+ << *op;
+ rewriterImpl.notifyCallback(diag);
+ }
+ });
rewriterImpl.resetState(curState);
appliedPatterns.erase(&pattern);
};
@@ -2333,7 +2348,9 @@ struct OperationConverter {
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
/// Converts the given operations to the conversion target.
- LogicalResult convertOperations(ArrayRef<Operation *> ops);
+ LogicalResult
+ convertOperations(ArrayRef<Operation *> ops,
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr);
private:
/// Converts an operation with the given rewriter.
@@ -2410,7 +2427,9 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return success();
}
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult OperationConverter::convertOperations(
+ ArrayRef<Operation *> ops,
+ function_ref<void(Diagnostic &)> notifyCallback) {
if (ops.empty())
return success();
ConversionTarget &target = opLegalizer.getTarget();
@@ -2428,6 +2447,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Convert each operation and discard rewrites on failure.
ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ rewriterImpl.notifyCallback = notifyCallback;
+
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
return rewriterImpl.discardRewrites(), failure();
@@ -3275,15 +3296,17 @@ LogicalResult
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps) {
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback) {
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
&convertedOps);
- return opConverter.convertOperations(ops);
+ return opConverter.convertOperations(ops, notifyCallback);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps) {
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback) {
return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
- convertedOps);
+ convertedOps, notifyCallback);
}
More information about the Mlir-commits
mailing list