[Mlir-commits] [mlir] b8c6b15 - [mlir] Support collecting logs from notifyMatchFailure().

Chia-hung Duan llvmlistbot at llvm.org
Fri Dec 3 20:41:07 PST 2021


Author: Chia-hung Duan
Date: 2021-12-04T04:35:24Z
New Revision: b8c6b15283000f1f065acd10d487ef87df0542c9

URL: https://github.com/llvm/llvm-project/commit/b8c6b15283000f1f065acd10d487ef87df0542c9
DIFF: https://github.com/llvm/llvm-project/commit/b8c6b15283000f1f065acd10d487ef87df0542c9.diff

LOG: [mlir] Support collecting logs from notifyMatchFailure().

Let the user registers their own handler to processing the matching
failure information.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D110896

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e66dbbc664b4e..d5fb29387b2b2 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -525,7 +525,7 @@ struct ConversionPatternRewriterImpl;
 /// hooks.
 class ConversionPatternRewriter final : public PatternRewriter {
 public:
-  ConversionPatternRewriter(MLIRContext *ctx);
+  explicit ConversionPatternRewriter(MLIRContext *ctx);
   ~ConversionPatternRewriter() override;
 
   /// Apply a signature conversion to the entry block of the given region. This
@@ -932,14 +932,20 @@ LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
 /// operations on success and only pre-existing operations are added to the set.
 /// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'.
-LogicalResult applyAnalysisConversion(ArrayRef<Operation *> ops,
-                                      ConversionTarget &target,
-                                      const FrozenRewritePatternSet &patterns,
-                                      DenseSet<Operation *> &convertedOps);
-LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target,
-                                      const FrozenRewritePatternSet &patterns,
-                                      DenseSet<Operation *> &convertedOps);
+/// the regions nested within 'ops'. There's an additional argument
+/// `notifyCallback` which is used for collecting match failure diagnostics
+/// generated during the conversion. Diagnostics are only reported to this
+/// callback may only be available in debug mode.
+LogicalResult applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns,
+    DenseSet<Operation *> &convertedOps,
+    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+LogicalResult applyAnalysisConversion(
+    Operation *op, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns,
+    DenseSet<Operation *> &convertedOps,
+    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1d793f91da5fd..ad34eebff5b9b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -851,8 +851,9 @@ void ArgConverter::insertConversion(Block *newBlock,
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl {
-  ConversionPatternRewriterImpl(PatternRewriter &rewriter)
-      : argConverter(rewriter, unresolvedMaterializations) {}
+  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+      : argConverter(rewriter, unresolvedMaterializations),
+        notifyCallback(nullptr) {}
 
   /// Cleanup and destroy any generated rewrite operations. This method is
   /// invoked when the conversion process fails.
@@ -1004,6 +1005,9 @@ struct ConversionPatternRewriterImpl {
   /// active.
   TypeConverter *currentTypeConverter = nullptr;
 
+  /// This allows the user to collect the match failure message.
+  function_ref<void(Diagnostic &)> notifyCallback;
+
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1475,6 +1479,8 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     logger.startLine() << "** Failure : " << diag.str() << "\n";
+    if (notifyCallback)
+      notifyCallback(diag);
   });
   return failure();
 }
@@ -1949,7 +1955,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   // Functor that cleans up the rewriter state after a pattern failed to match.
   RewriterState curState = rewriterImpl.getCurrentState();
   auto onFailure = [&](const Pattern &pattern) {
-    LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
+    LLVM_DEBUG({
+      logFailure(rewriterImpl.logger, "pattern failed to match");
+      if (rewriterImpl.notifyCallback) {
+        Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+        diag << "Failed to apply pattern \"" << pattern.getDebugName()
+             << "\" on op:\n"
+             << *op;
+        rewriterImpl.notifyCallback(diag);
+      }
+    });
     rewriterImpl.resetState(curState);
     appliedPatterns.erase(&pattern);
   };
@@ -2333,7 +2348,9 @@ struct OperationConverter {
       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult convertOperations(ArrayRef<Operation *> ops);
+  LogicalResult
+  convertOperations(ArrayRef<Operation *> ops,
+                    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
 
 private:
   /// Converts an operation with the given rewriter.
@@ -2410,7 +2427,9 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult OperationConverter::convertOperations(
+    ArrayRef<Operation *> ops,
+    function_ref<void(Diagnostic &)> notifyCallback) {
   if (ops.empty())
     return success();
   ConversionTarget &target = opLegalizer.getTarget();
@@ -2428,6 +2447,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   // Convert each operation and discard rewrites on failure.
   ConversionPatternRewriter rewriter(ops.front()->getContext());
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  rewriterImpl.notifyCallback = notifyCallback;
+
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
       return rewriterImpl.discardRewrites(), failure();
@@ -3275,15 +3296,17 @@ LogicalResult
 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
                               ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              DenseSet<Operation *> &convertedOps) {
+                              DenseSet<Operation *> &convertedOps,
+                              function_ref<void(Diagnostic &)> notifyCallback) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
                                  &convertedOps);
-  return opConverter.convertOperations(ops);
+  return opConverter.convertOperations(ops, notifyCallback);
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              DenseSet<Operation *> &convertedOps) {
+                              DenseSet<Operation *> &convertedOps,
+                              function_ref<void(Diagnostic &)> notifyCallback) {
   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
-                                 convertedOps);
+                                 convertedOps, notifyCallback);
 }


        


More information about the Mlir-commits mailing list