[Mlir-commits] [mlir] 4a7aed4 - [mlir][IR] Add a new SymbolUserMap class
River Riddle
llvmlistbot at llvm.org
Tue Mar 9 15:08:07 PST 2021
Author: River Riddle
Date: 2021-03-09T15:07:52-08:00
New Revision: 4a7aed4ee739de5ad13bd631d4af727df7fc5849
URL: https://github.com/llvm/llvm-project/commit/4a7aed4ee739de5ad13bd631d4af727df7fc5849
DIFF: https://github.com/llvm/llvm-project/commit/4a7aed4ee739de5ad13bd631d4af727df7fc5849.diff
LOG: [mlir][IR] Add a new SymbolUserMap class
This class provides efficient implementations of symbol queries related to uses, such as collecting the users of a symbol, replacing all uses, etc. This provides similar benefits to use related queries, as SymbolTableCollection did for lookup queries.
Differential Revision: https://reviews.llvm.org/D98071
Added:
Modified:
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/IR/SymbolTable.cpp
mlir/test/lib/IR/TestSymbolUses.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 0db01cce6ce2..e0d626cd9dc4 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
@@ -260,6 +261,43 @@ class SymbolTableCollection {
DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
};
+//===----------------------------------------------------------------------===//
+// SymbolUserMap
+//===----------------------------------------------------------------------===//
+
+/// This class represents a map of symbols to users, and provides efficient
+/// implementations of symbol queries related to users; such as collecting the
+/// users of a symbol, replacing all uses, etc.
+class SymbolUserMap {
+public:
+ /// Build a user map for all of the symbols defined in regions nested under
+ /// 'symbolTableOp'. A reference to the provided symbol table collection is
+ /// kept by the user map to ensure efficient lookups, thus the lifetime should
+ /// extend beyond that of this map.
+ SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
+
+ /// Return the users of the provided symbol operation.
+ ArrayRef<Operation *> getUsers(Operation *symbol) const {
+ auto it = symbolToUsers.find(symbol);
+ return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None;
+ }
+
+ /// Return true if the given symbol has no uses.
+ bool use_empty(Operation *symbol) const {
+ return !symbolToUsers.count(symbol);
+ }
+
+ /// Replace all of the uses of the given symbol with `newSymbolName`.
+ void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName);
+
+private:
+ /// A reference to the symbol table used to construct this map.
+ SymbolTableCollection &symbolTable;
+
+ /// A map of symbol operations to symbol users.
+ DenseMap<Operation *, llvm::SetVector<Operation *>> symbolToUsers;
+};
+
//===----------------------------------------------------------------------===//
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 70133d22482f..4620a5bcb381 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -1000,6 +1000,61 @@ SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
return *it.first->second;
}
+//===----------------------------------------------------------------------===//
+// SymbolUserMap
+//===----------------------------------------------------------------------===//
+
+SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
+ Operation *symbolTableOp)
+ : symbolTable(symbolTable) {
+ // Walk each of the symbol tables looking for discardable callgraph nodes.
+ SmallVector<Operation *> symbols;
+ auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
+ for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
+ auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
+ assert(symbolUses && "expected uses to be valid");
+
+ for (const SymbolTable::SymbolUse &use : *symbolUses) {
+ symbols.clear();
+ (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
+ symbols);
+ for (Operation *symbolOp : symbols)
+ symbolToUsers[symbolOp].insert(use.getUser());
+ }
+ }
+ };
+ // We just set `allSymUsesVisible` to false here because it isn't necessary
+ // for building the user map.
+ SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
+ walkFn);
+}
+
+void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
+ StringRef newSymbolName) {
+ auto it = symbolToUsers.find(symbol);
+ if (it == symbolToUsers.end())
+ return;
+ llvm::SetVector<Operation *> &users = it->second;
+
+ // Replace the uses within the users of `symbol`.
+ for (Operation *user : users)
+ (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
+
+ // Move the current users of `symbol` to the new symbol if it is in the
+ // symbol table.
+ Operation *newSymbol =
+ symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
+ if (newSymbol != symbol) {
+ // Transfer over the users to the new symbol.
+ auto newIt = symbolToUsers.find(newSymbol);
+ if (newIt == symbolToUsers.end())
+ symbolToUsers.try_emplace(newSymbol, std::move(users));
+ else
+ newIt->second.set_union(users);
+ symbolToUsers.erase(symbol);
+ }
+}
+
//===----------------------------------------------------------------------===//
// Visibility parsing implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 8f9a6ae17e07..db4127d7522e 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -90,16 +90,20 @@ struct SymbolUsesPass
struct SymbolReplacementPass
: public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
void runOnOperation() override {
- auto module = getOperation();
+ ModuleOp module = getOperation();
+
+ // Don't try to replace if we can't collect symbol uses.
+ if (!SymbolTable::getSymbolUses(&module.getBodyRegion()))
+ return;
- // Walk nested functions and modules.
+ SymbolTableCollection symbolTable;
+ SymbolUserMap symbolUsers(symbolTable, module);
module.getBodyRegion().walk([&](Operation *nestedOp) {
StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
if (!newName)
return;
- if (succeeded(SymbolTable::replaceAllSymbolUses(
- nestedOp, newName.getValue(), &module.getBodyRegion())))
- SymbolTable::setSymbolName(nestedOp, newName.getValue());
+ symbolUsers.replaceAllUsesWith(nestedOp, newName.getValue());
+ SymbolTable::setSymbolName(nestedOp, newName.getValue());
});
}
};
More information about the Mlir-commits
mailing list