[Mlir-commits] [mlir] a5ea604 - [mlir] Update SCCP and the Inliner to use SymbolTableCollection for symbol lookups
River Riddle
llvmlistbot at llvm.org
Fri Oct 16 12:09:27 PDT 2020
Author: River Riddle
Date: 2020-10-16T12:08:48-07:00
New Revision: a5ea60456c16faf7c75df98b03d5de5b9b6f506d
URL: https://github.com/llvm/llvm-project/commit/a5ea60456c16faf7c75df98b03d5de5b9b6f506d
DIFF: https://github.com/llvm/llvm-project/commit/a5ea60456c16faf7c75df98b03d5de5b9b6f506d.diff
LOG: [mlir] Update SCCP and the Inliner to use SymbolTableCollection for symbol lookups
This transforms the symbol lookups to O(1) from O(NM), greatly speeding up both passes. For a large MLIR module this shaved seconds off of the compilation time.
Differential Revision: https://reviews.llvm.org/D89522
Added:
Modified:
mlir/include/mlir/Analysis/CallGraph.h
mlir/include/mlir/Interfaces/CallInterfaces.td
mlir/lib/Analysis/CallGraph.cpp
mlir/lib/Interfaces/CallInterfaces.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/SCCP.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index 189641502d56..ef751bf58289 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -27,6 +27,7 @@ class CallOpInterface;
struct CallInterfaceCallable;
class Operation;
class Region;
+class SymbolTableCollection;
//===----------------------------------------------------------------------===//
// CallGraphNode
@@ -189,8 +190,11 @@ class CallGraph {
}
/// Resolve the callable for given callee to a node in the callgraph, or the
- /// external node if a valid node was not resolved.
- CallGraphNode *resolveCallable(CallOpInterface call) const;
+ /// external node if a valid node was not resolved. The provided symbol table
+ /// is used when resolving calls that reference callables via a symbol
+ /// reference.
+ CallGraphNode *resolveCallable(CallOpInterface call,
+ SymbolTableCollection &symbolTable) const;
/// Erase the given node from the callgraph.
void eraseNode(CallGraphNode *node);
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 7db6730c5e99..d51252305c23 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -46,19 +46,16 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"Operation::operand_range", "getArgOperands"
>,
- InterfaceMethod<[{
- Resolve the callable operation for given callee to a
- CallableOpInterface, or nullptr if a valid callable was not resolved.
- }],
- "Operation *", "resolveCallable", (ins), [{
- // If the callable isn't a value, lookup the symbol reference.
- CallInterfaceCallable callable = $_op.getCallableForCallee();
- if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
- return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef);
- return callable.get<Value>().getDefiningOp();
- }]
- >,
];
+
+ let extraClassDeclaration = [{
+ /// Resolve the callable operation for given callee to a
+ /// CallableOpInterface, or nullptr if a valid callable was not resolved.
+ /// `symbolTable` is an optional parameter that will allow for using a
+ /// cached symbol table for symbol lookups instead of performing an O(N)
+ /// scan.
+ Operation *resolveCallable(SymbolTableCollection *symbolTable = nullptr);
+ }];
}
/// Interface for callable operations.
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index f9c8c48d275b..0f214446bbc6 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -68,13 +68,14 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
/// Recursively compute the callgraph edges for the given operation. Computed
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
+ SymbolTableCollection &symbolTable,
CallGraphNode *parentNode, bool resolveCalls) {
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
// If there is no parent node, we ignore this operation. Even if this
// operation was a call, there would be no callgraph node to attribute it
// to.
if (resolveCalls && parentNode)
- parentNode->addCallEdge(cg.resolveCallable(call));
+ parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
return;
}
@@ -88,15 +89,18 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
for (Region ®ion : op->getRegions())
for (Operation &nested : region.getOps())
- computeCallGraph(&nested, cg, parentNode, resolveCalls);
+ computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
// Make two passes over the graph, one to compute the callables and one to
// resolve the calls. We split these up as we may have nested callable objects
// that need to be reserved before the calls.
- computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
- computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
+ SymbolTableCollection symbolTable;
+ computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
+ /*resolveCalls=*/false);
+ computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
+ /*resolveCalls=*/true);
}
/// Get or add a call graph node for the given region.
@@ -109,16 +113,17 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region,
node.reset(new CallGraphNode(region));
// Add this node to the given parent node if necessary.
- if (parentNode)
+ if (parentNode) {
parentNode->addChildEdge(node.get());
- else
+ } else {
// Otherwise, connect all callable nodes to the external node, this allows
// for conservatively including all callable nodes within the graph.
- // FIXME(riverriddle) This isn't correct, this is only necessary for
- // callable nodes that *could* be called from external sources. This
- // requires extending the interface for callables to check if they may be
- // referenced externally.
+ // FIXME This isn't correct, this is only necessary for callable nodes
+ // that *could* be called from external sources. This requires extending
+ // the interface for callables to check if they may be referenced
+ // externally.
externalNode.addAbstractEdge(node.get());
+ }
}
return node.get();
}
@@ -132,8 +137,10 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
/// Resolve the callable for given callee to a node in the callgraph, or the
/// external node if a valid node was not resolved.
-CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
- Operation *callable = call.resolveCallable();
+CallGraphNode *
+CallGraph::resolveCallable(CallOpInterface call,
+ SymbolTableCollection &symbolTable) const {
+ Operation *callable = call.resolveCallable(&symbolTable);
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 580cd13de5ab..c7bcb7a96099 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -10,6 +10,27 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// CallOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Resolve the callable operation for given callee to a CallableOpInterface, or
+/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
+/// parameter that will allow for using a cached symbol table for symbol lookups
+/// instead of performing an O(N) scan.
+Operation *
+CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
+ CallInterfaceCallable callable = getCallableForCallee();
+ if (auto symbolVal = callable.dyn_cast<Value>())
+ return symbolVal.getDefiningOp();
+
+ // If the callable isn't a value, lookup the symbol reference.
+ auto symbolRef = callable.get<SymbolRefAttr>();
+ if (symbolTable)
+ return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
+ return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
+}
+
//===----------------------------------------------------------------------===//
// CallInterfaces
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 2ddb10a3a088..d4ab3cc4d549 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -33,7 +33,7 @@ using namespace mlir;
/// Walk all of the used symbol callgraph nodes referenced with the given op.
static void walkReferencedSymbolNodes(
- Operation *op, CallGraph &cg,
+ Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
function_ref<void(CallGraphNode *, Operation *)> callback) {
auto symbolUses = SymbolTable::getSymbolUses(op);
@@ -47,8 +47,8 @@ static void walkReferencedSymbolNodes(
// If this is the first instance of this reference, try to resolve a
// callgraph node for it.
if (refIt.second) {
- auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp,
- use.getSymbolRef());
+ auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
+ use.getSymbolRef());
auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
if (!callableOp)
continue;
@@ -80,7 +80,7 @@ struct CGUseList {
DenseMap<CallGraphNode *, int> innerUses;
};
- CGUseList(Operation *op, CallGraph &cg);
+ CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
/// Drop uses of nodes referred to by the given call operation that resides
/// within 'userNode'.
@@ -110,13 +110,19 @@ struct CGUseList {
/// A mapping between a discardable callgraph node (that is a symbol) and the
/// number of uses for this node.
DenseMap<CallGraphNode *, int> discardableSymNodeUses;
+
/// A mapping between a callgraph node and the symbol callgraph nodes that it
/// uses.
DenseMap<CallGraphNode *, CGUser> nodeUses;
+
+ /// A symbol table to use when resolving call lookups.
+ SymbolTableCollection &symbolTable;
};
} // end anonymous namespace
-CGUseList::CGUseList(Operation *op, CallGraph &cg) {
+CGUseList::CGUseList(Operation *op, CallGraph &cg,
+ SymbolTableCollection &symbolTable)
+ : symbolTable(symbolTable) {
/// A set of callgraph nodes that are always known to be live during inlining.
DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
@@ -135,7 +141,7 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
}
}
// Otherwise, check for any referenced nodes. These will be always-live.
- walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
+ walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
[](CallGraphNode *, Operation *) {});
}
};
@@ -162,7 +168,7 @@ void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
--discardableSymNodeUses[node];
};
DenseMap<Attribute, CallGraphNode *> resolvedRefs;
- walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn);
+ walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
}
void CGUseList::eraseNode(CallGraphNode *node) {
@@ -220,7 +226,7 @@ void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
return;
++discardSymIt->second;
};
- walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn);
+ walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
}
void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
@@ -305,6 +311,7 @@ struct ResolvedCall {
/// inside of nested callgraph nodes.
static void collectCallOps(iterator_range<Region::iterator> blocks,
CallGraphNode *sourceNode, CallGraph &cg,
+ SymbolTableCollection &symbolTable,
SmallVectorImpl<ResolvedCall> &calls,
bool traverseNestedCGNodes) {
SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
@@ -328,7 +335,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
continue;
}
- CallGraphNode *targetNode = cg.resolveCallable(call);
+ CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
if (!targetNode->isExternal())
calls.emplace_back(call, sourceNode, targetNode);
continue;
@@ -352,8 +359,9 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
namespace {
/// This class provides a specialization of the main inlining interface.
struct Inliner : public InlinerInterface {
- Inliner(MLIRContext *context, CallGraph &cg)
- : InlinerInterface(context), cg(cg) {}
+ Inliner(MLIRContext *context, CallGraph &cg,
+ SymbolTableCollection &symbolTable)
+ : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
/// Process a set of blocks that have been inlined. This callback is invoked
/// *before* inlined terminator operations have been processed.
@@ -367,7 +375,7 @@ struct Inliner : public InlinerInterface {
assert(region && "expected valid parent node");
}
- collectCallOps(inlinedBlocks, node, cg, calls,
+ collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
/*traverseNestedCGNodes=*/true);
}
@@ -389,6 +397,9 @@ struct Inliner : public InlinerInterface {
/// The callgraph being operated on.
CallGraph &cg;
+
+ /// A symbol table to use when resolving call lookups.
+ SymbolTableCollection &symbolTable;
};
} // namespace
@@ -427,11 +438,12 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
continue;
// Don't collect calls if the node is already dead.
- if (useList.isDead(node))
+ if (useList.isDead(node)) {
deadNodes.push_back(node);
- else
- collectCallOps(*node->getCallableRegion(), node, cg, calls,
- /*traverseNestedCGNodes=*/false);
+ } else {
+ collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
+ calls, /*traverseNestedCGNodes=*/false);
+ }
}
// Try to inline each of the call operations. Don't cache the end iterator
@@ -585,8 +597,9 @@ void InlinerPass::runOnOperation() {
op->getCanonicalizationPatterns(canonPatterns, context);
// Run the inline transform in post-order over the SCCs in the callgraph.
- Inliner inliner(context, cg);
- CGUseList useList(getOperation(), cg);
+ SymbolTableCollection symbolTable;
+ Inliner inliner(context, cg, symbolTable);
+ CGUseList useList(getOperation(), cg, symbolTable);
runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
inlineSCC(inliner, useList, scc, context, canonPatterns);
});
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 95b035ee68cd..0a8f224f6fdd 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -304,6 +304,9 @@ class SCCPSolver {
/// avoids re-resolving symbol references during propagation. Value based
/// callables are trivial to resolve, so they can be done in-place.
DenseMap<Operation *, Operation *> callToSymbolCallable;
+
+ /// A symbol table used for O(1) symbol lookups during simplification.
+ SymbolTableCollection symbolTable;
};
} // end anonymous namespace
@@ -425,7 +428,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
// If the use is a call, track it to avoid the need to recompute the
// reference later.
if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
- Operation *symCallable = callOp.resolveCallable();
+ Operation *symCallable = callOp.resolveCallable(&symbolTable);
auto callableLatticeIt = callableLatticeState.find(symCallable);
if (callableLatticeIt != callableLatticeState.end()) {
callToSymbolCallable.try_emplace(callOp, symCallable);
@@ -438,7 +441,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
continue;
}
// This use isn't a call, so don't we know all of the callers.
- auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef());
+ auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
auto it = callableLatticeState.find(symbol);
if (it != callableLatticeState.end())
markAllOverdefined(it->second.getCallableArguments());
More information about the Mlir-commits
mailing list