[Mlir-commits] [mlir] d1cad22 - Reland [MLIR] Make resolveCallable customizable in CallOpInterface (#107989)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 06:33:17 PDT 2024


Author: Henrich Lauko
Date: 2024-09-10T15:33:13+02:00
New Revision: d1cad2290c10712ea27509081f50769ed597ee0f

URL: https://github.com/llvm/llvm-project/commit/d1cad2290c10712ea27509081f50769ed597ee0f
DIFF: https://github.com/llvm/llvm-project/commit/d1cad2290c10712ea27509081f50769ed597ee0f.diff

LOG: Reland [MLIR] Make resolveCallable customizable in CallOpInterface (#107989)

Relands #100361 with fixed dependencies.

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/CallInterfaces.h
    mlir/include/mlir/Interfaces/CallInterfaces.td
    mlir/lib/Analysis/CallGraph.cpp
    mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/lib/Dialect/Async/IR/CMakeLists.txt
    mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
    mlir/lib/Interfaces/CallInterfaces.cpp
    mlir/lib/Transforms/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 7dbcddb01b241e..0020c19333d103 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -23,10 +23,21 @@ namespace mlir {
 struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
   using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
 };
-} // namespace mlir
 
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/CallInterfaces.h.inc"
+class CallOpInterface;
+
+namespace call_interface_impl {
+
+/// Resolve the callable operation for given callee to a CallableOpInterface, or
+/// nullptr if a valid callable was not resolved.  `symbolTable` is an optional
+/// parameter that will allow for using a cached symbol table for symbol lookups
+/// instead of performing an O(N) scan.
+Operation *resolveCallable(CallOpInterface call,
+                           SymbolTableCollection *symbolTable = nullptr);
+
+} // namespace call_interface_impl
+
+} // namespace mlir
 
 namespace llvm {
 
@@ -41,4 +52,7 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>
 
 } // namespace llvm
 
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/CallInterfaces.h.inc"
+
 #endif // MLIR_INTERFACES_CALLINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 752de74e6e4d7e..c6002da0d491ce 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -59,17 +59,29 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
         Returns the operands within this call that are used as arguments to the
         callee as a mutable range.
       }],
-      "::mlir::MutableOperandRange", "getArgOperandsMutable">,
+      "::mlir::MutableOperandRange", "getArgOperandsMutable"
+    >,
+    InterfaceMethod<[{
+        Resolve the callable operation for given callee to a
+        CallableOpInterface, or nullptr if a valid callable was not resolved.
+        `symbolTable` parameter allow for using a cached symbol table for symbol
+        lookups instead of performing an O(N) scan.
+      }],
+      "::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{
+        return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable);
+      }]
+    >,
+    InterfaceMethod<[{
+        Resolve the callable operation for given callee to a
+        CallableOpInterface, or nullptr if a valid callable was not resolved.
+    }],
+      "::mlir::Operation *", "resolveCallable", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{
+        return ::mlir::call_interface_impl::resolveCallable($_op);
+      }]
+    >
   ];
-
-  let extraClassDeclaration = [{
-    /// Resolve the callable operation for given callee to a
-    /// CallableOpInterface, or nullptr if a valid callable was not resolved.
-    /// `symbolTable` is an optional parameter that will allow for using a
-    /// cached symbol table for symbol lookups instead of performing an O(N)
-    /// scan.
-    ::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
-  }];
 }
 
 /// Interface for callable operations.

diff  --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index ccd4676632136b..780c7caee767c1 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
 CallGraphNode *
 CallGraph::resolveCallable(CallOpInterface call,
                            SymbolTableCollection &symbolTable) const {
-  Operation *callable = call.resolveCallable(&symbolTable);
+  Operation *callable = call.resolveCallableInTable(&symbolTable);
   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
     if (auto *node = lookupNode(callableOp.getCallableRegion()))
       return node;

diff  --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 532480b6fad57d..beb68018a3b16e 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
 }
 
 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
-  Operation *callableOp = call.resolveCallable(&symbolTable);
+  Operation *callableOp = call.resolveCallableInTable(&symbolTable);
 
   // A call to a externally-defined callable has unknown predecessors.
   const auto isExternalCallable = [this](Operation *op) {

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 37f4ceaaa56cee..300c6e5f9b8919 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
     CallOpInterface call, const AbstractDenseLattice &after,
     AbstractDenseLattice *before) {
   // Find the callee.
-  Operation *callee = call.resolveCallable(&symbolTable);
+  Operation *callee = call.resolveCallableInTable(&symbolTable);
 
   auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
   // No region means the callee is only declared in this module.

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4a73f21a18aae7..1bd6defef90be0 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
-    Operation *callableOp = call.resolveCallable(&symbolTable);
+    Operation *callableOp = call.resolveCallableInTable(&symbolTable);
     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
       // Not all operands of a call op forward to arguments. Such operands are
       // stored in `unaccounted`.

diff  --git a/mlir/lib/Dialect/Async/IR/CMakeLists.txt b/mlir/lib/Dialect/Async/IR/CMakeLists.txt
index db903a48c43196..e0e667500308ae 100644
--- a/mlir/lib/Dialect/Async/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRAsyncDialect
   MLIRAsyncOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRFunctionInterfaces
   MLIRDialect

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index f27d924416677a..50104e8f8346b4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRBufferizationDialect
+  MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRFuncDialect
   MLIRFunctionInterfaces
@@ -42,4 +43,3 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   MLIRViewLikeInterface
   MLIRSupport
 )
-

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index ca5d0688b5b594..b973618004497b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -824,7 +824,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
   // the function is referenced by SSA value instead of a Symbol, it's assumed
   // to be public. (And we cannot easily change the type of the SSA value
   // anyway.)
-  Operation *funcOp = op.resolveCallable(state.getSymbolTable());
+  Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable());
   bool isPrivate = false;
   if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
     isPrivate = symbol.isPrivate() && !symbol.isDeclaration();

diff  --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 455684d8e2ea7c..9e5bc159dc8908 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -14,21 +14,18 @@ using namespace mlir;
 // CallOpInterface
 //===----------------------------------------------------------------------===//
 
-/// Resolve the callable operation for given callee to a CallableOpInterface, or
-/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
-/// parameter that will allow for using a cached symbol table for symbol lookups
-/// instead of performing an O(N) scan.
 Operation *
-CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
-  CallInterfaceCallable callable = getCallableForCallee();
+call_interface_impl::resolveCallable(CallOpInterface call,
+                                     SymbolTableCollection *symbolTable) {
+  CallInterfaceCallable callable = call.getCallableForCallee();
   if (auto symbolVal = dyn_cast<Value>(callable))
     return symbolVal.getDefiningOp();
 
   // If the callable isn't a value, lookup the symbol reference.
   auto symbolRef = callable.get<SymbolRefAttr>();
   if (symbolTable)
-    return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
-  return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
+    return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
+  return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index b5788c679edc44..eb588640dbf83a 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_library(MLIRTransformUtils
 
   LINK_LIBS PUBLIC
   MLIRAnalysis
+  MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRFunctionInterfaces
   MLIRLoopLikeInterface


        


More information about the Mlir-commits mailing list