[Mlir-commits] [mlir] 5c159b9 - [mlir] Add a utility method on CallOpInterface for resolving the callable.
River Riddle
llvmlistbot at llvm.org
Sat Feb 8 10:46:54 PST 2020
Author: River Riddle
Date: 2020-02-08T10:44:29-08:00
New Revision: 5c159b91a24b07974328ab17fd56a244995f2944
URL: https://github.com/llvm/llvm-project/commit/5c159b91a24b07974328ab17fd56a244995f2944
DIFF: https://github.com/llvm/llvm-project/commit/5c159b91a24b07974328ab17fd56a244995f2944.diff
LOG: [mlir] Add a utility method on CallOpInterface for resolving the callable.
Summary: This is the most common operation performed on a CallOpInterface. This just moves the existing functionality from the CallGraph so that other users can access it.
Differential Revision: https://reviews.llvm.org/D74250
Added:
Modified:
mlir/include/mlir/Analysis/CallGraph.h
mlir/include/mlir/Analysis/CallInterfaces.h
mlir/include/mlir/Analysis/CallInterfaces.td
mlir/lib/Analysis/CallGraph.cpp
mlir/lib/Transforms/Inliner.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index 6b486ba2d1ac..cd25151da4c0 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -23,6 +23,7 @@
#include "llvm/ADT/SetVector.h"
namespace mlir {
+class CallOpInterface;
struct CallInterfaceCallable;
class Operation;
class Region;
@@ -188,11 +189,8 @@ class CallGraph {
}
/// Resolve the callable for given callee to a node in the callgraph, or the
- /// external node if a valid node was not resolved. 'from' provides an anchor
- /// for symbol table lookups, and is only required if the callable is a symbol
- /// reference.
- CallGraphNode *resolveCallable(CallInterfaceCallable callable,
- Operation *from = nullptr) const;
+ /// external node if a valid node was not resolved.
+ CallGraphNode *resolveCallable(CallOpInterface call) const;
/// An iterator over the nodes of the graph.
using iterator = NodeIterator;
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.h b/mlir/include/mlir/Analysis/CallInterfaces.h
index e0d9ba2fcbdc..ebefc88e21a9 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.h
+++ b/mlir/include/mlir/Analysis/CallInterfaces.h
@@ -14,11 +14,10 @@
#ifndef MLIR_ANALYSIS_CALLINTERFACES_H
#define MLIR_ANALYSIS_CALLINTERFACES_H
-#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/PointerUnion.h"
namespace mlir {
-
/// A callable is either a symbol, or an SSA value, that is referenced by a
/// call-like operation. This represents the destination of the call.
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td
index 7dd58a48fa1d..2bc59c224b4f 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.td
+++ b/mlir/include/mlir/Analysis/CallInterfaces.td
@@ -44,6 +44,18 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"Operation::operand_range", "getArgOperands"
>,
+ InterfaceMethod<[{
+ Resolve the callable operation for given callee to a
+ CallableOpInterface, or nullptr if a valid callable was not resolved.
+ }],
+ "Operation *", "resolveCallable", (ins), [{
+ // If the callable isn't a value, lookup the symbol reference.
+ CallInterfaceCallable callable = op.getCallableForCallee();
+ if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
+ return SymbolTable::lookupNearestSymbolFrom(op, symbolRef);
+ return callable.get<Value>().getDefiningOp();
+ }]
+ >,
];
}
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 0f3bfbb653c8..d61b2359e691 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -79,10 +79,8 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
// If there is no parent node, we ignore this operation. Even if this
// operation was a call, there would be no callgraph node to attribute it
// to.
- if (!resolveCalls || !parentNode)
- return;
- parentNode->addCallEdge(
- cg.resolveCallable(call.getCallableForCallee(), op));
+ if (resolveCalls && parentNode)
+ parentNode->addCallEdge(cg.resolveCallable(call));
return;
}
@@ -141,23 +139,11 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
/// Resolve the callable for given callee to a node in the callgraph, or the
/// external node if a valid node was not resolved.
-CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
- Operation *from) const {
- // Get the callee operation from the callable.
- Operation *callee;
- if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
- callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
- else
- callee = callable.get<Value>().getDefiningOp();
-
- // If the callee is non-null and is a valid callable object, try to get the
- // called region from it.
- if (callee && callee->getNumRegions()) {
- if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
- if (auto *node = lookupNode(callableOp.getCallableRegion()))
- return node;
- }
- }
+CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
+ Operation *callable = call.resolveCallable();
+ if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
+ if (auto *node = lookupNode(callableOp.getCallableRegion()))
+ return node;
// If we don't have a valid direct region, this is an external call.
return getExternalNode();
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 3b6855366657..b6fcf8bc3941 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -86,15 +86,14 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
while (!worklist.empty()) {
for (Operation &op : *worklist.pop_back_val()) {
if (auto call = dyn_cast<CallOpInterface>(op)) {
- CallInterfaceCallable callable = call.getCallableForCallee();
-
// TODO(riverriddle) Support inlining nested call references.
+ CallInterfaceCallable callable = call.getCallableForCallee();
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
if (!symRef.isa<FlatSymbolRefAttr>())
continue;
}
- CallGraphNode *node = cg.resolveCallable(callable, &op);
+ CallGraphNode *node = cg.resolveCallable(call);
if (!node->isExternal())
calls.emplace_back(call, node);
continue;
More information about the Mlir-commits
mailing list