[Mlir-commits] [mlir] [mlir][operation] weak refs (PR #97340)
Maksim Levental
llvmlistbot at llvm.org
Mon Jul 1 13:06:40 PDT 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/97340
>From 3e6ed96c23de05d492f29de3bfbc9b5323021808 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 1 Jul 2024 14:34:22 -0500
Subject: [PATCH] [mlir][operation] weak refs
---
mlir/include/mlir/IR/Operation.h | 19 +++++++++++++------
mlir/lib/IR/Operation.cpp | 28 ++++++++++++++--------------
2 files changed, 27 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index f0dd7c5178056..3cc5a34a63ab4 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -210,7 +210,7 @@ class alignas(8) Operation final
Operation *cloneWithoutRegions();
/// Returns the operation block that contains this operation.
- Block *getBlock() { return block; }
+ Block *getBlock() { return blockHasWeakRefPair.getPointer(); }
/// Return the context this operation is associated with.
MLIRContext *getContext() { return location->getContext(); }
@@ -227,11 +227,15 @@ class alignas(8) Operation final
/// Returns the region to which the instruction belongs. Returns nullptr if
/// the instruction is unlinked.
- Region *getParentRegion() { return block ? block->getParent() : nullptr; }
+ Region *getParentRegion() {
+ return getBlock() ? getBlock()->getParent() : nullptr;
+ }
/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
- Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
+ Operation *getParentOp() {
+ return getBlock() ? getBlock()->getParentOp() : nullptr;
+ }
/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy>
@@ -1016,7 +1020,7 @@ class alignas(8) Operation final
/// requires a 'getParent() const' method. Once ilist_node removes this
/// constraint, we should drop the const to fit the rest of the MLIR const
/// model.
- Block *getParent() const { return block; }
+ Block *getParent() const { return blockHasWeakRefPair.getPointer(); }
/// Expose a few methods explicitly for the debugger to call for
/// visualization.
@@ -1031,8 +1035,11 @@ class alignas(8) Operation final
}
#endif
- /// The operation block that contains this operation.
- Block *block = nullptr;
+ /// The operation block that contains this operation and a bit that signifies
+ /// if the operation has a weak reference.
+ llvm::PointerIntPair<Block *, /*IntBits=*/1, bool> blockHasWeakRefPair;
+
+ bool hasWeakReference() { return blockHasWeakRefPair.getInt(); }
/// This holds information about the source location the operation was defined
/// or derived from.
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b51357198b1ca..ae44475a83d3e 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -322,8 +322,8 @@ void Operation::setAttrs(DictionaryAttr newAttrs) {
}
void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
if (getPropertiesStorageSize()) {
- // We're spliting the providing array of attributes by removing the inherentAttr
- // which will be stored in the properties.
+ // We're spliting the providing array of attributes by removing the
+ // inherentAttr which will be stored in the properties.
SmallVector<NamedAttribute> discardableAttrs;
discardableAttrs.reserve(newAttrs.size());
for (NamedAttribute attr : newAttrs) {
@@ -389,8 +389,8 @@ bool Operation::isBeforeInBlock(Operation *other) {
"Expected other operation to have the same parent block.");
// If the order of the block is already invalid, directly recompute the
// parent.
- if (!block->isOpOrderValid()) {
- block->recomputeOpOrder();
+ if (!getBlock()->isOpOrderValid()) {
+ getBlock()->recomputeOpOrder();
} else {
// Update the order either operation if necessary.
updateOrderIfNecessary();
@@ -408,8 +408,8 @@ void Operation::updateOrderIfNecessary() {
// If the order is valid for this operation there is nothing to do.
if (hasValidOrder())
return;
- Operation *blockFront = &block->front();
- Operation *blockBack = &block->back();
+ Operation *blockFront = &getBlock()->front();
+ Operation *blockBack = &getBlock()->back();
// This method is expected to only be invoked on blocks with more than one
// operation.
@@ -419,7 +419,7 @@ void Operation::updateOrderIfNecessary() {
if (this == blockBack) {
Operation *prevNode = getPrevNode();
if (!prevNode->hasValidOrder())
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
// Add the stride to the previous operation.
orderIndex = prevNode->orderIndex + kOrderStride;
@@ -431,10 +431,10 @@ void Operation::updateOrderIfNecessary() {
if (this == blockFront) {
Operation *nextNode = getNextNode();
if (!nextNode->hasValidOrder())
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
// There is no order to give this operation.
if (nextNode->orderIndex == 0)
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
// If we can't use the stride, just take the middle value left. This is safe
// because we know there is at least one valid index to assign to.
@@ -449,12 +449,12 @@ void Operation::updateOrderIfNecessary() {
// the middle of the previous and next if possible.
Operation *prevNode = getPrevNode(), *nextNode = getNextNode();
if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder())
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex;
// Check to see if there is a valid order between the two.
if (prevOrder + 1 == nextOrder)
- return block->recomputeOpOrder();
+ return getBlock()->recomputeOpOrder();
orderIndex = prevOrder + ((nextOrder - prevOrder) / 2);
}
@@ -502,7 +502,7 @@ Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
/// keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
assert(!op->getBlock() && "already in an operation block!");
- op->block = getContainingBlock();
+ op->blockHasWeakRefPair.setPointer(getContainingBlock());
// Invalidate the order on the operation.
op->orderIndex = Operation::kInvalidOrderIdx;
@@ -512,7 +512,7 @@ void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
/// We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
assert(op->block && "not already in an operation block!");
- op->block = nullptr;
+ op->blockHasWeakRefPair.setPointer(nullptr);
}
/// This is a trait method invoked when an operation is moved from one block
@@ -531,7 +531,7 @@ void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
// Update the 'block' member of each operation.
for (; first != last; ++first)
- first->block = curParent;
+ first->blockHasWeakRefPair.setPointer(curParent);
}
/// Remove this operation (and its descendants) from its Block and delete
More information about the Mlir-commits
mailing list