[Mlir-commits] [mlir] 2cfc66a - [mlir] Add a SharedSymbolTableCollection class

Jeff Niu llvmlistbot at llvm.org
Wed Feb 22 13:57:37 PST 2023


Author: Jeff Niu
Date: 2023-02-22T13:57:30-08:00
New Revision: 2cfc66a6fa660fb0b03de40815392f25038ccfba

URL: https://github.com/llvm/llvm-project/commit/2cfc66a6fa660fb0b03de40815392f25038ccfba
DIFF: https://github.com/llvm/llvm-project/commit/2cfc66a6fa660fb0b03de40815392f25038ccfba.diff

LOG: [mlir] Add a SharedSymbolTableCollection class

This class wraps a `SymbolTableCollection` to allow shared access to the
collection of symbol tables (but not the individual symbol tables). This
can be used, for example, in a pass that shards work among symbols that
requires symbol lookups.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D144507

Added: 
    

Modified: 
    mlir/include/mlir/IR/SymbolTable.h
    mlir/lib/IR/SymbolTable.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 7b325c8b896ce..33427788a075e 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -13,6 +13,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringMap.h"
+#include "llvm/Support/RWMutex.h"
 
 namespace mlir {
 
@@ -281,10 +282,66 @@ class SymbolTableCollection {
   SymbolTable &getSymbolTable(Operation *op);
 
 private:
+  friend class LockedSymbolTableCollection;
+
   /// The constructed symbol tables nested within this table.
   DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
 };
 
+//===----------------------------------------------------------------------===//
+// LockedSymbolTableCollection
+//===----------------------------------------------------------------------===//
+
+/// This class implements a lock-based shared wrapper around a symbol table
+/// collection that allows shared access to the collection of symbol tables.
+/// This class does not protect shared access to individual symbol tables.
+/// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for
+/// symbol table operations, making read operations not thread-safe. This class
+/// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the
+/// lazy `SymbolTable` lookup.
+class LockedSymbolTableCollection : public SymbolTableCollection {
+public:
+  explicit LockedSymbolTableCollection(SymbolTableCollection &collection)
+      : collection(collection) {}
+
+  /// Look up a symbol with the specified name within the specified symbol table
+  /// operation, returning null if no such name exists.
+  Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
+  /// Look up a symbol with the specified name within the specified symbol table
+  /// operation, returning null if no such name exists.
+  Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol);
+  /// Look up a potentially nested symbol within the specified symbol table
+  /// operation, returning null if no such symbol exists.
+  Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
+
+  /// Lookup a symbol of a particular kind within the specified symbol table,
+  /// returning null if the symbol was not found.
+  template <typename T, typename NameT>
+  T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
+    return dyn_cast_or_null<T>(
+        lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
+  }
+
+  /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
+  /// by a given SymbolRefAttr when resolved within the provided symbol table
+  /// operation. Returns failure if any of the nested references could not be
+  /// resolved.
+  LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
+                               SmallVectorImpl<Operation *> &symbols);
+
+private:
+  /// Get the symbol table for the symbol table operation, constructing if it
+  /// does not exist. This function provides thread safety over `collection`
+  /// by locking when performing the lookup and when inserting
+  /// lazily-constructed symbol tables.
+  SymbolTable &getSymbolTable(Operation *symbolTableOp);
+
+  /// The symbol tables to manage.
+  SymbolTableCollection &collection;
+  /// The mutex protecting access to the symbol table collection.
+  llvm::sys::SmartRWMutex<true> mutex;
+};
+
 //===----------------------------------------------------------------------===//
 // SymbolUserMap
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 446cce3db4132..2f44514127522 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -916,6 +916,58 @@ SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
   return *it.first->second;
 }
 
+//===----------------------------------------------------------------------===//
+// LockedSymbolTableCollection
+//===----------------------------------------------------------------------===//
+
+Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+                                                       StringAttr symbol) {
+  return getSymbolTable(symbolTableOp).lookup(symbol);
+}
+
+Operation *
+LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+                                            FlatSymbolRefAttr symbol) {
+  return lookupSymbolIn(symbolTableOp, symbol.getAttr());
+}
+
+Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+                                                       SymbolRefAttr name) {
+  SmallVector<Operation *> symbols;
+  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
+    return nullptr;
+  return symbols.back();
+}
+
+LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
+    Operation *symbolTableOp, SymbolRefAttr name,
+    SmallVectorImpl<Operation *> &symbols) {
+  auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
+    return lookupSymbolIn(symbolTableOp, symbol);
+  };
+  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
+}
+
+SymbolTable &
+LockedSymbolTableCollectio::getSymbolTable(Operation *symbolTableOp) {
+  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
+  // Try to find an existing symbol table.
+  {
+    llvm::sys::SmartScopedReader<true> lock(mutex);
+    auto it = collection.symbolTables.find(symbolTableOp);
+    if (it != collection.symbolTables.end())
+      return *it->second;
+  }
+  // Create a symbol table for the operation. Perform construction outside of
+  // the critical section.
+  auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
+  // Insert the constructed symbol table.
+  llvm::sys::SmartScopedWriter<true> lock(mutex);
+  return *collection.symbolTables
+              .insert({symbolTableOp, std::move(symbolTable)})
+              .first->second;
+}
+
 //===----------------------------------------------------------------------===//
 // SymbolUserMap
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list