[Mlir-commits] [mlir] 244a4e7 - [mlir][transform] Add optional error checking to TrackingListener

Matthias Springer llvmlistbot at llvm.org
Thu Mar 30 02:18:27 PDT 2023


Author: Matthias Springer
Date: 2023-03-30T11:10:38+02:00
New Revision: 244a4e75aee42126209a6950fa12b345af5edf3d

URL: https://github.com/llvm/llvm-project/commit/244a4e75aee42126209a6950fa12b345af5edf3d
DIFF: https://github.com/llvm/llvm-project/commit/244a4e75aee42126209a6950fa12b345af5edf3d.diff

LOG: [mlir][transform] Add optional error checking to TrackingListener

Derived classes can implement `notifyPayloadReplacementNotFound` for custom error checking.

Differential Revision: https://reviews.llvm.org/D147206

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index f0f209f071a1b..eb55f7f133e91 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -50,6 +50,12 @@ class TrackingListener : public RewriterBase::Listener,
   virtual Operation *findReplacementOp(Operation *op,
                                        ValueRange newValues) const;
 
+  /// This function is called when a tracked payload op is dropped because no
+  /// replacement op was found. Derived classes can implement this function for
+  /// custom error handling.
+  virtual void notifyPayloadReplacementNotFound(Operation *op,
+                                                ValueRange values) const {}
+
   /// Return "true" if the given op is a new op.
   bool isNewOp(Operation *op) const;
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 63df020f2eb79..a37822d9d0998 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -204,15 +204,19 @@ void transform::TrackingListener::notifyOperationReplaced(
     Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
          "invalid number of replacement values");
-  if (op->getNumResults() == 0)
-    return;
 
   // Replace value handles.
   for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
     (void)replacePayloadValue(oldValue, newValue);
 
   // Replace op handle.
-  (void)replacePayloadOp(op, findReplacementOp(op, newValues));
+  Operation *replacement = findReplacementOp(op, newValues);
+  if (succeeded(replacePayloadOp(op, replacement))) {
+    // If the op is tracked but no replacement op was found, send a
+    // notification.
+    if (!replacement)
+      notifyPayloadReplacementNotFound(op, newValues);
+  }
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list