[Mlir-commits] [mlir] 5c159b9 - [mlir] Add a utility method on CallOpInterface for resolving the callable.

River Riddle llvmlistbot at llvm.org
Sat Feb 8 10:46:54 PST 2020


Author: River Riddle
Date: 2020-02-08T10:44:29-08:00
New Revision: 5c159b91a24b07974328ab17fd56a244995f2944

URL: https://github.com/llvm/llvm-project/commit/5c159b91a24b07974328ab17fd56a244995f2944
DIFF: https://github.com/llvm/llvm-project/commit/5c159b91a24b07974328ab17fd56a244995f2944.diff

LOG: [mlir] Add a utility method on CallOpInterface for resolving the callable.

Summary: This is the most common operation performed on a CallOpInterface. This just moves the existing functionality from the CallGraph so that other users can access it.

Differential Revision: https://reviews.llvm.org/D74250

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/CallGraph.h
    mlir/include/mlir/Analysis/CallInterfaces.h
    mlir/include/mlir/Analysis/CallInterfaces.td
    mlir/lib/Analysis/CallGraph.cpp
    mlir/lib/Transforms/Inliner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index 6b486ba2d1ac..cd25151da4c0 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -23,6 +23,7 @@
 #include "llvm/ADT/SetVector.h"
 
 namespace mlir {
+class CallOpInterface;
 struct CallInterfaceCallable;
 class Operation;
 class Region;
@@ -188,11 +189,8 @@ class CallGraph {
   }
 
   /// Resolve the callable for given callee to a node in the callgraph, or the
-  /// external node if a valid node was not resolved. 'from' provides an anchor
-  /// for symbol table lookups, and is only required if the callable is a symbol
-  /// reference.
-  CallGraphNode *resolveCallable(CallInterfaceCallable callable,
-                                 Operation *from = nullptr) const;
+  /// external node if a valid node was not resolved.
+  CallGraphNode *resolveCallable(CallOpInterface call) const;
 
   /// An iterator over the nodes of the graph.
   using iterator = NodeIterator;

diff  --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h
index e0d9ba2fcbdc..ebefc88e21a9 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.h
+++ b/mlir/include/mlir/Analysis/CallInterfaces.h
@@ -14,11 +14,10 @@
 #ifndef MLIR_ANALYSIS_CALLINTERFACES_H
 #define MLIR_ANALYSIS_CALLINTERFACES_H
 
-#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
 #include "llvm/ADT/PointerUnion.h"
 
 namespace mlir {
-
 /// A callable is either a symbol, or an SSA value, that is referenced by a
 /// call-like operation. This represents the destination of the call.
 struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {

diff  --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td
index 7dd58a48fa1d..2bc59c224b4f 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.td
+++ b/mlir/include/mlir/Analysis/CallInterfaces.td
@@ -44,6 +44,18 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
       }],
       "Operation::operand_range", "getArgOperands"
     >,
+    InterfaceMethod<[{
+        Resolve the callable operation for given callee to a
+        CallableOpInterface, or nullptr if a valid callable was not resolved.
+      }],
+      "Operation *", "resolveCallable", (ins), [{
+        // If the callable isn't a value, lookup the symbol reference.
+        CallInterfaceCallable callable = op.getCallableForCallee();
+        if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
+          return SymbolTable::lookupNearestSymbolFrom(op, symbolRef);
+        return callable.get<Value>().getDefiningOp();
+      }]
+    >,
   ];
 }
 

diff  --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 0f3bfbb653c8..d61b2359e691 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -79,10 +79,8 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
     // If there is no parent node, we ignore this operation. Even if this
     // operation was a call, there would be no callgraph node to attribute it
     // to.
-    if (!resolveCalls || !parentNode)
-      return;
-    parentNode->addCallEdge(
-        cg.resolveCallable(call.getCallableForCallee(), op));
+    if (resolveCalls && parentNode)
+      parentNode->addCallEdge(cg.resolveCallable(call));
     return;
   }
 
@@ -141,23 +139,11 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
 
 /// Resolve the callable for given callee to a node in the callgraph, or the
 /// external node if a valid node was not resolved.
-CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
-                                          Operation *from) const {
-  // Get the callee operation from the callable.
-  Operation *callee;
-  if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
-    callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
-  else
-    callee = callable.get<Value>().getDefiningOp();
-
-  // If the callee is non-null and is a valid callable object, try to get the
-  // called region from it.
-  if (callee && callee->getNumRegions()) {
-    if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
-      if (auto *node = lookupNode(callableOp.getCallableRegion()))
-        return node;
-    }
-  }
+CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
+  Operation *callable = call.resolveCallable();
+  if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
+    if (auto *node = lookupNode(callableOp.getCallableRegion()))
+      return node;
 
   // If we don't have a valid direct region, this is an external call.
   return getExternalNode();

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 3b6855366657..b6fcf8bc3941 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -86,15 +86,14 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
   while (!worklist.empty()) {
     for (Operation &op : *worklist.pop_back_val()) {
       if (auto call = dyn_cast<CallOpInterface>(op)) {
-        CallInterfaceCallable callable = call.getCallableForCallee();
-
         // TODO(riverriddle) Support inlining nested call references.
+        CallInterfaceCallable callable = call.getCallableForCallee();
         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
           if (!symRef.isa<FlatSymbolRefAttr>())
             continue;
         }
 
-        CallGraphNode *node = cg.resolveCallable(callable, &op);
+        CallGraphNode *node = cg.resolveCallable(call);
         if (!node->isExternal())
           calls.emplace_back(call, node);
         continue;


        


More information about the Mlir-commits mailing list