[Mlir-commits] [mlir] [mlir][operation] weak refs (PR #97340)

Maksim Levental llvmlistbot at llvm.org
Mon Jul 1 13:06:40 PDT 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/97340

>From 3e6ed96c23de05d492f29de3bfbc9b5323021808 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 1 Jul 2024 14:34:22 -0500
Subject: [PATCH] [mlir][operation] weak refs

---
 mlir/include/mlir/IR/Operation.h | 19 +++++++++++++------
 mlir/lib/IR/Operation.cpp        | 28 ++++++++++++++--------------
 2 files changed, 27 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index f0dd7c5178056..3cc5a34a63ab4 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -210,7 +210,7 @@ class alignas(8) Operation final
   Operation *cloneWithoutRegions();
 
   /// Returns the operation block that contains this operation.
-  Block *getBlock() { return block; }
+  Block *getBlock() { return blockHasWeakRefPair.getPointer(); }
 
   /// Return the context this operation is associated with.
   MLIRContext *getContext() { return location->getContext(); }
@@ -227,11 +227,15 @@ class alignas(8) Operation final
 
   /// Returns the region to which the instruction belongs. Returns nullptr if
   /// the instruction is unlinked.
-  Region *getParentRegion() { return block ? block->getParent() : nullptr; }
+  Region *getParentRegion() {
+    return getBlock() ? getBlock()->getParent() : nullptr;
+  }
 
   /// Returns the closest surrounding operation that contains this operation
   /// or nullptr if this is a top-level operation.
-  Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
+  Operation *getParentOp() {
+    return getBlock() ? getBlock()->getParentOp() : nullptr;
+  }
 
   /// Return the closest surrounding parent operation that is of type 'OpTy'.
   template <typename OpTy>
@@ -1016,7 +1020,7 @@ class alignas(8) Operation final
   /// requires a 'getParent() const' method. Once ilist_node removes this
   /// constraint, we should drop the const to fit the rest of the MLIR const
   /// model.
-  Block *getParent() const { return block; }
+  Block *getParent() const { return blockHasWeakRefPair.getPointer(); }
 
   /// Expose a few methods explicitly for the debugger to call for
   /// visualization.
@@ -1031,8 +1035,11 @@ class alignas(8) Operation final
   }
 #endif
 
-  /// The operation block that contains this operation.
-  Block *block = nullptr;
+  /// The operation block that contains this operation and a bit that signifies
+  /// if the operation has a weak reference.
+  llvm::PointerIntPair<Block *, /*IntBits=*/1, bool> blockHasWeakRefPair;
+
+  bool hasWeakReference() { return blockHasWeakRefPair.getInt(); }
 
   /// This holds information about the source location the operation was defined
   /// or derived from.
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b51357198b1ca..ae44475a83d3e 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -322,8 +322,8 @@ void Operation::setAttrs(DictionaryAttr newAttrs) {
 }
 void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
   if (getPropertiesStorageSize()) {
-    // We're spliting the providing array of attributes by removing the inherentAttr
-    // which will be stored in the properties.
+    // We're spliting the providing array of attributes by removing the
+    // inherentAttr which will be stored in the properties.
     SmallVector<NamedAttribute> discardableAttrs;
     discardableAttrs.reserve(newAttrs.size());
     for (NamedAttribute attr : newAttrs) {
@@ -389,8 +389,8 @@ bool Operation::isBeforeInBlock(Operation *other) {
          "Expected other operation to have the same parent block.");
   // If the order of the block is already invalid, directly recompute the
   // parent.
-  if (!block->isOpOrderValid()) {
-    block->recomputeOpOrder();
+  if (!getBlock()->isOpOrderValid()) {
+    getBlock()->recomputeOpOrder();
   } else {
     // Update the order either operation if necessary.
     updateOrderIfNecessary();
@@ -408,8 +408,8 @@ void Operation::updateOrderIfNecessary() {
   // If the order is valid for this operation there is nothing to do.
   if (hasValidOrder())
     return;
-  Operation *blockFront = &block->front();
-  Operation *blockBack = &block->back();
+  Operation *blockFront = &getBlock()->front();
+  Operation *blockBack = &getBlock()->back();
 
   // This method is expected to only be invoked on blocks with more than one
   // operation.
@@ -419,7 +419,7 @@ void Operation::updateOrderIfNecessary() {
   if (this == blockBack) {
     Operation *prevNode = getPrevNode();
     if (!prevNode->hasValidOrder())
-      return block->recomputeOpOrder();
+      return getBlock()->recomputeOpOrder();
 
     // Add the stride to the previous operation.
     orderIndex = prevNode->orderIndex + kOrderStride;
@@ -431,10 +431,10 @@ void Operation::updateOrderIfNecessary() {
   if (this == blockFront) {
     Operation *nextNode = getNextNode();
     if (!nextNode->hasValidOrder())
-      return block->recomputeOpOrder();
+      return getBlock()->recomputeOpOrder();
     // There is no order to give this operation.
     if (nextNode->orderIndex == 0)
-      return block->recomputeOpOrder();
+      return getBlock()->recomputeOpOrder();
 
     // If we can't use the stride, just take the middle value left. This is safe
     // because we know there is at least one valid index to assign to.
@@ -449,12 +449,12 @@ void Operation::updateOrderIfNecessary() {
   // the middle of the previous and next if possible.
   Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
   if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
-    return block->recomputeOpOrder();
+    return getBlock()->recomputeOpOrder();
   unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
 
   // Check to see if there is a valid order between the two.
   if (prevOrder + 1 == nextOrder)
-    return block->recomputeOpOrder();
+    return getBlock()->recomputeOpOrder();
   orderIndex = prevOrder + ((nextOrder - prevOrder) / 2);
 }
 
@@ -502,7 +502,7 @@ Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
 /// keep the block pointer up to date.
 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
   assert(!op->getBlock() && "already in an operation block!");
-  op->block = getContainingBlock();
+  op->blockHasWeakRefPair.setPointer(getContainingBlock());
 
   // Invalidate the order on the operation.
   op->orderIndex = Operation::kInvalidOrderIdx;
@@ -512,7 +512,7 @@ void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
 /// We keep the block pointer up to date.
 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
   assert(op->block && "not already in an operation block!");
-  op->block = nullptr;
+  op->blockHasWeakRefPair.setPointer(nullptr);
 }
 
 /// This is a trait method invoked when an operation is moved from one block
@@ -531,7 +531,7 @@ void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
 
   // Update the 'block' member of each operation.
   for (; first != last; ++first)
-    first->block = curParent;
+    first->blockHasWeakRefPair.setPointer(curParent);
 }
 
 /// Remove this operation (and its descendants) from its Block and delete



More information about the Mlir-commits mailing list