[Mlir-commits] [mlir] [mlir][transform] Remove `cachedNames` expensive check (PR #73961)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 09:24:42 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This check was trying to find cases of invalid API usage: incorrect/missing handle side effects and/or incorrect rewriter usage.

This check is not implemented correctly and can report false positives in case of pointer reuse (different op created at same location). It is unclear if such a check can be implemented given that we have both tracking listener-based handle updates and handle consumption.

Fixes #<!-- -->72931.


---
Full diff: https://github.com/llvm/llvm-project/pull/73961.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (-12) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+3-121) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 5fcde11d52f03d9..2fdc15db9ad854a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -789,18 +789,6 @@ class TransformState {
   /// Each region must be an ancestor of the following regions in this list.
   /// These are also the keys for "mappings".
   SmallVector<Region *> regionStack;
-
-  /// This cache stores operation names for operations that are tracked in the
-  /// transform dialect state. It is used to detect missing memory side effects
-  /// and op tracking.
-  ///
-  /// All tracked ops are added to this cache before a transform op is applied.
-  /// After the application of the transform op, the names of all tracked ops
-  /// are compared with the names in the cache. If there is a mismatch (or a
-  /// crash), op tracking is missing somewhere. This is typically a missing
-  /// "consumesHandle" side effect or a pattern that removes an op without
-  /// notifying a TrackingListener.
-  DenseMap<Operation *, OperationName> cachedNames;
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 };
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index d0cd879d560c887..de5b7a81286bc4e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -386,23 +386,6 @@ transform::TransformState::replacePayloadOp(Operation *op,
     dropMappingEntry(mappings.reverse, op, handle);
   }
 
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  if (options.getExpensiveChecksEnabled()) {
-    auto it = cachedNames.find(op);
-    assert(it != cachedNames.end() && "entry not found");
-    assert(it->second == op->getName() && "operation name mismatch");
-    cachedNames.erase(it);
-    if (replacement) {
-      auto insertion =
-          cachedNames.insert({replacement, replacement->getName()});
-      if (!insertion.second) {
-        assert(insertion.first->second == replacement->getName() &&
-               "operation is already cached with a different name");
-      }
-    }
-  }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
   // Replace the pointed-to object of all handles with the replacement object.
   // In case a payload op was erased (replacement object is nullptr), a nullptr
   // is stored in the mapping. These nullptrs are removed after each transform.
@@ -494,10 +477,10 @@ void transform::TransformState::recordOpHandleInvalidationOne(
   unsigned operandNo = consumingHandle.getOperandNumber();
   for (Operation *ancestor : potentialAncestors) {
     // clang-format off
-    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
       { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
-    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
-      { (DBGS() << "----of payload with name: " 
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+      { (DBGS() << "----of payload with name: "
                 << payloadOp->getName().getIdentifier() << "\n"); });
     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
       { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
@@ -872,29 +855,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
       }
     }
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-    // Cache Operation* -> OperationName mappings. These will be checked after
-    // the transform has been applied to detect incorrect memory side effects
-    // and missing op tracking.
-    for (std::unique_ptr<Mappings> &mapping :
-         llvm::make_second_range(mappings)) {
-      for (Operation *op : llvm::make_first_range(mapping->reverse)) {
-        auto insertion = cachedNames.insert({op, op->getName()});
-        if (!insertion.second) {
-          if (insertion.first->second != op->getName()) {
-            // Operation is already in the cache, but with a different name.
-            DiagnosedDefiniteFailure diag =
-                emitDefiniteFailure(transform->getLoc())
-                << "expensive checks failure: operation mismatch, expected "
-                << insertion.first->second;
-            diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
-            return diag;
-          }
-        }
-      }
-    }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   }
 
   // Find which operands are consumed.
@@ -908,22 +868,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   // IR after that.
   SmallVector<Value> origOpFlatResults;
   SmallVector<Operation *> origAssociatedOps;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  DenseSet<Operation *> consumedPayloadOps;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   for (OpOperand *opOperand : consumedOperands) {
     Value operand = opOperand->get();
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
       for (Operation *payloadOp : getPayloadOps(operand)) {
         llvm::append_range(origOpFlatResults, payloadOp->getResults());
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-        if (options.getExpensiveChecksEnabled()) {
-          // Store all consumed payload ops (and their nested ops) in a set for
-          // extra error checking.
-          payloadOp->walk(
-              [&](Operation *op) { consumedPayloadOps.insert(op); });
-        }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
       }
       continue;
     }
@@ -1004,63 +953,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     }
   }
 
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  if (options.getExpensiveChecksEnabled()) {
-    // Remove erased ops from the transform state.
-    for (Operation *op : consumedPayloadOps) {
-      // This payload op was consumed but it may still be mapped to one or
-      // multiple handles. Forget all handles that are mapped to the op, so that
-      // there are no dangling pointers in the transform dialect state. This is
-      // necessary so that the `cachedNames`-based checks work correctly.
-      //
-      // Note: Dangling pointers to erased payload ops are allowed if the
-      // corresponding handles are not used anymore. There is another
-      // "expensive-check" that looks for future uses of dangling payload op
-      // pointers (through arbitrary handles). Removing handles to erased ops
-      // does not interfere with the other expensive checks: handle invalidation
-      // happens earlier and keeps track of invalidated handles with
-      // pre-generated error messages, so we do not need the association to
-      // still be there when the invalidated handle is accessed.
-      SmallVector<Value> handles;
-      (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
-      for (Value handle : handles)
-        forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
-                      /*allowOutOfScope=*/true);
-      cachedNames.erase(op);
-    }
-
-    // Check cached operation names.
-    for (std::unique_ptr<Mappings> &mapping :
-         llvm::make_second_range(mappings)) {
-      for (Operation *op : llvm::make_first_range(mapping->reverse)) {
-        // Make sure that the name of the op has not changed. If it has changed,
-        // the op was removed and a new op was allocated at the same memory
-        // location. This means that we are missing op tracking somewhere.
-        auto cacheIt = cachedNames.find(op);
-        if (cacheIt == cachedNames.end()) {
-          DiagnosedDefiniteFailure diag =
-              emitDefiniteFailure(transform->getLoc())
-              << "expensive checks failure: operation not found in cache";
-          diag.attachNote(op->getLoc()) << "payload op";
-          return diag;
-        }
-        // If the `getName` call (or the above `attachNote`) is crashing, we
-        // have a dangling pointer. This usually means that an op was erased but
-        // the transform dialect was not made aware of that; e.g., missing
-        // "consumesHandle" or rewriter usage.
-        if (cacheIt->second != op->getName()) {
-          DiagnosedDefiniteFailure diag =
-              emitDefiniteFailure(transform->getLoc())
-              << "expensive checks failure: operation mismatch, expected "
-              << cacheIt->second;
-          diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
-          return diag;
-        }
-      }
-    }
-  }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
   if (failed(updateStateFromResults(results, transform->getResults())))
     return DiagnosedSilenceableFailure::definiteFailure();
 
@@ -1150,16 +1042,6 @@ transform::TransformState::RegionScope::~RegionScope() {
   state.mappings.erase(region);
 
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  // If the last handle to a payload op has gone out of scope, we no longer
-  // need to store the cached name. Pointers may get reused, leading to
-  // incorrect associations in the cache.
-  for (Operation *op : referencedOps) {
-    SmallVector<Value> handles;
-    if (succeeded(state.getHandlesForPayloadOp(op, handles)))
-      continue;
-    state.cachedNames.erase(op);
-  }
-
   state.regionStack.pop_back();
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/73961


More information about the Mlir-commits mailing list