[Mlir-commits] [mlir] a700a1d - [mlir] use shared pointer to prevent vector resizes from destroying ops

Ashay Rane llvmlistbot at llvm.org
Fri Jul 7 08:42:01 PDT 2023


Author: Ashay Rane
Date: 2023-07-07T10:41:48-05:00
New Revision: a700a1db8bbaffe3cff09da4ee11419ab3a057b3

URL: https://github.com/llvm/llvm-project/commit/a700a1db8bbaffe3cff09da4ee11419ab3a057b3
DIFF: https://github.com/llvm/llvm-project/commit/a700a1db8bbaffe3cff09da4ee11419ab3a057b3.diff

LOG: [mlir] use shared pointer to prevent vector resizes from destroying ops

The `MapVector` type stores key-value pairs in a vector, which, when
resized, copies the entries and destroys the old ones.  This causes the
underlying operations to be deleted, subsequently causing segfaults.

This patch makes the `mappings` map type refer to a shared pointer
instead, so that resizing the vector doesn't call the operations'
destructors.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D154511

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 8411eb0dd84128..629d4b55815a00 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -285,7 +285,8 @@ class TransformState {
     /// transform IR region and payload IR objects.
     RegionScope(TransformState &state, Region &region)
         : state(state), region(&region) {
-      auto res = state.mappings.insert(std::make_pair(&region, Mappings()));
+      auto res = state.mappings.insert(
+          std::make_pair(&region, std::make_unique<Mappings>()));
       assert(res.second && "the region scope is already present");
       (void)res;
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -437,7 +438,7 @@ class TransformState {
       }
     }
 #endif // NDEBUG
-    return it->second;
+    return *it->second.get();
   }
 
   /// Returns the mappings frame for the region in which the operation resides.
@@ -464,7 +465,7 @@ class TransformState {
       }
     }
 #endif // NDEBUG
-    return it->second;
+    return *it->second.get();
   }
 
   /// Updates the state to include the associations between op results and the
@@ -683,7 +684,10 @@ class TransformState {
 
   /// A stack of mappings between transform IR values and payload IR ops,
   /// aggregated by the region in which the transform IR values are defined.
-  llvm::MapVector<Region *, Mappings> mappings;
+  /// We use a pointer to the Mappings struct so that reallocations inside
+  /// MapVector don't invalidate iterators when we apply nested transform ops
+  /// while also iterating over the mappings.
+  llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
 
   /// Op handles may be temporarily mapped to nullptr to avoid invalidating
   /// payload op iterators. This set contains all op handles with nullptrs.

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 889be71c1f34cb..ed987ac4b51646 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -45,7 +45,8 @@ transform::TransformState::TransformState(
   for (ArrayRef<MappedValue> mapping : extraMappings)
     topLevelMappedValues.push_back(mapping);
 
-  auto result = mappings.insert(std::make_pair(region, Mappings()));
+  auto result =
+      mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
   assert(result.second && "the region scope is already present");
   (void)result;
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -87,8 +88,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
     bool includeOutOfScope) const {
   bool found = false;
   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
-    auto iterator = mapping.reverse.find(op);
-    if (iterator != mapping.reverse.end()) {
+    auto iterator = mapping->reverse.find(op);
+    if (iterator != mapping->reverse.end()) {
       llvm::append_range(handles, iterator->getSecond());
       found = true;
     }
@@ -106,8 +107,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadValue(
     bool includeOutOfScope) const {
   bool found = false;
   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
-    auto iterator = mapping.reverseValues.find(payloadValue);
-    if (iterator != mapping.reverseValues.end()) {
+    auto iterator = mapping->reverseValues.find(payloadValue);
+    if (iterator != mapping->reverseValues.end()) {
       llvm::append_range(handles, iterator->getSecond());
       found = true;
     }
@@ -611,7 +612,7 @@ void transform::TransformState::recordOpHandleInvalidation(
     // Go over all op handle mappings and mark as invalidated any handle
     // pointing to any of the payload ops associated with the given handle or
     // any op nested in them.
-    for (const auto &[payloadOp, otherHandles] : mapping.reverse) {
+    for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
       for (Value otherHandle : otherHandles)
         recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
                                       otherHandle, throughValue,
@@ -622,7 +623,7 @@ void transform::TransformState::recordOpHandleInvalidation(
     // or any op nested in them. Similarly invalidate handles to argument of
     // blocks belonging to any region of any payload op associated with the
     // given handle or any op nested in them.
-    for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) {
+    for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
       for (Value valueHandle : valueHandles)
         recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
                                                    payloadValue, valueHandle,
@@ -842,8 +843,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     // Cache Operation* -> OperationName mappings. These will be checked after
     // the transform has been applied to detect incorrect memory side effects
     // and missing op tracking.
-    for (Mappings &mapping : llvm::make_second_range(mappings)) {
-      for (Operation *op : llvm::make_first_range(mapping.reverse)) {
+    for (std::unique_ptr<Mappings> &mapping :
+         llvm::make_second_range(mappings)) {
+      for (Operation *op : llvm::make_first_range(mapping->reverse)) {
         auto insertion = cachedNames.insert({op, op->getName()});
         if (!insertion.second) {
           if (insertion.first->second != op->getName()) {
@@ -993,8 +995,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     }
 
     // Check cached operation names.
-    for (Mappings &mapping : llvm::make_second_range(mappings)) {
-      for (Operation *op : llvm::make_first_range(mapping.reverse)) {
+    for (std::unique_ptr<Mappings> &mapping :
+         llvm::make_second_range(mappings)) {
+      for (Operation *op : llvm::make_first_range(mapping->reverse)) {
         // Make sure that the name of the op has not changed. If it has changed,
         // the op was removed and a new op was allocated at the same memory
         // location. This means that we are missing op tracking somewhere.
@@ -1106,7 +1109,7 @@ transform::TransformState::RegionScope::~RegionScope() {
   // Remember pointers to payload ops referenced by the handles going out of
   // scope.
   SmallVector<Operation *> referencedOps =
-      llvm::to_vector(llvm::make_first_range(state.mappings[region].reverse));
+      llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   state.mappings.erase(region);


        


More information about the Mlir-commits mailing list