[Mlir-commits] [mlir] a700a1d - [mlir] use shared pointer to prevent vector resizes from destroying ops
Ashay Rane
llvmlistbot at llvm.org
Fri Jul 7 08:42:01 PDT 2023
Author: Ashay Rane
Date: 2023-07-07T10:41:48-05:00
New Revision: a700a1db8bbaffe3cff09da4ee11419ab3a057b3
URL: https://github.com/llvm/llvm-project/commit/a700a1db8bbaffe3cff09da4ee11419ab3a057b3
DIFF: https://github.com/llvm/llvm-project/commit/a700a1db8bbaffe3cff09da4ee11419ab3a057b3.diff
LOG: [mlir] use shared pointer to prevent vector resizes from destroying ops
The `MapVector` type stores key-value pairs in a vector, which, when
resized, copies the entries and destroys the old ones. This causes the
underlying operations to be deleted, subsequently causing segfaults.
This patch makes the `mappings` map type refer to a shared pointer
instead, so that resizing the vector doesn't call the operations'
destructors.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D154511
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 8411eb0dd84128..629d4b55815a00 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -285,7 +285,8 @@ class TransformState {
/// transform IR region and payload IR objects.
RegionScope(TransformState &state, Region ®ion)
: state(state), region(®ion) {
- auto res = state.mappings.insert(std::make_pair(®ion, Mappings()));
+ auto res = state.mappings.insert(
+ std::make_pair(®ion, std::make_unique<Mappings>()));
assert(res.second && "the region scope is already present");
(void)res;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -437,7 +438,7 @@ class TransformState {
}
}
#endif // NDEBUG
- return it->second;
+ return *it->second.get();
}
/// Returns the mappings frame for the region in which the operation resides.
@@ -464,7 +465,7 @@ class TransformState {
}
}
#endif // NDEBUG
- return it->second;
+ return *it->second.get();
}
/// Updates the state to include the associations between op results and the
@@ -683,7 +684,10 @@ class TransformState {
/// A stack of mappings between transform IR values and payload IR ops,
/// aggregated by the region in which the transform IR values are defined.
- llvm::MapVector<Region *, Mappings> mappings;
+ /// We use a pointer to the Mappings struct so that reallocations inside
+ /// MapVector don't invalidate iterators when we apply nested transform ops
+ /// while also iterating over the mappings.
+ llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
/// Op handles may be temporarily mapped to nullptr to avoid invalidating
/// payload op iterators. This set contains all op handles with nullptrs.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 889be71c1f34cb..ed987ac4b51646 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -45,7 +45,8 @@ transform::TransformState::TransformState(
for (ArrayRef<MappedValue> mapping : extraMappings)
topLevelMappedValues.push_back(mapping);
- auto result = mappings.insert(std::make_pair(region, Mappings()));
+ auto result =
+ mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -87,8 +88,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
bool includeOutOfScope) const {
bool found = false;
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
- auto iterator = mapping.reverse.find(op);
- if (iterator != mapping.reverse.end()) {
+ auto iterator = mapping->reverse.find(op);
+ if (iterator != mapping->reverse.end()) {
llvm::append_range(handles, iterator->getSecond());
found = true;
}
@@ -106,8 +107,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadValue(
bool includeOutOfScope) const {
bool found = false;
for (const auto &[region, mapping] : llvm::reverse(mappings)) {
- auto iterator = mapping.reverseValues.find(payloadValue);
- if (iterator != mapping.reverseValues.end()) {
+ auto iterator = mapping->reverseValues.find(payloadValue);
+ if (iterator != mapping->reverseValues.end()) {
llvm::append_range(handles, iterator->getSecond());
found = true;
}
@@ -611,7 +612,7 @@ void transform::TransformState::recordOpHandleInvalidation(
// Go over all op handle mappings and mark as invalidated any handle
// pointing to any of the payload ops associated with the given handle or
// any op nested in them.
- for (const auto &[payloadOp, otherHandles] : mapping.reverse) {
+ for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
for (Value otherHandle : otherHandles)
recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
otherHandle, throughValue,
@@ -622,7 +623,7 @@ void transform::TransformState::recordOpHandleInvalidation(
// or any op nested in them. Similarly invalidate handles to argument of
// blocks belonging to any region of any payload op associated with the
// given handle or any op nested in them.
- for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) {
+ for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
for (Value valueHandle : valueHandles)
recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
payloadValue, valueHandle,
@@ -842,8 +843,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// Cache Operation* -> OperationName mappings. These will be checked after
// the transform has been applied to detect incorrect memory side effects
// and missing op tracking.
- for (Mappings &mapping : llvm::make_second_range(mappings)) {
- for (Operation *op : llvm::make_first_range(mapping.reverse)) {
+ for (std::unique_ptr<Mappings> &mapping :
+ llvm::make_second_range(mappings)) {
+ for (Operation *op : llvm::make_first_range(mapping->reverse)) {
auto insertion = cachedNames.insert({op, op->getName()});
if (!insertion.second) {
if (insertion.first->second != op->getName()) {
@@ -993,8 +995,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
// Check cached operation names.
- for (Mappings &mapping : llvm::make_second_range(mappings)) {
- for (Operation *op : llvm::make_first_range(mapping.reverse)) {
+ for (std::unique_ptr<Mappings> &mapping :
+ llvm::make_second_range(mappings)) {
+ for (Operation *op : llvm::make_first_range(mapping->reverse)) {
// Make sure that the name of the op has not changed. If it has changed,
// the op was removed and a new op was allocated at the same memory
// location. This means that we are missing op tracking somewhere.
@@ -1106,7 +1109,7 @@ transform::TransformState::RegionScope::~RegionScope() {
// Remember pointers to payload ops referenced by the handles going out of
// scope.
SmallVector<Operation *> referencedOps =
- llvm::to_vector(llvm::make_first_range(state.mappings[region].reverse));
+ llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.mappings.erase(region);
More information about the Mlir-commits
mailing list