[Mlir-commits] [mlir] Cache symbol tables during OneShotBufferization analyses (PR #138125)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 1 05:20:16 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Michele Scuttari (mscuttari)
<details>
<summary>Changes</summary>
During bufferization, the callee of each `func::CallOp` / `CallableOpInterface` operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the `ModuleOp` body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols.
The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/
This PR aims to partially address this scaling issue by leveraging the `SymbolTableCollection` class, whose instance is added to the `FuncAnalysisState` extension. Later modifications are also expected to address the problem in other methods required by `BufferizableOpInterface` (e.g., `bufferize` and `getBufferType`), which suffer of the same problem but do not provide access to any bufferization state.
---
Full diff: https://github.com/llvm/llvm-project/pull/138125.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+3)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+19-10)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+9-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac7..b63c0883c6c15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+ /// A collection of cached SymbolTables used for faster function lookup.
+ mutable mlir::SymbolTableCollection symbolTable;
+
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..86d15d4f0a607 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
}
/// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Get FuncAnalysisState.
@@ -135,14 +136,14 @@ struct CallOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is read.
return true;
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
return funcState.readBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -150,14 +151,14 @@ struct CallOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is written.
return true;
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
return funcState.writtenBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -165,14 +166,14 @@ struct CallOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Any OpResult may be aliasing.
return detail::unknownGetAliasingValues(opOperand);
// Get aliasing results from state.
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
auto aliasingReturnVals =
funcState.aliasingReturnVals.lookup(funcOp).lookup(
opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
// If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
// 2. Rewrite tensor operands as memrefs based on type of the already
// bufferized callee.
SmallVector<Value> newOperands;
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
FunctionType funcType = funcOp.getFunctionType();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf84f460..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
}
/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
``````````
</details>
https://github.com/llvm/llvm-project/pull/138125
More information about the Mlir-commits
mailing list