[Mlir-commits] [mlir] 2cfc66a - [mlir] Add a SharedSymbolTableCollection class
Jeff Niu
llvmlistbot at llvm.org
Wed Feb 22 13:57:37 PST 2023
Author: Jeff Niu
Date: 2023-02-22T13:57:30-08:00
New Revision: 2cfc66a6fa660fb0b03de40815392f25038ccfba
URL: https://github.com/llvm/llvm-project/commit/2cfc66a6fa660fb0b03de40815392f25038ccfba
DIFF: https://github.com/llvm/llvm-project/commit/2cfc66a6fa660fb0b03de40815392f25038ccfba.diff
LOG: [mlir] Add a SharedSymbolTableCollection class
This class wraps a `SymbolTableCollection` to allow shared access to the
collection of symbol tables (but not the individual symbol tables). This
can be used, for example, in a pass that shards work among symbols that
requires symbol lookups.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D144507
Added:
Modified:
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/IR/SymbolTable.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 7b325c8b896ce..33427788a075e 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -13,6 +13,7 @@
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/RWMutex.h"
namespace mlir {
@@ -281,10 +282,66 @@ class SymbolTableCollection {
SymbolTable &getSymbolTable(Operation *op);
private:
+ friend class LockedSymbolTableCollection;
+
/// The constructed symbol tables nested within this table.
DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
};
+//===----------------------------------------------------------------------===//
+// LockedSymbolTableCollection
+//===----------------------------------------------------------------------===//
+
+/// This class implements a lock-based shared wrapper around a symbol table
+/// collection that allows shared access to the collection of symbol tables.
+/// This class does not protect shared access to individual symbol tables.
+/// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for
+/// symbol table operations, making read operations not thread-safe. This class
+/// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the
+/// lazy `SymbolTable` lookup.
+class LockedSymbolTableCollection : public SymbolTableCollection {
+public:
+ explicit LockedSymbolTableCollection(SymbolTableCollection &collection)
+ : collection(collection) {}
+
+ /// Look up a symbol with the specified name within the specified symbol table
+ /// operation, returning null if no such name exists.
+ Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
+ /// Look up a symbol with the specified name within the specified symbol table
+ /// operation, returning null if no such name exists.
+ Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol);
+ /// Look up a potentially nested symbol within the specified symbol table
+ /// operation, returning null if no such symbol exists.
+ Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
+
+ /// Lookup a symbol of a particular kind within the specified symbol table,
+ /// returning null if the symbol was not found.
+ template <typename T, typename NameT>
+ T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
+ return dyn_cast_or_null<T>(
+ lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
+ }
+
+ /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
+ /// by a given SymbolRefAttr when resolved within the provided symbol table
+ /// operation. Returns failure if any of the nested references could not be
+ /// resolved.
+ LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
+ SmallVectorImpl<Operation *> &symbols);
+
+private:
+ /// Get the symbol table for the symbol table operation, constructing if it
+ /// does not exist. This function provides thread safety over `collection`
+ /// by locking when performing the lookup and when inserting
+ /// lazily-constructed symbol tables.
+ SymbolTable &getSymbolTable(Operation *symbolTableOp);
+
+ /// The symbol tables to manage.
+ SymbolTableCollection &collection;
+ /// The mutex protecting access to the symbol table collection.
+ llvm::sys::SmartRWMutex<true> mutex;
+};
+
//===----------------------------------------------------------------------===//
// SymbolUserMap
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 446cce3db4132..2f44514127522 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -916,6 +916,58 @@ SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
return *it.first->second;
}
+//===----------------------------------------------------------------------===//
+// LockedSymbolTableCollection
+//===----------------------------------------------------------------------===//
+
+Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+ StringAttr symbol) {
+ return getSymbolTable(symbolTableOp).lookup(symbol);
+}
+
+Operation *
+LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+ FlatSymbolRefAttr symbol) {
+ return lookupSymbolIn(symbolTableOp, symbol.getAttr());
+}
+
+Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+ SymbolRefAttr name) {
+ SmallVector<Operation *> symbols;
+ if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
+ return nullptr;
+ return symbols.back();
+}
+
+LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
+ Operation *symbolTableOp, SymbolRefAttr name,
+ SmallVectorImpl<Operation *> &symbols) {
+ auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
+ return lookupSymbolIn(symbolTableOp, symbol);
+ };
+ return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
+}
+
+SymbolTable &
+LockedSymbolTableCollectio::getSymbolTable(Operation *symbolTableOp) {
+ assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
+ // Try to find an existing symbol table.
+ {
+ llvm::sys::SmartScopedReader<true> lock(mutex);
+ auto it = collection.symbolTables.find(symbolTableOp);
+ if (it != collection.symbolTables.end())
+ return *it->second;
+ }
+ // Create a symbol table for the operation. Perform construction outside of
+ // the critical section.
+ auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
+ // Insert the constructed symbol table.
+ llvm::sys::SmartScopedWriter<true> lock(mutex);
+ return *collection.symbolTables
+ .insert({symbolTableOp, std::move(symbolTable)})
+ .first->second;
+}
+
//===----------------------------------------------------------------------===//
// SymbolUserMap
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list