[Mlir-commits] [mlir] [MLIR] Make `resolveCallable` customizable in `CallOpInterface` (PR #100361)
Henrich Lauko
llvmlistbot at llvm.org
Mon Sep 9 13:20:06 PDT 2024
https://github.com/xlauko updated https://github.com/llvm/llvm-project/pull/100361
>From e86afb4a51ef44731e0858919071698dc85e4ec1 Mon Sep 17 00:00:00 2001
From: xlauko <xlauko at mail.muni.cz>
Date: Wed, 24 Jul 2024 14:47:17 +0200
Subject: [PATCH] [MLIR] Make `resolveCallable` customizable in
`CallOpInterface`
Allow customization of the `resolveCallable` method in the `CallOpInterface`. This change allows for operations implementing this interface to provide their own logic for resolving callables.
- Introduce the `resolveCallable` method, which does not include the optional symbol table parameter. This method replaces the previously existing extra class declaration `resolveCallable`.
- Introduce the `resolveCallableInTable` method, which incorporates the symbol table parameter. This method replaces the previous extra class declaration `resolveCallable` that used the optional symbol table parameter.
---
mlir/include/mlir/Interfaces/CallInterfaces.h | 19 +++++++++--
.../include/mlir/Interfaces/CallInterfaces.td | 32 +++++++++++++------
mlir/lib/Analysis/CallGraph.cpp | 2 +-
.../Analysis/DataFlow/DeadCodeAnalysis.cpp | 2 +-
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 2 +-
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 2 +-
.../OwnershipBasedBufferDeallocation.cpp | 2 +-
mlir/lib/Interfaces/CallInterfaces.cpp | 12 +++----
8 files changed, 47 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 7dbcddb01b241e..58c37f01caef09 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -23,10 +23,20 @@ namespace mlir {
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
};
-} // namespace mlir
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/CallInterfaces.h.inc"
+class CallOpInterface;
+
+namespace call_interface_impl {
+
+/// Resolve the callable operation for given callee to a CallableOpInterface, or
+/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
+/// parameter that will allow for using a cached symbol table for symbol lookups
+/// instead of performing an O(N) scan.
+Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr);
+
+} // namespace call_interface_impl
+
+} // namespace mlir
namespace llvm {
@@ -41,4 +51,7 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>
} // namespace llvm
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/CallInterfaces.h.inc"
+
#endif // MLIR_INTERFACES_CALLINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 752de74e6e4d7e..c6002da0d491ce 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -59,17 +59,29 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
Returns the operands within this call that are used as arguments to the
callee as a mutable range.
}],
- "::mlir::MutableOperandRange", "getArgOperandsMutable">,
+ "::mlir::MutableOperandRange", "getArgOperandsMutable"
+ >,
+ InterfaceMethod<[{
+ Resolve the callable operation for given callee to a
+ CallableOpInterface, or nullptr if a valid callable was not resolved.
+ `symbolTable` parameter allow for using a cached symbol table for symbol
+ lookups instead of performing an O(N) scan.
+ }],
+ "::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{
+ return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable);
+ }]
+ >,
+ InterfaceMethod<[{
+ Resolve the callable operation for given callee to a
+ CallableOpInterface, or nullptr if a valid callable was not resolved.
+ }],
+ "::mlir::Operation *", "resolveCallable", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{
+ return ::mlir::call_interface_impl::resolveCallable($_op);
+ }]
+ >
];
-
- let extraClassDeclaration = [{
- /// Resolve the callable operation for given callee to a
- /// CallableOpInterface, or nullptr if a valid callable was not resolved.
- /// `symbolTable` is an optional parameter that will allow for using a
- /// cached symbol table for symbol lookups instead of performing an O(N)
- /// scan.
- ::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
- }];
}
/// Interface for callable operations.
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index ccd4676632136b..780c7caee767c1 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
CallGraphNode *
CallGraph::resolveCallable(CallOpInterface call,
SymbolTableCollection &symbolTable) const {
- Operation *callable = call.resolveCallable(&symbolTable);
+ Operation *callable = call.resolveCallableInTable(&symbolTable);
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 532480b6fad57d..beb68018a3b16e 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
}
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
- Operation *callableOp = call.resolveCallable(&symbolTable);
+ Operation *callableOp = call.resolveCallableInTable(&symbolTable);
// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [this](Operation *op) {
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 37f4ceaaa56cee..300c6e5f9b8919 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
// Find the callee.
- Operation *callee = call.resolveCallable(&symbolTable);
+ Operation *callee = call.resolveCallableInTable(&symbolTable);
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
// No region means the callee is only declared in this module.
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4a73f21a18aae7..1bd6defef90be0 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
- Operation *callableOp = call.resolveCallable(&symbolTable);
+ Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
// stored in `unaccounted`.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index ca5d0688b5b594..b973618004497b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -824,7 +824,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
// the function is referenced by SSA value instead of a Symbol, it's assumed
// to be public. (And we cannot easily change the type of the SSA value
// anyway.)
- Operation *funcOp = op.resolveCallable(state.getSymbolTable());
+ Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable());
bool isPrivate = false;
if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
isPrivate = symbol.isPrivate() && !symbol.isDeclaration();
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 455684d8e2ea7c..47f8021f50cd28 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -14,21 +14,17 @@ using namespace mlir;
// CallOpInterface
//===----------------------------------------------------------------------===//
-/// Resolve the callable operation for given callee to a CallableOpInterface, or
-/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
-/// parameter that will allow for using a cached symbol table for symbol lookups
-/// instead of performing an O(N) scan.
Operation *
-CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
- CallInterfaceCallable callable = getCallableForCallee();
+call_interface_impl::resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable) {
+ CallInterfaceCallable callable = call.getCallableForCallee();
if (auto symbolVal = dyn_cast<Value>(callable))
return symbolVal.getDefiningOp();
// If the callable isn't a value, lookup the symbol reference.
auto symbolRef = callable.get<SymbolRefAttr>();
if (symbolTable)
- return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
- return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
+ return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
+ return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list