[Mlir-commits] [mlir] [mlir][operation] weak refs (PR #97340)
Maksim Levental
llvmlistbot at llvm.org
Mon Jul 1 23:49:30 PDT 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/97340
>From 53692ef7bb6f7d7126842e66acae725b788d39fe Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 1 Jul 2024 14:34:22 -0500
Subject: [PATCH] [mlir][operation] weak refs
---
mlir/include/mlir/IR/MLIRContext.h | 5 ++
mlir/include/mlir/IR/Operation.h | 66 +++++++++++++++++--
mlir/lib/IR/MLIRContext.cpp | 39 +++++++++++
mlir/lib/IR/Operation.cpp | 77 ++++++++++++++++------
mlir/unittests/IR/OperationSupportTest.cpp | 21 ++++++
5 files changed, 183 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 11e5329f43e68..f05f34870764f 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -31,9 +31,11 @@ class DynamicDialect;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
+class Operation;
class RegisteredOperationName;
class StorageUniquer;
class IRUnit;
+class WeakOpRef;
/// MLIRContext is the top-level object for a collection of MLIR operations. It
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -275,6 +277,9 @@ class MLIRContext {
actionFn();
}
+ WeakOpRef acquireWeakOpRef(Operation *op);
+ void expireWeakRefs(Operation *op);
+
private:
/// Return true if the given dialect is currently loading.
bool isDialectLoading(StringRef dialectNamespace);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index f0dd7c5178056..319ae5cfc018e 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -22,6 +22,39 @@
#include <optional>
namespace mlir {
+class WeakOpRef;
+class WeakOpRefHolder {
+private:
+ mlir::Operation *op;
+
+public:
+ WeakOpRefHolder(mlir::Operation *op) : op(op) {}
+ ~WeakOpRefHolder();
+ friend class WeakOpRef;
+};
+
+class WeakOpRef {
+private:
+ std::shared_ptr<WeakOpRefHolder> holder;
+
+public:
+ WeakOpRef(std::shared_ptr<WeakOpRefHolder> const &r);
+
+ WeakOpRef(WeakOpRef const &r);
+ WeakOpRef(WeakOpRef &&r);
+ ~WeakOpRef();
+
+ WeakOpRef &operator=(WeakOpRef const &r);
+ WeakOpRef &operator=(WeakOpRef &&r);
+
+ void swap(WeakOpRef &r);
+ bool expired() const;
+ long use_count() const { return holder ? holder.use_count() : 0; }
+
+ mlir::Operation *operator->() const;
+ mlir::Operation &operator*() const;
+};
+
namespace detail {
/// This is a "tag" used for mapping the properties storage in
/// llvm::TrailingObjects.
@@ -210,7 +243,7 @@ class alignas(8) Operation final
Operation *cloneWithoutRegions();
/// Returns the operation block that contains this operation.
- Block *getBlock() { return block; }
+ Block *getBlock() { return blockHasWeakRefPair.getPointer(); }
/// Return the context this operation is associated with.
MLIRContext *getContext() { return location->getContext(); }
@@ -227,11 +260,15 @@ class alignas(8) Operation final
/// Returns the region to which the instruction belongs. Returns nullptr if
/// the instruction is unlinked.
- Region *getParentRegion() { return block ? block->getParent() : nullptr; }
+ Region *getParentRegion() {
+ return getBlock() ? getBlock()->getParent() : nullptr;
+ }
/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
- Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
+ Operation *getParentOp() {
+ return getBlock() ? getBlock()->getParentOp() : nullptr;
+ }
/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy>
@@ -545,6 +582,7 @@ class alignas(8) Operation final
AttrClass getAttrOfType(StringAttr name) {
return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
}
+
template <typename AttrClass>
AttrClass getAttrOfType(StringRef name) {
return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
@@ -559,6 +597,7 @@ class alignas(8) Operation final
}
return attrs.contains(name);
}
+
bool hasAttr(StringRef name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
@@ -566,6 +605,7 @@ class alignas(8) Operation final
}
return attrs.contains(name);
}
+
template <typename AttrClass, typename NameT>
bool hasAttrOfType(NameT &&name) {
return static_cast<bool>(
@@ -585,6 +625,7 @@ class alignas(8) Operation final
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}
+
void setAttr(StringRef name, Attribute value) {
setAttr(StringAttr::get(getContext(), name), value);
}
@@ -605,6 +646,7 @@ class alignas(8) Operation final
attrs = attributes.getDictionary(getContext());
return removedAttr;
}
+
Attribute removeAttr(StringRef name) {
return removeAttr(StringAttr::get(getContext(), name));
}
@@ -626,6 +668,7 @@ class alignas(8) Operation final
// Allow access to the constructor.
friend Operation;
};
+
using dialect_attr_range = iterator_range<dialect_attr_iterator>;
/// Return a range corresponding to the dialect attributes for this operation.
@@ -634,10 +677,12 @@ class alignas(8) Operation final
return {dialect_attr_iterator(attrs.begin(), attrs.end()),
dialect_attr_iterator(attrs.end(), attrs.end())};
}
+
dialect_attr_iterator dialect_attr_begin() {
auto attrs = getAttrs();
return dialect_attr_iterator(attrs.begin(), attrs.end());
}
+
dialect_attr_iterator dialect_attr_end() {
auto attrs = getAttrs();
return dialect_attr_iterator(attrs.end(), attrs.end());
@@ -705,6 +750,7 @@ class alignas(8) Operation final
assert(index < getNumSuccessors());
return getBlockOperands()[index].get();
}
+
void setSuccessor(Block *block, unsigned index);
//===--------------------------------------------------------------------===//
@@ -892,12 +938,14 @@ class alignas(8) Operation final
int getPropertiesStorageSize() const {
return ((int)propertiesStorageSize) * 8;
}
+
/// Returns the properties storage.
OpaqueProperties getPropertiesStorage() {
if (propertiesStorageSize)
return getPropertiesStorageUnsafe();
return {nullptr};
}
+
OpaqueProperties getPropertiesStorage() const {
if (propertiesStorageSize)
return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
@@ -933,6 +981,11 @@ class alignas(8) Operation final
/// Compute a hash for the op properties (if any).
llvm::hash_code hashProperties();
+ bool hasWeakReference() { return blockHasWeakRefPair.getInt(); }
+ void setHasWeakReference(bool hasWeakRef) {
+ blockHasWeakRefPair.setInt(hasWeakRef);
+ }
+
private:
//===--------------------------------------------------------------------===//
// Ordering
@@ -1016,7 +1069,7 @@ class alignas(8) Operation final
/// requires a 'getParent() const' method. Once ilist_node removes this
/// constraint, we should drop the const to fit the rest of the MLIR const
/// model.
- Block *getParent() const { return block; }
+ Block *getParent() const { return blockHasWeakRefPair.getPointer(); }
/// Expose a few methods explicitly for the debugger to call for
/// visualization.
@@ -1031,8 +1084,9 @@ class alignas(8) Operation final
}
#endif
- /// The operation block that contains this operation.
- Block *block = nullptr;
+ /// The operation block that contains this operation and a bit that signifies
+ /// if the operation has a weak reference.
+ llvm::PointerIntPair<Block *, /*IntBits=*/1, bool> blockHasWeakRefPair;
/// This holds information about the source location the operation was defined
/// or derived from.
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 214b354c5347e..bd466a2cde990 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -271,6 +271,9 @@ class MLIRContextImpl {
/// destruction.
DistinctAttributeAllocator distinctAttributeAllocator;
+ llvm::sys::SmartRWMutex<true> weakOperationRefsMutex;
+ DenseMap<Operation *, std::weak_ptr<WeakOpRefHolder>> weakOperationReferences;
+
public:
MLIRContextImpl(bool threadingIsEnabled)
: threadingIsEnabled(threadingIsEnabled) {
@@ -393,6 +396,42 @@ void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
+WeakOpRef MLIRContext::acquireWeakOpRef(Operation *op) {
+ {
+ llvm::sys::SmartScopedReader<true> contextLock(
+ impl->weakOperationRefsMutex);
+ auto it = impl->weakOperationReferences.find(op);
+ if (it != impl->weakOperationReferences.end()) {
+ assert(op->hasWeakReference() &&
+ "op should report having weak references");
+ return {it->second.lock()};
+ }
+ }
+ {
+ ScopedWriterLock contextLock(impl->weakOperationRefsMutex,
+ isMultithreadingEnabled());
+ auto shared = std::make_shared<WeakOpRefHolder>(op);
+ (void)impl->weakOperationReferences.insert({op, shared});
+ op->setHasWeakReference(true);
+ return {shared};
+ }
+}
+
+void MLIRContext::expireWeakRefs(Operation *op) {
+ if (op && impl) {
+ ScopedWriterLock lock(impl->weakOperationRefsMutex,
+ isMultithreadingEnabled());
+ if (auto it = impl->weakOperationReferences.find(op);
+ it != impl->weakOperationReferences.end()) {
+ if (!it->second.expired())
+ it->second.reset();
+ assert(it->second.expired() && "should be expired");
+ impl->weakOperationReferences.erase(op);
+ }
+ op->setHasWeakReference(false);
+ }
+}
+
//===----------------------------------------------------------------------===//
// Diagnostic Handlers
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b51357198b1ca..dbbd32cffc2bb 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -26,6 +26,42 @@
using namespace mlir;
+WeakOpRef::WeakOpRef(const std::shared_ptr<WeakOpRefHolder> &r) : holder(r) {}
+
+// copy constructor
+WeakOpRef::WeakOpRef(const WeakOpRef &r) : holder(r.holder) {}
+
+// move constructor
+WeakOpRef::WeakOpRef(WeakOpRef &&r) : holder(r.holder) { r.holder = nullptr; }
+
+WeakOpRef::~WeakOpRef() {}
+
+// copy assignment
+WeakOpRef &WeakOpRef::operator=(const WeakOpRef &r) {
+ WeakOpRef(r).swap(*this);
+ return *this;
+}
+
+// move assignment
+WeakOpRef &WeakOpRef::operator=(WeakOpRef &&r) {
+ WeakOpRef(std::move(r)).swap(*this);
+ return *this;
+}
+
+void WeakOpRef::swap(WeakOpRef &r) { std::swap(holder, r.holder); }
+
+void swap(WeakOpRef &x, WeakOpRef &y) { x.swap(y); }
+
+Operation *WeakOpRef::operator->() const { return this->holder->op; }
+
+Operation &WeakOpRef::operator*() const { return *this->holder->op; }
+
+bool WeakOpRef::expired() const {
+ return !bool(holder) || holder.use_count() == 0;
+}
+
+WeakOpRefHolder::~WeakOpRefHolder() { op->getContext()->expireWeakRefs(op); }
+
//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//
@@ -177,7 +213,7 @@ Operation::Operation(Location location, OperationName name, unsigned numResults,
// Operations are deleted through the destroy() member because they are
// allocated via malloc.
Operation::~Operation() {
- assert(block == nullptr && "operation destroyed but still in a block");
+ assert(getBlock() == nullptr && "operation destroyed but still in a block");
#ifndef NDEBUG
if (!use_empty()) {
{
@@ -202,6 +238,9 @@ Operation::~Operation() {
region.~Region();
if (propertiesStorageSize)
name.destroyOpProperties(getPropertiesStorage());
+
+ if (hasWeakReference())
+ getContext()->expireWeakRefs(this);
}
/// Destroy this operation or one of its subclasses.
@@ -322,8 +361,8 @@ void Operation::setAttrs(DictionaryAttr newAttrs) {
}
void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
if (getPropertiesStorageSize()) {
- // We're spliting the providing array of attributes by removing the inherentAttr
- // which will be stored in the properties.
+ // We're spliting the providing array of attributes by removing the
+ // inherentAttr which will be stored in the properties.
SmallVector<NamedAttribute> discardableAttrs;
discardableAttrs.reserve(newAttrs.size());
for (NamedAttribute attr : newAttrs) {
@@ -384,13 +423,13 @@ constexpr unsigned Operation::kOrderStride;
/// Note: This function has an average complexity of O(1), but worst case may
/// take O(N) where N is the number of operations within the parent block.
bool Operation::isBeforeInBlock(Operation *other) {
- assert(block && "Operations without parent blocks have no order.");
- assert(other && other->block == block &&
+ assert(getBlock() && "Operations without parent blocks have no order.");
+ assert(other && other->getBlock() == getBlock() &&
"Expected other operation to have the same parent block.");
// If the order of the block is already invalid, directly recompute the
// parent.
- if (!block->isOpOrderValid()) {
- block->recomputeOpOrder();
+ if (!getBlock()->isOpOrderValid()) {
+ getBlock()->recomputeOpOrder();
} else {
// Update the order either operation if necessary.
updateOrderIfNecessary();
@@ -403,13 +442,13 @@ bool Operation::isBeforeInBlock(Operation *other) {
/// Update the order index of this operation of this operation if necessary,
/// potentially recomputing the order of the parent block.
void Operation::updateOrderIfNecessary() {
- assert(block && "expected valid parent");
+ assert(getBlock() && "expected valid parent");
// If the order is valid for this operation there is nothing to do.
if (hasValidOrder())
return;
- Operation *blockFront = &block->front();
- Operation *blockBack = &block->back();
+ Operation *blockFront = &getBlock()->front();
+ Operation *blockBack = &getBlock()->back();
// This method is expected to only be invoked on blocks with more than one
// operation.
@@ -419,7 +458,7 @@ void Operation::updateOrderIfNecessary() {
if (this == blockBack) {
Operation *prevNode = getPrevNode();
if (!prevNode->hasValidOrder())
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
// Add the stride to the previous operation.
orderIndex = prevNode->orderIndex + kOrderStride;
@@ -431,10 +470,10 @@ void Operation::updateOrderIfNecessary() {
if (this == blockFront) {
Operation *nextNode = getNextNode();
if (!nextNode->hasValidOrder())
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
// There is no order to give this operation.
if (nextNode->orderIndex == 0)
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
// If we can't use the stride, just take the middle value left. This is safe
// because we know there is at least one valid index to assign to.
@@ -449,12 +488,12 @@ void Operation::updateOrderIfNecessary() {
// the middle of the previous and next if possible.
Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
// Check to see if there is a valid order between the two.
if (prevOrder + 1 == nextOrder)
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
orderIndex = prevOrder + ((nextOrder - prevOrder) / 2);
}
@@ -502,7 +541,7 @@ Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
/// keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
assert(!op->getBlock() && "already in an operation block!");
- op->block = getContainingBlock();
+ op->blockHasWeakRefPair.setPointer(getContainingBlock());
// Invalidate the order on the operation.
op->orderIndex = Operation::kInvalidOrderIdx;
@@ -511,8 +550,8 @@ void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
/// This is a trait method invoked when an operation is removed from a block.
/// We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
- assert(op->block && "not already in an operation block!");
- op->block = nullptr;
+ assert(op->getBlock() && "not already in an operation block!");
+ op->blockHasWeakRefPair.setPointer(nullptr);
}
/// This is a trait method invoked when an operation is moved from one block
@@ -531,7 +570,7 @@ void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
// Update the 'block' member of each operation.
for (; first != last; ++first)
- first->block = curParent;
+ first->blockHasWeakRefPair.setPointer(curParent);
}
/// Remove this operation (and its descendants) from its Block and delete
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index f94dc78445807..e5b48851cc48e 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -313,4 +313,25 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
op2->destroy();
}
+TEST(WeakOpRefTest, Test1) {
+ MLIRContext context;
+ context.getOrLoadDialect<test::TestDialect>();
+
+ auto *op1 = createOp(&context);
+ EXPECT_EQ(op1->hasWeakReference(), false);
+ {
+ WeakOpRef weakRef1 = context.acquireWeakOpRef(op1);
+ EXPECT_EQ(weakRef1.use_count(), 1);
+ EXPECT_EQ(op1->hasWeakReference(), true);
+ {
+ WeakOpRef weakRef2 = context.acquireWeakOpRef(op1);
+ EXPECT_EQ(weakRef2.use_count(), 2);
+ EXPECT_EQ(op1->hasWeakReference(), true);
+ }
+ EXPECT_EQ(weakRef1.use_count(), 1);
+ EXPECT_EQ(op1->hasWeakReference(), true);
+ }
+ EXPECT_EQ(op1->hasWeakReference(), false);
+}
+
} // namespace
More information about the Mlir-commits
mailing list