[Mlir-commits] [mlir] a5ea604 - [mlir] Update SCCP and the Inliner to use SymbolTableCollection for symbol lookups

River Riddle llvmlistbot at llvm.org
Fri Oct 16 12:09:27 PDT 2020


Author: River Riddle
Date: 2020-10-16T12:08:48-07:00
New Revision: a5ea60456c16faf7c75df98b03d5de5b9b6f506d

URL: https://github.com/llvm/llvm-project/commit/a5ea60456c16faf7c75df98b03d5de5b9b6f506d
DIFF: https://github.com/llvm/llvm-project/commit/a5ea60456c16faf7c75df98b03d5de5b9b6f506d.diff

LOG: [mlir] Update SCCP and the Inliner to use SymbolTableCollection for symbol lookups

This transforms the symbol lookups to O(1) from O(NM), greatly speeding up both passes. For a large MLIR module this shaved seconds off of the compilation time.

Differential Revision: https://reviews.llvm.org/D89522

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/CallGraph.h
    mlir/include/mlir/Interfaces/CallInterfaces.td
    mlir/lib/Analysis/CallGraph.cpp
    mlir/lib/Interfaces/CallInterfaces.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/lib/Transforms/SCCP.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index 189641502d56..ef751bf58289 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -27,6 +27,7 @@ class CallOpInterface;
 struct CallInterfaceCallable;
 class Operation;
 class Region;
+class SymbolTableCollection;
 
 //===----------------------------------------------------------------------===//
 // CallGraphNode
@@ -189,8 +190,11 @@ class CallGraph {
   }
 
   /// Resolve the callable for given callee to a node in the callgraph, or the
-  /// external node if a valid node was not resolved.
-  CallGraphNode *resolveCallable(CallOpInterface call) const;
+  /// external node if a valid node was not resolved. The provided symbol table
+  /// is used when resolving calls that reference callables via a symbol
+  /// reference.
+  CallGraphNode *resolveCallable(CallOpInterface call,
+                                 SymbolTableCollection &symbolTable) const;
 
   /// Erase the given node from the callgraph.
   void eraseNode(CallGraphNode *node);

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 7db6730c5e99..d51252305c23 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -46,19 +46,16 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
       }],
       "Operation::operand_range", "getArgOperands"
     >,
-    InterfaceMethod<[{
-        Resolve the callable operation for given callee to a
-        CallableOpInterface, or nullptr if a valid callable was not resolved.
-      }],
-      "Operation *", "resolveCallable", (ins), [{
-        // If the callable isn't a value, lookup the symbol reference.
-        CallInterfaceCallable callable = $_op.getCallableForCallee();
-        if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
-          return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef);
-        return callable.get<Value>().getDefiningOp();
-      }]
-    >,
   ];
+
+  let extraClassDeclaration = [{
+    /// Resolve the callable operation for given callee to a
+    /// CallableOpInterface, or nullptr if a valid callable was not resolved.
+    /// `symbolTable` is an optional parameter that will allow for using a
+    /// cached symbol table for symbol lookups instead of performing an O(N)
+    /// scan.
+    Operation *resolveCallable(SymbolTableCollection *symbolTable = nullptr);
+  }];
 }
 
 /// Interface for callable operations.

diff  --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index f9c8c48d275b..0f214446bbc6 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -68,13 +68,14 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
 /// Recursively compute the callgraph edges for the given operation. Computed
 /// edges are placed into the given callgraph object.
 static void computeCallGraph(Operation *op, CallGraph &cg,
+                             SymbolTableCollection &symbolTable,
                              CallGraphNode *parentNode, bool resolveCalls) {
   if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
     // If there is no parent node, we ignore this operation. Even if this
     // operation was a call, there would be no callgraph node to attribute it
     // to.
     if (resolveCalls && parentNode)
-      parentNode->addCallEdge(cg.resolveCallable(call));
+      parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
     return;
   }
 
@@ -88,15 +89,18 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
 
   for (Region &region : op->getRegions())
     for (Operation &nested : region.getOps())
-      computeCallGraph(&nested, cg, parentNode, resolveCalls);
+      computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
 }
 
 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
   // Make two passes over the graph, one to compute the callables and one to
   // resolve the calls. We split these up as we may have nested callable objects
   // that need to be reserved before the calls.
-  computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
-  computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
+  SymbolTableCollection symbolTable;
+  computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
+                   /*resolveCalls=*/false);
+  computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
+                   /*resolveCalls=*/true);
 }
 
 /// Get or add a call graph node for the given region.
@@ -109,16 +113,17 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region,
     node.reset(new CallGraphNode(region));
 
     // Add this node to the given parent node if necessary.
-    if (parentNode)
+    if (parentNode) {
       parentNode->addChildEdge(node.get());
-    else
+    } else {
       // Otherwise, connect all callable nodes to the external node, this allows
       // for conservatively including all callable nodes within the graph.
-      // FIXME(riverriddle) This isn't correct, this is only necessary for
-      // callable nodes that *could* be called from external sources. This
-      // requires extending the interface for callables to check if they may be
-      // referenced externally.
+      // FIXME This isn't correct, this is only necessary for callable nodes
+      // that *could* be called from external sources. This requires extending
+      // the interface for callables to check if they may be referenced
+      // externally.
       externalNode.addAbstractEdge(node.get());
+    }
   }
   return node.get();
 }
@@ -132,8 +137,10 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
 
 /// Resolve the callable for given callee to a node in the callgraph, or the
 /// external node if a valid node was not resolved.
-CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
-  Operation *callable = call.resolveCallable();
+CallGraphNode *
+CallGraph::resolveCallable(CallOpInterface call,
+                           SymbolTableCollection &symbolTable) const {
+  Operation *callable = call.resolveCallable(&symbolTable);
   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
     if (auto *node = lookupNode(callableOp.getCallableRegion()))
       return node;

diff  --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 580cd13de5ab..c7bcb7a96099 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -10,6 +10,27 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// CallOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Resolve the callable operation for given callee to a CallableOpInterface, or
+/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
+/// parameter that will allow for using a cached symbol table for symbol lookups
+/// instead of performing an O(N) scan.
+Operation *
+CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
+  CallInterfaceCallable callable = getCallableForCallee();
+  if (auto symbolVal = callable.dyn_cast<Value>())
+    return symbolVal.getDefiningOp();
+
+  // If the callable isn't a value, lookup the symbol reference.
+  auto symbolRef = callable.get<SymbolRefAttr>();
+  if (symbolTable)
+    return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
+  return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
+}
+
 //===----------------------------------------------------------------------===//
 // CallInterfaces
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 2ddb10a3a088..d4ab3cc4d549 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -33,7 +33,7 @@ using namespace mlir;
 
 /// Walk all of the used symbol callgraph nodes referenced with the given op.
 static void walkReferencedSymbolNodes(
-    Operation *op, CallGraph &cg,
+    Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
     function_ref<void(CallGraphNode *, Operation *)> callback) {
   auto symbolUses = SymbolTable::getSymbolUses(op);
@@ -47,8 +47,8 @@ static void walkReferencedSymbolNodes(
     // If this is the first instance of this reference, try to resolve a
     // callgraph node for it.
     if (refIt.second) {
-      auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp,
-                                                            use.getSymbolRef());
+      auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
+                                                           use.getSymbolRef());
       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
       if (!callableOp)
         continue;
@@ -80,7 +80,7 @@ struct CGUseList {
     DenseMap<CallGraphNode *, int> innerUses;
   };
 
-  CGUseList(Operation *op, CallGraph &cg);
+  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
 
   /// Drop uses of nodes referred to by the given call operation that resides
   /// within 'userNode'.
@@ -110,13 +110,19 @@ struct CGUseList {
   /// A mapping between a discardable callgraph node (that is a symbol) and the
   /// number of uses for this node.
   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
+
   /// A mapping between a callgraph node and the symbol callgraph nodes that it
   /// uses.
   DenseMap<CallGraphNode *, CGUser> nodeUses;
+
+  /// A symbol table to use when resolving call lookups.
+  SymbolTableCollection &symbolTable;
 };
 } // end anonymous namespace
 
-CGUseList::CGUseList(Operation *op, CallGraph &cg) {
+CGUseList::CGUseList(Operation *op, CallGraph &cg,
+                     SymbolTableCollection &symbolTable)
+    : symbolTable(symbolTable) {
   /// A set of callgraph nodes that are always known to be live during inlining.
   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
 
@@ -135,7 +141,7 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
         }
       }
       // Otherwise, check for any referenced nodes. These will be always-live.
-      walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
+      walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
                                 [](CallGraphNode *, Operation *) {});
     }
   };
@@ -162,7 +168,7 @@ void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
     --discardableSymNodeUses[node];
   };
   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
-  walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn);
+  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
 }
 
 void CGUseList::eraseNode(CallGraphNode *node) {
@@ -220,7 +226,7 @@ void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
       return;
     ++discardSymIt->second;
   };
-  walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn);
+  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
 }
 
 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
@@ -305,6 +311,7 @@ struct ResolvedCall {
 /// inside of nested callgraph nodes.
 static void collectCallOps(iterator_range<Region::iterator> blocks,
                            CallGraphNode *sourceNode, CallGraph &cg,
+                           SymbolTableCollection &symbolTable,
                            SmallVectorImpl<ResolvedCall> &calls,
                            bool traverseNestedCGNodes) {
   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
@@ -328,7 +335,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
             continue;
         }
 
-        CallGraphNode *targetNode = cg.resolveCallable(call);
+        CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
         if (!targetNode->isExternal())
           calls.emplace_back(call, sourceNode, targetNode);
         continue;
@@ -352,8 +359,9 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
 namespace {
 /// This class provides a specialization of the main inlining interface.
 struct Inliner : public InlinerInterface {
-  Inliner(MLIRContext *context, CallGraph &cg)
-      : InlinerInterface(context), cg(cg) {}
+  Inliner(MLIRContext *context, CallGraph &cg,
+          SymbolTableCollection &symbolTable)
+      : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
 
   /// Process a set of blocks that have been inlined. This callback is invoked
   /// *before* inlined terminator operations have been processed.
@@ -367,7 +375,7 @@ struct Inliner : public InlinerInterface {
       assert(region && "expected valid parent node");
     }
 
-    collectCallOps(inlinedBlocks, node, cg, calls,
+    collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
                    /*traverseNestedCGNodes=*/true);
   }
 
@@ -389,6 +397,9 @@ struct Inliner : public InlinerInterface {
 
   /// The callgraph being operated on.
   CallGraph &cg;
+
+  /// A symbol table to use when resolving call lookups.
+  SymbolTableCollection &symbolTable;
 };
 } // namespace
 
@@ -427,11 +438,12 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
       continue;
 
     // Don't collect calls if the node is already dead.
-    if (useList.isDead(node))
+    if (useList.isDead(node)) {
       deadNodes.push_back(node);
-    else
-      collectCallOps(*node->getCallableRegion(), node, cg, calls,
-                     /*traverseNestedCGNodes=*/false);
+    } else {
+      collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
+                     calls, /*traverseNestedCGNodes=*/false);
+    }
   }
 
   // Try to inline each of the call operations. Don't cache the end iterator
@@ -585,8 +597,9 @@ void InlinerPass::runOnOperation() {
     op->getCanonicalizationPatterns(canonPatterns, context);
 
   // Run the inline transform in post-order over the SCCs in the callgraph.
-  Inliner inliner(context, cg);
-  CGUseList useList(getOperation(), cg);
+  SymbolTableCollection symbolTable;
+  Inliner inliner(context, cg, symbolTable);
+  CGUseList useList(getOperation(), cg, symbolTable);
   runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
     inlineSCC(inliner, useList, scc, context, canonPatterns);
   });

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 95b035ee68cd..0a8f224f6fdd 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -304,6 +304,9 @@ class SCCPSolver {
   /// avoids re-resolving symbol references during propagation. Value based
   /// callables are trivial to resolve, so they can be done in-place.
   DenseMap<Operation *, Operation *> callToSymbolCallable;
+
+  /// A symbol table used for O(1) symbol lookups during simplification.
+  SymbolTableCollection symbolTable;
 };
 } // end anonymous namespace
 
@@ -425,7 +428,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
       // If the use is a call, track it to avoid the need to recompute the
       // reference later.
       if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
-        Operation *symCallable = callOp.resolveCallable();
+        Operation *symCallable = callOp.resolveCallable(&symbolTable);
         auto callableLatticeIt = callableLatticeState.find(symCallable);
         if (callableLatticeIt != callableLatticeState.end()) {
           callToSymbolCallable.try_emplace(callOp, symCallable);
@@ -438,7 +441,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
         continue;
       }
       // This use isn't a call, so don't we know all of the callers.
-      auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef());
+      auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
       auto it = callableLatticeState.find(symbol);
       if (it != callableLatticeState.end())
         markAllOverdefined(it->second.getCallableArguments());


        


More information about the Mlir-commits mailing list