[Mlir-commits] [mlir] [mlir][operation] weak refs (PR #97340)

Maksim Levental llvmlistbot at llvm.org
Mon Jul 1 23:29:23 PDT 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/97340

>From 7d5b5aad366cb6497e7aded0e80c1466b8652018 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 1 Jul 2024 14:34:22 -0500
Subject: [PATCH] [mlir][operation] weak refs

---
 mlir/include/mlir/IR/MLIRContext.h         |  5 ++
 mlir/include/mlir/IR/Operation.h           | 66 +++++++++++++++++++--
 mlir/lib/IR/MLIRContext.cpp                | 39 +++++++++++++
 mlir/lib/IR/Operation.cpp                  | 67 +++++++++++++++++-----
 mlir/unittests/IR/OperationSupportTest.cpp | 21 +++++++
 5 files changed, 178 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 11e5329f43e68..f05f34870764f 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -31,9 +31,11 @@ class DynamicDialect;
 class InFlightDiagnostic;
 class Location;
 class MLIRContextImpl;
+class Operation;
 class RegisteredOperationName;
 class StorageUniquer;
 class IRUnit;
+class WeakOpRef;
 
 /// MLIRContext is the top-level object for a collection of MLIR operations. It
 /// holds immortal uniqued objects like types, and the tables used to unique
@@ -275,6 +277,9 @@ class MLIRContext {
       actionFn();
   }
 
+  WeakOpRef acquireWeakOpRef(Operation *op);
+  void expireWeakRefs(Operation *op);
+
 private:
   /// Return true if the given dialect is currently loading.
   bool isDialectLoading(StringRef dialectNamespace);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index f0dd7c5178056..319ae5cfc018e 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -22,6 +22,39 @@
 #include <optional>
 
 namespace mlir {
+class WeakOpRef;
+class WeakOpRefHolder {
+private:
+  mlir::Operation *op;
+
+public:
+  WeakOpRefHolder(mlir::Operation *op) : op(op) {}
+  ~WeakOpRefHolder();
+  friend class WeakOpRef;
+};
+
+class WeakOpRef {
+private:
+  std::shared_ptr<WeakOpRefHolder> holder;
+
+public:
+  WeakOpRef(std::shared_ptr<WeakOpRefHolder> const &r);
+
+  WeakOpRef(WeakOpRef const &r);
+  WeakOpRef(WeakOpRef &&r);
+  ~WeakOpRef();
+
+  WeakOpRef &operator=(WeakOpRef const &r);
+  WeakOpRef &operator=(WeakOpRef &&r);
+
+  void swap(WeakOpRef &r);
+  bool expired() const;
+  long use_count() const { return holder ? holder.use_count() : 0; }
+
+  mlir::Operation *operator->() const;
+  mlir::Operation &operator*() const;
+};
+
 namespace detail {
 /// This is a "tag" used for mapping the properties storage in
 /// llvm::TrailingObjects.
@@ -210,7 +243,7 @@ class alignas(8) Operation final
   Operation *cloneWithoutRegions();
 
   /// Returns the operation block that contains this operation.
-  Block *getBlock() { return block; }
+  Block *getBlock() { return blockHasWeakRefPair.getPointer(); }
 
   /// Return the context this operation is associated with.
   MLIRContext *getContext() { return location->getContext(); }
@@ -227,11 +260,15 @@ class alignas(8) Operation final
 
   /// Returns the region to which the instruction belongs. Returns nullptr if
   /// the instruction is unlinked.
-  Region *getParentRegion() { return block ? block->getParent() : nullptr; }
+  Region *getParentRegion() {
+    return getBlock() ? getBlock()->getParent() : nullptr;
+  }
 
   /// Returns the closest surrounding operation that contains this operation
   /// or nullptr if this is a top-level operation.
-  Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
+  Operation *getParentOp() {
+    return getBlock() ? getBlock()->getParentOp() : nullptr;
+  }
 
   /// Return the closest surrounding parent operation that is of type 'OpTy'.
   template <typename OpTy>
@@ -545,6 +582,7 @@ class alignas(8) Operation final
   AttrClass getAttrOfType(StringAttr name) {
     return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
   }
+
   template <typename AttrClass>
   AttrClass getAttrOfType(StringRef name) {
     return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
@@ -559,6 +597,7 @@ class alignas(8) Operation final
     }
     return attrs.contains(name);
   }
+
   bool hasAttr(StringRef name) {
     if (getPropertiesStorageSize()) {
       if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
@@ -566,6 +605,7 @@ class alignas(8) Operation final
     }
     return attrs.contains(name);
   }
+
   template <typename AttrClass, typename NameT>
   bool hasAttrOfType(NameT &&name) {
     return static_cast<bool>(
@@ -585,6 +625,7 @@ class alignas(8) Operation final
     if (attributes.set(name, value) != value)
       attrs = attributes.getDictionary(getContext());
   }
+
   void setAttr(StringRef name, Attribute value) {
     setAttr(StringAttr::get(getContext(), name), value);
   }
@@ -605,6 +646,7 @@ class alignas(8) Operation final
       attrs = attributes.getDictionary(getContext());
     return removedAttr;
   }
+
   Attribute removeAttr(StringRef name) {
     return removeAttr(StringAttr::get(getContext(), name));
   }
@@ -626,6 +668,7 @@ class alignas(8) Operation final
     // Allow access to the constructor.
     friend Operation;
   };
+
   using dialect_attr_range = iterator_range<dialect_attr_iterator>;
 
   /// Return a range corresponding to the dialect attributes for this operation.
@@ -634,10 +677,12 @@ class alignas(8) Operation final
     return {dialect_attr_iterator(attrs.begin(), attrs.end()),
             dialect_attr_iterator(attrs.end(), attrs.end())};
   }
+
   dialect_attr_iterator dialect_attr_begin() {
     auto attrs = getAttrs();
     return dialect_attr_iterator(attrs.begin(), attrs.end());
   }
+
   dialect_attr_iterator dialect_attr_end() {
     auto attrs = getAttrs();
     return dialect_attr_iterator(attrs.end(), attrs.end());
@@ -705,6 +750,7 @@ class alignas(8) Operation final
     assert(index < getNumSuccessors());
     return getBlockOperands()[index].get();
   }
+
   void setSuccessor(Block *block, unsigned index);
 
   //===--------------------------------------------------------------------===//
@@ -892,12 +938,14 @@ class alignas(8) Operation final
   int getPropertiesStorageSize() const {
     return ((int)propertiesStorageSize) * 8;
   }
+
   /// Returns the properties storage.
   OpaqueProperties getPropertiesStorage() {
     if (propertiesStorageSize)
       return getPropertiesStorageUnsafe();
     return {nullptr};
   }
+
   OpaqueProperties getPropertiesStorage() const {
     if (propertiesStorageSize)
       return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
@@ -933,6 +981,11 @@ class alignas(8) Operation final
   /// Compute a hash for the op properties (if any).
   llvm::hash_code hashProperties();
 
+  bool hasWeakReference() { return blockHasWeakRefPair.getInt(); }
+  void setHasWeakReference(bool hasWeakRef) {
+    blockHasWeakRefPair.setInt(hasWeakRef);
+  }
+
 private:
   //===--------------------------------------------------------------------===//
   // Ordering
@@ -1016,7 +1069,7 @@ class alignas(8) Operation final
   /// requires a 'getParent() const' method. Once ilist_node removes this
   /// constraint, we should drop the const to fit the rest of the MLIR const
   /// model.
-  Block *getParent() const { return block; }
+  Block *getParent() const { return blockHasWeakRefPair.getPointer(); }
 
   /// Expose a few methods explicitly for the debugger to call for
   /// visualization.
@@ -1031,8 +1084,9 @@ class alignas(8) Operation final
   }
 #endif
 
-  /// The operation block that contains this operation.
-  Block *block = nullptr;
+  /// The operation block that contains this operation and a bit that signifies
+  /// if the operation has a weak reference.
+  llvm::PointerIntPair<Block *, /*IntBits=*/1, bool> blockHasWeakRefPair;
 
   /// This holds information about the source location the operation was defined
   /// or derived from.
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 214b354c5347e..bd466a2cde990 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -271,6 +271,9 @@ class MLIRContextImpl {
   /// destruction.
   DistinctAttributeAllocator distinctAttributeAllocator;
 
+  llvm::sys::SmartRWMutex<true> weakOperationRefsMutex;
+  DenseMap<Operation *, std::weak_ptr<WeakOpRefHolder>> weakOperationReferences;
+
 public:
   MLIRContextImpl(bool threadingIsEnabled)
       : threadingIsEnabled(threadingIsEnabled) {
@@ -393,6 +396,42 @@ void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
 
 bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
 
+WeakOpRef MLIRContext::acquireWeakOpRef(Operation *op) {
+  {
+    llvm::sys::SmartScopedReader<true> contextLock(
+        impl->weakOperationRefsMutex);
+    auto it = impl->weakOperationReferences.find(op);
+    if (it != impl->weakOperationReferences.end()) {
+      assert(op->hasWeakReference() &&
+             "op should report having weak references");
+      return {it->second.lock()};
+    }
+  }
+  {
+    ScopedWriterLock contextLock(impl->weakOperationRefsMutex,
+                                 isMultithreadingEnabled());
+    auto shared = std::make_shared<WeakOpRefHolder>(op);
+    (void)impl->weakOperationReferences.insert({op, shared});
+    op->setHasWeakReference(true);
+    return {shared};
+  }
+}
+
+void MLIRContext::expireWeakRefs(Operation *op) {
+  if (op && impl) {
+    ScopedWriterLock lock(impl->weakOperationRefsMutex,
+                          isMultithreadingEnabled());
+    if (auto it = impl->weakOperationReferences.find(op);
+        it != impl->weakOperationReferences.end()) {
+      if (!it->second.expired())
+        it->second.reset();
+      assert(it->second.expired() && "should be expired");
+      impl->weakOperationReferences.erase(op);
+    }
+    op->setHasWeakReference(false);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Diagnostic Handlers
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b51357198b1ca..755f0aeb2d2ee 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -26,6 +26,42 @@
 
 using namespace mlir;
 
+WeakOpRef::WeakOpRef(const std::shared_ptr<WeakOpRefHolder> &r) : holder(r) {}
+
+// copy constructor
+WeakOpRef::WeakOpRef(const WeakOpRef &r) : holder(r.holder) {}
+
+// move constructor
+WeakOpRef::WeakOpRef(WeakOpRef &&r) : holder(r.holder) { r.holder = nullptr; }
+
+WeakOpRef::~WeakOpRef() {}
+
+// copy assignment
+WeakOpRef &WeakOpRef::operator=(const WeakOpRef &r) {
+  WeakOpRef(r).swap(*this);
+  return *this;
+}
+
+// move assignment
+WeakOpRef &WeakOpRef::operator=(WeakOpRef &&r) {
+  WeakOpRef(std::move(r)).swap(*this);
+  return *this;
+}
+
+void WeakOpRef::swap(WeakOpRef &r) { std::swap(holder, r.holder); }
+
+void swap(WeakOpRef &x, WeakOpRef &y) { x.swap(y); }
+
+Operation *WeakOpRef::operator->() const { return this->holder->op; }
+
+Operation &WeakOpRef::operator*() const { return *this->holder->op; }
+
+bool WeakOpRef::expired() const {
+  return !bool(holder) || holder.use_count() == 0;
+}
+
+WeakOpRefHolder::~WeakOpRefHolder() { op->getContext()->expireWeakRefs(op); }
+
 //===----------------------------------------------------------------------===//
 // Operation
 //===----------------------------------------------------------------------===//
@@ -202,6 +238,9 @@ Operation::~Operation() {
     region.~Region();
   if (propertiesStorageSize)
     name.destroyOpProperties(getPropertiesStorage());
+
+  if (hasWeakReference())
+    getContext()->expireWeakRefs(this);
 }
 
 /// Destroy this operation or one of its subclasses.
@@ -322,8 +361,8 @@ void Operation::setAttrs(DictionaryAttr newAttrs) {
 }
 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
   if (getPropertiesStorageSize()) {
-    // We're spliting the providing array of attributes by removing the inherentAttr
-    // which will be stored in the properties.
+    // We're spliting the providing array of attributes by removing the
+    // inherentAttr which will be stored in the properties.
     SmallVector<NamedAttribute> discardableAttrs;
     discardableAttrs.reserve(newAttrs.size());
     for (NamedAttribute attr : newAttrs) {
@@ -389,8 +428,8 @@ bool Operation::isBeforeInBlock(Operation *other) {
          "Expected other operation to have the same parent block.");
   // If the order of the block is already invalid, directly recompute the
   // parent.
-  if (!block->isOpOrderValid()) {
-    block->recomputeOpOrder();
+  if (!getBlock()->isOpOrderValid()) {
+    getBlock()->recomputeOpOrder();
   } else {
     // Update the order either operation if necessary.
     updateOrderIfNecessary();
@@ -408,8 +447,8 @@ void Operation::updateOrderIfNecessary() {
   // If the order is valid for this operation there is nothing to do.
   if (hasValidOrder())
     return;
-  Operation *blockFront = &block->front();
-  Operation *blockBack = &block->back();
+  Operation *blockFront = &getBlock()->front();
+  Operation *blockBack = &getBlock()->back();
 
   // This method is expected to only be invoked on blocks with more than one
   // operation.
@@ -419,7 +458,7 @@ void Operation::updateOrderIfNecessary() {
   if (this == blockBack) {
     Operation *prevNode = getPrevNode();
     if (!prevNode->hasValidOrder())
-      return block->recomputeOpOrder();
+      return getBlock()->recomputeOpOrder();
 
     // Add the stride to the previous operation.
     orderIndex = prevNode->orderIndex + kOrderStride;
@@ -431,10 +470,10 @@ void Operation::updateOrderIfNecessary() {
   if (this == blockFront) {
     Operation *nextNode = getNextNode();
     if (!nextNode->hasValidOrder())
-      return block->recomputeOpOrder();
+      return getBlock()->recomputeOpOrder();
     // There is no order to give this operation.
     if (nextNode->orderIndex == 0)
-      return block->recomputeOpOrder();
+      return getBlock()->recomputeOpOrder();
 
     // If we can't use the stride, just take the middle value left. This is safe
     // because we know there is at least one valid index to assign to.
@@ -449,12 +488,12 @@ void Operation::updateOrderIfNecessary() {
   // the middle of the previous and next if possible.
   Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
   if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
-    return block->recomputeOpOrder();
+    return getBlock()->recomputeOpOrder();
   unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
 
   // Check to see if there is a valid order between the two.
   if (prevOrder + 1 == nextOrder)
-    return block->recomputeOpOrder();
+    return getBlock()->recomputeOpOrder();
   orderIndex = prevOrder + ((nextOrder - prevOrder) / 2);
 }
 
@@ -502,7 +541,7 @@ Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
 /// keep the block pointer up to date.
 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
   assert(!op->getBlock() && "already in an operation block!");
-  op->block = getContainingBlock();
+  op->blockHasWeakRefPair.setPointer(getContainingBlock());
 
   // Invalidate the order on the operation.
   op->orderIndex = Operation::kInvalidOrderIdx;
@@ -512,7 +551,7 @@ void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
 /// We keep the block pointer up to date.
 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
   assert(op->block && "not already in an operation block!");
-  op->block = nullptr;
+  op->blockHasWeakRefPair.setPointer(nullptr);
 }
 
 /// This is a trait method invoked when an operation is moved from one block
@@ -531,7 +570,7 @@ void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
 
   // Update the 'block' member of each operation.
   for (; first != last; ++first)
-    first->block = curParent;
+    first->blockHasWeakRefPair.setPointer(curParent);
 }
 
 /// Remove this operation (and its descendants) from its Block and delete
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index f94dc78445807..e5b48851cc48e 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -313,4 +313,25 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
   op2->destroy();
 }
 
+TEST(WeakOpRefTest, Test1) {
+  MLIRContext context;
+  context.getOrLoadDialect<test::TestDialect>();
+
+  auto *op1 = createOp(&context);
+  EXPECT_EQ(op1->hasWeakReference(), false);
+  {
+    WeakOpRef weakRef1 = context.acquireWeakOpRef(op1);
+    EXPECT_EQ(weakRef1.use_count(), 1);
+    EXPECT_EQ(op1->hasWeakReference(), true);
+    {
+      WeakOpRef weakRef2 = context.acquireWeakOpRef(op1);
+      EXPECT_EQ(weakRef2.use_count(), 2);
+      EXPECT_EQ(op1->hasWeakReference(), true);
+    }
+    EXPECT_EQ(weakRef1.use_count(), 1);
+    EXPECT_EQ(op1->hasWeakReference(), true);
+  }
+  EXPECT_EQ(op1->hasWeakReference(), false);
+}
+
 } // namespace



More information about the Mlir-commits mailing list