[Mlir-commits] [mlir] [MLIR] Make `resolveCallable` customizable in `CallOpInterface` (PR #100361)

Henrich Lauko llvmlistbot at llvm.org
Wed Jul 24 05:52:24 PDT 2024


https://github.com/xlauko created https://github.com/llvm/llvm-project/pull/100361

Allow customization of the `resolveCallable` method in the `CallOpInterface`. This change allows for operations implementing this interface to provide their own logic for resolving callables.

- Introduce the `resolveCallable` method, which does not include the optional symbol table parameter. This method replaces the previously existing extra class declaration `resolveCallable`.

- Introduce the `resolveCallableInTable` method, which incorporates the symbol table parameter. This method replaces the previous extra class declaration `resolveCallable` that used the optional symbol table parameter.

>From ab2ddd07820f01e09c90f37029740fd4d42533a3 Mon Sep 17 00:00:00 2001
From: xlauko <xlauko at mail.muni.cz>
Date: Wed, 24 Jul 2024 14:47:17 +0200
Subject: [PATCH] [MLIR] Make `resolveCallable` customizable in
 `CallOpInterface`

Allow customization of the `resolveCallable` method in the `CallOpInterface`. This change allows for operations implementing this interface to provide their own logic for resolving callables.

- Introduce the `resolveCallable` method, which does not include the optional symbol table parameter. This method replaces the previously existing extra class declaration `resolveCallable`.

- Introduce the `resolveCallableInTable` method, which incorporates the symbol table parameter. This method replaces the previous extra class declaration `resolveCallable` that used the optional symbol table parameter.
---
 mlir/include/mlir/Interfaces/CallInterfaces.h |  6 +--
 .../include/mlir/Interfaces/CallInterfaces.td | 44 ++++++++++++++-----
 mlir/lib/Analysis/CallGraph.cpp               |  2 +-
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    |  2 +-
 mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp  |  2 +-
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp |  2 +-
 .../OwnershipBasedBufferDeallocation.cpp      |  3 +-
 mlir/lib/Interfaces/CallInterfaces.cpp        | 17 -------
 8 files changed, 43 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 7dbcddb01b241..15a3d260f4166 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -25,9 +25,6 @@ struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
 };
 } // namespace mlir
 
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/CallInterfaces.h.inc"
-
 namespace llvm {
 
 // Allow llvm::cast style functions.
@@ -41,4 +38,7 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>
 
 } // namespace llvm
 
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/CallInterfaces.h.inc"
+
 #endif // MLIR_INTERFACES_CALLINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 752de74e6e4d7..1da2a6d4e7eac 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -59,17 +59,41 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
         Returns the operands within this call that are used as arguments to the
         callee as a mutable range.
       }],
-      "::mlir::MutableOperandRange", "getArgOperandsMutable">,
-  ];
+      "::mlir::MutableOperandRange", "getArgOperandsMutable"
+    >,
+    InterfaceMethod<[{
+        Resolve the callable operation for given callee to a
+        CallableOpInterface, or nullptr if a valid callable was not resolved.
+        `symbolTable` parameter allow for using a cached symbol table for symbol
+        lookups instead of performing an O(N) scan.
+      }],
+      "::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection &":$symbolTable),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{
+        ::mlir::CallInterfaceCallable callable = $_op.getCallableForCallee();
+        if (auto symbolVal = dyn_cast<::mlir::Value>(callable))
+          return symbolVal.getDefiningOp();
 
-  let extraClassDeclaration = [{
-    /// Resolve the callable operation for given callee to a
-    /// CallableOpInterface, or nullptr if a valid callable was not resolved.
-    /// `symbolTable` is an optional parameter that will allow for using a
-    /// cached symbol table for symbol lookups instead of performing an O(N)
-    /// scan.
-    ::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
-  }];
+        // If the callable isn't a value, lookup the symbol reference.
+        auto symbolRef = callable.get<::mlir::SymbolRefAttr>();
+        return symbolTable.lookupNearestSymbolFrom($_op, symbolRef);
+      }]
+    >,
+    InterfaceMethod<[{
+        Resolve the callable operation for given callee to a
+        CallableOpInterface, or nullptr if a valid callable was not resolved.
+    }],
+      "::mlir::Operation *", "resolveCallable", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{
+        ::mlir::CallInterfaceCallable callable = $_op.getCallableForCallee();
+        if (auto symbolVal = dyn_cast<::mlir::Value>(callable))
+          return symbolVal.getDefiningOp();
+
+        // If the callable isn't a value, lookup the symbol reference.
+        auto symbolRef = callable.get<::mlir::SymbolRefAttr>();
+        return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef);
+      }]
+    >
+  ];
 }
 
 /// Interface for callable operations.
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index ccd4676632136..6a85fb1a0da27 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
 CallGraphNode *
 CallGraph::resolveCallable(CallOpInterface call,
                            SymbolTableCollection &symbolTable) const {
-  Operation *callable = call.resolveCallable(&symbolTable);
+  Operation *callable = call.resolveCallableInTable(symbolTable);
   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
     if (auto *node = lookupNode(callableOp.getCallableRegion()))
       return node;
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index fab2bd83888da..cb80d229ae678 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -295,7 +295,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
 }
 
 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
-  Operation *callableOp = call.resolveCallable(&symbolTable);
+  Operation *callableOp = call.resolveCallableInTable(symbolTable);
 
   // A call to a externally-defined callable has unknown predecessors.
   const auto isExternalCallable = [this](Operation *op) {
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 9894810f0e04b..21f99c1b2b6cd 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -281,7 +281,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
     CallOpInterface call, const AbstractDenseLattice &after,
     AbstractDenseLattice *before) {
   // Find the callee.
-  Operation *callee = call.resolveCallable(&symbolTable);
+  Operation *callee = call.resolveCallableInTable(symbolTable);
 
   auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
   // No region means the callee is only declared in this module.
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index ad956b73e4b1d..a64c3b50ed6e7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -438,7 +438,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
-    Operation *callableOp = call.resolveCallable(&symbolTable);
+    Operation *callableOp = call.resolveCallableInTable(symbolTable);
     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
       // Not all operands of a call op forward to arguments. Such operands are
       // stored in `unaccounted`.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index ca5d0688b5b59..949b08682443f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -824,7 +824,8 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
   // the function is referenced by SSA value instead of a Symbol, it's assumed
   // to be public. (And we cannot easily change the type of the SSA value
   // anyway.)
-  Operation *funcOp = op.resolveCallable(state.getSymbolTable());
+  SymbolTableCollection *symbolTable = state.getSymbolTable();
+  Operation *funcOp = op.resolveCallableInTable(*symbolTable);
   bool isPrivate = false;
   if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
     isPrivate = symbol.isPrivate() && !symbol.isDeclaration();
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 455684d8e2ea7..527c19713addf 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -14,23 +14,6 @@ using namespace mlir;
 // CallOpInterface
 //===----------------------------------------------------------------------===//
 
-/// Resolve the callable operation for given callee to a CallableOpInterface, or
-/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
-/// parameter that will allow for using a cached symbol table for symbol lookups
-/// instead of performing an O(N) scan.
-Operation *
-CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
-  CallInterfaceCallable callable = getCallableForCallee();
-  if (auto symbolVal = dyn_cast<Value>(callable))
-    return symbolVal.getDefiningOp();
-
-  // If the callable isn't a value, lookup the symbol reference.
-  auto symbolRef = callable.get<SymbolRefAttr>();
-  if (symbolTable)
-    return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
-  return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
-}
-
 //===----------------------------------------------------------------------===//
 // CallInterfaces
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list