[Mlir-commits] [mlir] Cache symbol tables during OneShotBufferization analyses (PR #138125)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 1 05:20:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Michele Scuttari (mscuttari)

<details>
<summary>Changes</summary>

During bufferization, the callee of each `func::CallOp` / `CallableOpInterface` operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the `ModuleOp` body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols.
The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/

This PR aims to partially address this scaling issue by leveraging the `SymbolTableCollection` class, whose instance is added to the `FuncAnalysisState` extension. Later modifications are also expected to address the problem in other methods required by `BufferizableOpInterface` (e.g., `bufferize` and `getBufferType`), which suffer of the same problem but do not provide access to any bufferization state.

---
Full diff: https://github.com/llvm/llvm-project/pull/138125.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+3) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+19-10) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+9-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac7..b63c0883c6c15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
   /// analyzed.
   DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
 
+  /// A collection of cached SymbolTables used for faster function lookup.
+  mutable mlir::SymbolTableCollection symbolTable;
+
   /// This function is called right before analyzing the given FuncOp. It
   /// initializes the data structures for the FuncOp in this state object.
   void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..86d15d4f0a607 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 }
 
 /// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+                                mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Get FuncAnalysisState.
@@ -135,14 +136,14 @@ struct CallOpInterface
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is read.
       return true;
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.readBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -150,14 +151,14 @@ struct CallOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is written.
       return true;
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.writtenBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -165,14 +166,14 @@ struct CallOpInterface
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Any OpResult may be aliasing.
       return detail::unknownGetAliasingValues(opOperand);
 
     // Get aliasing results from state.
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     auto aliasingReturnVals =
         funcState.aliasingReturnVals.lookup(funcOp).lookup(
             opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    mlir::SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     // If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
     // 2. Rewrite tensor operands as memrefs based on type of the already
     //    bufferized callee.
     SmallVector<Value> newOperands;
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    mlir::SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     FunctionType funcType = funcOp.getFunctionType();
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf84f460..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
 }
 
 /// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+                  mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+  // TODO Avoid recomputing the symbol tables every time.
+  mlir::SymbolTableCollection symbolTable;
+
   for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp);
+      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.

``````````

</details>


https://github.com/llvm/llvm-project/pull/138125


More information about the Mlir-commits mailing list