[Mlir-commits] [mlir] [MLIR] Make `OneShotModuleBufferize` use `OpInterface` (PR #107295)
Tzung-Han Juang
llvmlistbot at llvm.org
Wed Sep 4 12:25:29 PDT 2024
https://github.com/tzunghanjuang created https://github.com/llvm/llvm-project/pull/107295
**Description:**
`OneShotModuleBufferize` deals with the bufferization of `FuncOp`, `CallOp` and `ReturnOp` but they are hard-coded. Any custom function-like operations will not be handled. The PR replaces a part of `FuncOp` and `CallOp` with `FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize` so that custom function ops and call ops can be bufferized.
**Limitations:**
`ReturnOp` is not implemented with any interface. Right now we just create if cases for detecting `FuncOp` to trigger `ReturnOp` bufferization.
**Related Discord Discussion:** [Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900)
>From 8a5aca204bb7ed1a0a05f14994274a70f732b3d6 Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Wed, 4 Sep 2024 15:04:36 -0400
Subject: [PATCH] Make OneShotModuleBufferize accept FunctionOpInterface and
CallOpInterface
---
.../Transforms/OneShotModuleBufferize.cpp | 81 ++++++++++++-------
1 file changed, 50 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..2983af0fcbf3f7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
using namespace mlir::bufferization::func_ext;
/// A mapping of FuncOps to their callers.
-using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
+using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
/// Get or create FuncAnalysisState.
static FuncAnalysisState &
@@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
+static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ if (!sym)
+ return nullptr;
+ return dyn_cast_or_null<FunctionOpInterface>(
+ SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+}
+
/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
@@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
}
/// Return "true" if the given function signature has tensor semantics.
-static bool hasTensorSignature(func::FuncOp funcOp) {
- return llvm::any_of(funcOp.getFunctionType().getInputs(),
+static bool hasTensorSignature(FunctionOpInterface funcOp) {
+ return llvm::any_of(funcOp.getArgumentTypes(),
llvm::IsaPred<TensorType>) ||
- llvm::any_of(funcOp.getFunctionType().getResults(),
+ llvm::any_of(funcOp.getResultTypes(),
llvm::IsaPred<TensorType>);
}
@@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
- DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
+ DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
- DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
- if (!funcOp.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
+ DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
+ WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
+ // Only handle ReturnOp if funcOp is exactly the FuncOp type.
+ if(isa<FuncOp>(funcOp)) {
+ FuncOp funcOpCasted = cast<FuncOp>(funcOp);
+ if (!funcOpCasted.getBody().empty()) {
+ func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
+ if (!returnOp)
+ return funcOp->emitError()
+ << "cannot bufferize a FuncOp with tensors and "
+ "without a unique ReturnOp";
+ }
}
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
- return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
+ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+ FunctionOpInterface calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
@@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<FunctionOpInterface> orderedFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
@@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return failure();
// Analyze ops.
- for (func::FuncOp funcOp : orderedFuncOps) {
- if (!state.getOptions().isOpAllowed(funcOp))
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
+
+ // The following analysis is specific to the FuncOp type.
+ if(!isa<FuncOp>(funcOp))
+ continue;
+ FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
+
+ if (!state.getOptions().isOpAllowed(funcOpCasted))
continue;
// Now analyzing function.
- funcState.startFunctionAnalysis(funcOp);
+ funcState.startFunctionAnalysis(funcOpCasted);
// Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOp, state, funcState);
+ equivalenceAnalysis(funcOpCasted, state, funcState);
// Analyze funcOp.
- if (failed(analyzeOp(funcOp, state, statistics)))
+ if (failed(analyzeOp(funcOpCasted, state, statistics)))
return failure();
// Run some extra function analyses.
- if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
- failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
+ if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
+ failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
return failure();
// Mark op as fully analyzed.
- funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+ funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
}
return success();
@@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<FunctionOpInterface> orderedFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
+ SmallVector<FunctionOpInterface> ops;
// Bufferize functions.
- for (func::FuncOp funcOp : orderedFuncOps) {
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
-
- if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+ if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
@@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Change buffer return types to more precise layout maps.
- if (options.inferFunctionResultLayout)
- foldMemRefCasts(funcOp);
+ if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
+ foldMemRefCasts(cast<func::FuncOp>(funcOp));
}
// Bufferize all other ops.
More information about the Mlir-commits
mailing list