[Mlir-commits] [mlir] [MLIR] Use cached symbol tables in `getFuncOpsOrderedByCalls` (PR #141967)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 29 08:39:40 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Michele Scuttari (mscuttari)
<details>
<summary>Changes</summary>
Address TODO regarding the recomputation of symbol tables. The signature of the `getFuncOpsOrderedByCalls` function is modified to receive the collection of cached symbol tables.
---
Full diff: https://github.com/llvm/llvm-project/pull/141967.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+21-7)
``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index dee2af8271ce8..f7b72a8ab022b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -310,21 +310,19 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
- SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
+ SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- // TODO Avoid recomputing the symbol tables every time.
- mlir::SymbolTableCollection symbolTable;
-
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
@@ -362,6 +360,21 @@ static LogicalResult getFuncOpsOrderedByCalls(
return success();
}
+static LogicalResult getFuncOpsOrderedByCalls(
+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
+ OneShotAnalysisState &analysisState) {
+ auto *funcAnalysisState = analysisState.getExtension<FuncAnalysisState>();
+
+ if (funcAnalysisState)
+ return getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, remainingFuncOps,
+ callerMap, funcAnalysisState->symbolTables);
+
+ SymbolTableCollection symbolTables;
+ return getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, remainingFuncOps,
+ callerMap, symbolTables);
+}
+
/// Helper function that extracts the source from a memref.cast. If the given
/// value is not a memref.cast result, simply returns the given value.
static Value unpackCast(Value v) {
@@ -458,7 +471,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncCallerMap callerMap;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
- remainingFuncOps, callerMap)))
+ remainingFuncOps, callerMap, state)))
return failure();
// Analyze functions in order. Starting with functions that are not calling
@@ -534,7 +547,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// each other recursively are bufferized in an unspecified order at the end.
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
- remainingFuncOps, callerMap)))
+ remainingFuncOps, callerMap,
+ state.getSymbolTables())))
return failure();
llvm::append_range(orderedFuncOps, remainingFuncOps);
``````````
</details>
https://github.com/llvm/llvm-project/pull/141967
More information about the Mlir-commits
mailing list