[Mlir-commits] [mlir] 4a7aed4 - [mlir][IR] Add a new SymbolUserMap class

River Riddle llvmlistbot at llvm.org
Tue Mar 9 15:08:07 PST 2021


Author: River Riddle
Date: 2021-03-09T15:07:52-08:00
New Revision: 4a7aed4ee739de5ad13bd631d4af727df7fc5849

URL: https://github.com/llvm/llvm-project/commit/4a7aed4ee739de5ad13bd631d4af727df7fc5849
DIFF: https://github.com/llvm/llvm-project/commit/4a7aed4ee739de5ad13bd631d4af727df7fc5849.diff

LOG: [mlir][IR] Add a new SymbolUserMap class

This class provides efficient implementations of symbol queries related to uses, such as collecting the users of a symbol, replacing all uses, etc. This provides similar benefits to use related queries, as SymbolTableCollection did for lookup queries.

Differential Revision: https://reviews.llvm.org/D98071

Added: 
    

Modified: 
    mlir/include/mlir/IR/SymbolTable.h
    mlir/lib/IR/SymbolTable.cpp
    mlir/test/lib/IR/TestSymbolUses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 0db01cce6ce2..e0d626cd9dc4 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringMap.h"
 
 namespace mlir {
@@ -260,6 +261,43 @@ class SymbolTableCollection {
   DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
 };
 
+//===----------------------------------------------------------------------===//
+// SymbolUserMap
+//===----------------------------------------------------------------------===//
+
+/// This class represents a map of symbols to users, and provides efficient
+/// implementations of symbol queries related to users; such as collecting the
+/// users of a symbol, replacing all uses, etc.
+class SymbolUserMap {
+public:
+  /// Build a user map for all of the symbols defined in regions nested under
+  /// 'symbolTableOp'. A reference to the provided symbol table collection is
+  /// kept by the user map to ensure efficient lookups, thus the lifetime should
+  /// extend beyond that of this map.
+  SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
+
+  /// Return the users of the provided symbol operation.
+  ArrayRef<Operation *> getUsers(Operation *symbol) const {
+    auto it = symbolToUsers.find(symbol);
+    return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None;
+  }
+
+  /// Return true if the given symbol has no uses.
+  bool use_empty(Operation *symbol) const {
+    return !symbolToUsers.count(symbol);
+  }
+
+  /// Replace all of the uses of the given symbol with `newSymbolName`.
+  void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName);
+
+private:
+  /// A reference to the symbol table used to construct this map.
+  SymbolTableCollection &symbolTable;
+
+  /// A map of symbol operations to symbol users.
+  DenseMap<Operation *, llvm::SetVector<Operation *>> symbolToUsers;
+};
+
 //===----------------------------------------------------------------------===//
 // SymbolTable Trait Types
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 70133d22482f..4620a5bcb381 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -1000,6 +1000,61 @@ SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
   return *it.first->second;
 }
 
+//===----------------------------------------------------------------------===//
+// SymbolUserMap
+//===----------------------------------------------------------------------===//
+
+SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
+                             Operation *symbolTableOp)
+    : symbolTable(symbolTable) {
+  // Walk each of the symbol tables looking for discardable callgraph nodes.
+  SmallVector<Operation *> symbols;
+  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
+    for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
+      auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
+      assert(symbolUses && "expected uses to be valid");
+
+      for (const SymbolTable::SymbolUse &use : *symbolUses) {
+        symbols.clear();
+        (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
+                                         symbols);
+        for (Operation *symbolOp : symbols)
+          symbolToUsers[symbolOp].insert(use.getUser());
+      }
+    }
+  };
+  // We just set `allSymUsesVisible` to false here because it isn't necessary
+  // for building the user map.
+  SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
+                                walkFn);
+}
+
+void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
+                                       StringRef newSymbolName) {
+  auto it = symbolToUsers.find(symbol);
+  if (it == symbolToUsers.end())
+    return;
+  llvm::SetVector<Operation *> &users = it->second;
+
+  // Replace the uses within the users of `symbol`.
+  for (Operation *user : users)
+    (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
+
+  // Move the current users of `symbol` to the new symbol if it is in the
+  // symbol table.
+  Operation *newSymbol =
+      symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
+  if (newSymbol != symbol) {
+    // Transfer over the users to the new symbol.
+    auto newIt = symbolToUsers.find(newSymbol);
+    if (newIt == symbolToUsers.end())
+      symbolToUsers.try_emplace(newSymbol, std::move(users));
+    else
+      newIt->second.set_union(users);
+    symbolToUsers.erase(symbol);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Visibility parsing implementation.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 8f9a6ae17e07..db4127d7522e 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -90,16 +90,20 @@ struct SymbolUsesPass
 struct SymbolReplacementPass
     : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
   void runOnOperation() override {
-    auto module = getOperation();
+    ModuleOp module = getOperation();
+
+    // Don't try to replace if we can't collect symbol uses.
+    if (!SymbolTable::getSymbolUses(&module.getBodyRegion()))
+      return;
 
-    // Walk nested functions and modules.
+    SymbolTableCollection symbolTable;
+    SymbolUserMap symbolUsers(symbolTable, module);
     module.getBodyRegion().walk([&](Operation *nestedOp) {
       StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
       if (!newName)
         return;
-      if (succeeded(SymbolTable::replaceAllSymbolUses(
-              nestedOp, newName.getValue(), &module.getBodyRegion())))
-        SymbolTable::setSymbolName(nestedOp, newName.getValue());
+      symbolUsers.replaceAllUsesWith(nestedOp, newName.getValue());
+      SymbolTable::setSymbolName(nestedOp, newName.getValue());
     });
   }
 };


        


More information about the Mlir-commits mailing list