[Mlir-commits] [mlir] 572b171 - [mlir][transform] TrackingListener: Distinguish between failure and "should be dropped"

Matthias Springer llvmlistbot at llvm.org
Fri Jun 9 02:40:40 PDT 2023


Author: Matthias Springer
Date: 2023-06-09T11:40:32+02:00
New Revision: 572b171fb58b58f4c3f95dfde495dd81059e1451

URL: https://github.com/llvm/llvm-project/commit/572b171fb58b58f4c3f95dfde495dd81059e1451
DIFF: https://github.com/llvm/llvm-project/commit/572b171fb58b58f4c3f95dfde495dd81059e1451.diff

LOG: [mlir][transform] TrackingListener: Distinguish between failure and "should be dropped"

When looking for replacement ops (`findReplacementOp`) distinguish between "no replacement could be found" and "this op should be dropped from the mapping". The latter case will be utilized in a subsequent revision when a payload op is mapped to a consumed handle.

Differential Revision: https://reviews.llvm.org/D152375

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 5625d193b1508..d5fcdb31ab0c8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -53,6 +53,11 @@ class TrackingListener : public RewriterBase::Listener,
   /// the same op, which also has the same type as the given op, that defining
   /// op is used as a replacement.
   ///
+  /// A "failure" return value indicates that no replacement operation could be
+  /// found. A "nullptr" return value indicates that no replacement op is needed
+  /// (e.g., handle is dead or was consumed) and that the payload op should
+  /// be dropped from the mapping.
+  ///
   /// Example: A tracked "linalg.generic" with two results is replaced with two
   /// values defined by (another) "linalg.generic". It is reasonable to assume
   /// that the replacement "linalg.generic" represents the same "computation".
@@ -91,8 +96,8 @@ class TrackingListener : public RewriterBase::Listener,
   ///
   /// Derived classes may override `findReplacementOp` to specify custom
   /// replacement rules.
-  virtual Operation *findReplacementOp(Operation *op,
-                                       ValueRange newValues) const;
+  virtual FailureOr<Operation *> findReplacementOp(Operation *op,
+                                                   ValueRange newValues) const;
 
   /// Notify the listener that the pattern failed to match the given operation,
   /// and provide a callback to populate a diagnostic with the reason why the

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 69d0b2b20307b..a64b9d7aa365d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -70,7 +70,7 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
   return defOp;
 }
 
-Operation *
+FailureOr<Operation *>
 transform::TrackingListener::findReplacementOp(Operation *op,
                                                ValueRange newValues) const {
   assert(op->getNumResults() == newValues.size() &&
@@ -81,7 +81,7 @@ transform::TrackingListener::findReplacementOp(Operation *op,
     // If the replacement values belong to 
diff erent ops, drop the mapping.
     Operation *defOp = getCommonDefiningOp(values);
     if (!defOp)
-      return nullptr;
+      return failure();
 
     // If the defining op has the same type, we take it as a replacement.
     if (op->getName() == defOp->getName())
@@ -108,7 +108,7 @@ transform::TrackingListener::findReplacementOp(Operation *op,
     }
   } while (!values.empty());
 
-  return nullptr;
+  return failure();
 }
 
 LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -173,12 +173,16 @@ void transform::TrackingListener::notifyOperationReplaced(
     return;
   }
 
-  Operation *replacement = findReplacementOp(op, newValues);
+  FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
   // If the op is tracked but no replacement op was found, send a
   // notification.
-  if (!replacement)
+  if (failed(replacement)) {
     notifyPayloadReplacementNotFound(op, newValues);
-  (void)replacePayloadOp(op, replacement);
+    (void)replacePayloadOp(op, nullptr);
+    return;
+  }
+
+  (void)replacePayloadOp(op, *replacement);
 }
 
 transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index c0a1348c8010a..c6aad2c8fdb35 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -302,7 +302,10 @@ class DummyTrackingListener : public transform::TrackingListener {
 
   // Expose `findReplacementOp` as a public function, so that it can be tested.
   Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
-    return findReplacementOp(op, newValues);
+    FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
+    if (failed(replacementOp))
+      return nullptr;
+    return *replacementOp;
   }
 };
 } // namespace

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index fc6f323a3b557..0c3697d1171ff 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -696,15 +696,15 @@ class TestTrackingListener : public transform::TrackingListener {
   using transform::TrackingListener::TrackingListener;
 
 protected:
-  Operation *findReplacementOp(Operation *op,
-                               ValueRange newValues) const override {
+  FailureOr<Operation *>
+  findReplacementOp(Operation *op, ValueRange newValues) const override {
     if (newValues.size() != 1)
-      return nullptr;
+      return failure();
     Operation *replacement = newValues[0].getDefiningOp();
     if (!replacement)
-      return nullptr;
+      return failure();
     if (replacement->getName().getStringRef() != "test.update_mapping")
-      return nullptr;
+      return failure();
     return replacement;
   }
 };


        


More information about the Mlir-commits mailing list