[Mlir-commits] [mlir] [MLIR] Make `OneShotModuleBufferize` use `OpInterface` (PR #107295)
Tzung-Han Juang
llvmlistbot at llvm.org
Tue Sep 10 11:30:28 PDT 2024
https://github.com/tzunghanjuang updated https://github.com/llvm/llvm-project/pull/107295
>From 8a5aca204bb7ed1a0a05f14994274a70f732b3d6 Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Wed, 4 Sep 2024 15:04:36 -0400
Subject: [PATCH 1/6] Make OneShotModuleBufferize accept FunctionOpInterface
and CallOpInterface
---
.../Transforms/OneShotModuleBufferize.cpp | 81 ++++++++++++-------
1 file changed, 50 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..2983af0fcbf3f7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
using namespace mlir::bufferization::func_ext;
/// A mapping of FuncOps to their callers.
-using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
+using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
/// Get or create FuncAnalysisState.
static FuncAnalysisState &
@@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
+static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ if (!sym)
+ return nullptr;
+ return dyn_cast_or_null<FunctionOpInterface>(
+ SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+}
+
/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
@@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
}
/// Return "true" if the given function signature has tensor semantics.
-static bool hasTensorSignature(func::FuncOp funcOp) {
- return llvm::any_of(funcOp.getFunctionType().getInputs(),
+static bool hasTensorSignature(FunctionOpInterface funcOp) {
+ return llvm::any_of(funcOp.getArgumentTypes(),
llvm::IsaPred<TensorType>) ||
- llvm::any_of(funcOp.getFunctionType().getResults(),
+ llvm::any_of(funcOp.getResultTypes(),
llvm::IsaPred<TensorType>);
}
@@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
- DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
+ DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
- DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
- if (!funcOp.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
+ DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
+ WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
+ // Only handle ReturnOp if funcOp is exactly the FuncOp type.
+ if(isa<FuncOp>(funcOp)) {
+ FuncOp funcOpCasted = cast<FuncOp>(funcOp);
+ if (!funcOpCasted.getBody().empty()) {
+ func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
+ if (!returnOp)
+ return funcOp->emitError()
+ << "cannot bufferize a FuncOp with tensors and "
+ "without a unique ReturnOp";
+ }
}
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
- return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
+ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+ FunctionOpInterface calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
@@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<FunctionOpInterface> orderedFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
@@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return failure();
// Analyze ops.
- for (func::FuncOp funcOp : orderedFuncOps) {
- if (!state.getOptions().isOpAllowed(funcOp))
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
+
+ // The following analysis is specific to the FuncOp type.
+ if(!isa<FuncOp>(funcOp))
+ continue;
+ FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
+
+ if (!state.getOptions().isOpAllowed(funcOpCasted))
continue;
// Now analyzing function.
- funcState.startFunctionAnalysis(funcOp);
+ funcState.startFunctionAnalysis(funcOpCasted);
// Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOp, state, funcState);
+ equivalenceAnalysis(funcOpCasted, state, funcState);
// Analyze funcOp.
- if (failed(analyzeOp(funcOp, state, statistics)))
+ if (failed(analyzeOp(funcOpCasted, state, statistics)))
return failure();
// Run some extra function analyses.
- if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
- failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
+ if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
+ failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
return failure();
// Mark op as fully analyzed.
- funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+ funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
}
return success();
@@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<FunctionOpInterface> orderedFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
+ SmallVector<FunctionOpInterface> ops;
// Bufferize functions.
- for (func::FuncOp funcOp : orderedFuncOps) {
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
-
- if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+ if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
@@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Change buffer return types to more precise layout maps.
- if (options.inferFunctionResultLayout)
- foldMemRefCasts(funcOp);
+ if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
+ foldMemRefCasts(cast<func::FuncOp>(funcOp));
}
// Bufferize all other ops.
>From 5153af3ee72d4322273b1614a6637a952b10cdcc Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Wed, 4 Sep 2024 15:42:08 -0400
Subject: [PATCH 2/6] Cleanup
---
.../Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 2983af0fcbf3f7..5231fe86055371 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -456,12 +456,12 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
- SmallVector<FunctionOpInterface> ops;
// Bufferize functions.
for (FunctionOpInterface funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
+
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
>From 1f8d847077716be2f0115c4fadcb7c2d4eafe945 Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Fri, 6 Sep 2024 10:37:18 -0400
Subject: [PATCH 3/6] Make getAssumedUniqueReturnOp detect ReturnLike and
FuncAnalysisState use FunctionOpInterface
---
.../FuncBufferizableOpInterfaceImpl.h | 12 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 2 +-
.../Transforms/OneShotModuleBufferize.cpp | 117 ++++++++----------
3 files changed, 59 insertions(+), 72 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index 0b91d3d675b7c9..8bed0dfc5814b7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
- DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
+ DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
- DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
+ DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
/// A set of all read BlockArguments of FuncOps.
- DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
+ DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
/// A set of all written-to BlockArguments of FuncOps.
- DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
+ DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
/// Keep track of which FuncOps are fully analyzed or currently being
/// analyzed.
- DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+ DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
- void startFunctionAnalysis(FuncOp funcOp);
+ void startFunctionAnalysis(FunctionOpInterface funcOp);
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..9749a71f3514bc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -22,7 +22,7 @@ namespace mlir {
namespace bufferization {
namespace func_ext {
-void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
+void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
auto createdAliasingResults =
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 5231fe86055371..cfb87aef6e64bb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
+static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
+ Operation *returnOp = nullptr;
+ for (Block &b : funcOp.getFunctionBody()) {
+ auto candidateOp = b.getTerminator();
+ if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
@@ -126,16 +127,15 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
/// Store function BlockArguments that are equivalent to/aliasing a returned
/// value in FuncAnalysisState.
static LogicalResult
-aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- if (funcOp.getBody().empty()) {
+ if (funcOp.getFunctionBody().empty()) {
// No function body available. Conservatively assume that every tensor
// return value may alias with any tensor bbArg.
- FunctionType type = funcOp.getFunctionType();
- for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
+ for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
if (!isa<TensorType>(inputIt.value()))
continue;
- for (const auto &resultIt : llvm::enumerate(type.getResults())) {
+ for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
if (!isa<TensorType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
@@ -147,7 +147,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
}
// Support only single return-terminated block in the function.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!isa<func::FuncOp>(funcOp))
+ return success();
+ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -168,7 +170,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}
-static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
+static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx, bool isRead,
bool isWritten) {
OpBuilder b(funcOp.getContext());
Attribute accessType;
@@ -189,12 +191,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
/// function with unknown ops, we conservatively assume that such ops bufferize
/// to a read + write.
static LogicalResult
-funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
+ for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e;
++idx) {
// Skip non-tensor arguments.
- if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
+ if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
continue;
bool isRead;
bool isWritten;
@@ -204,7 +206,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
StringRef str = accessAttr.getValue();
isRead = str == "read" || str == "read-write";
isWritten = str == "write" || str == "read-write";
- } else if (funcOp.getBody().empty()) {
+ } else if (funcOp.getFunctionBody().empty()) {
// If the function has no body, conservatively assume that all args are
// read + written.
isRead = true;
@@ -230,23 +232,13 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
/// Remove bufferization attributes on FuncOp arguments.
static void removeBufferizationAttributes(BlockArgument bbArg) {
- auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
+ auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kBufferLayoutAttrName);
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kWritableAttrName);
}
-/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
-}
-
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
@@ -260,12 +252,12 @@ static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
// TODO: This does not handle cyclic function call graphs etc.
-static void equivalenceAnalysis(func::FuncOp funcOp,
+static void equivalenceAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- funcOp->walk([&](func::CallOp callOp) {
- func::FuncOp calledFunction = getCalledFunction(callOp);
- assert(calledFunction && "could not retrieved called func::FuncOp");
+ funcOp->walk([&](CallOpInterface callOp) {
+ FunctionOpInterface calledFunction = getCalledFunction(callOp);
+ assert(calledFunction && "could not retrieved called FunctionOpInterface");
// No equivalence info available for the called function.
if (!funcState.equivalentFuncArgs.count(calledFunction))
@@ -276,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
int64_t bbargIdx = it.second;
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
continue;
- Value returnVal = callOp.getResult(returnIdx);
+ Value returnVal = callOp->getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
state.unionEquivalenceClasses(returnVal, argVal);
}
@@ -308,23 +300,19 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
- // Only handle ReturnOp if funcOp is exactly the FuncOp type.
- if(isa<FuncOp>(funcOp)) {
- FuncOp funcOpCasted = cast<FuncOp>(funcOp);
- if (!funcOpCasted.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
- }
+ if (!funcOp.getFunctionBody().empty() && isa<func::FuncOp>(funcOp)) {
+ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!returnOp)
+ return funcOp->emitError()
+ << "cannot bufferize a FuncOp with tensors and "
+ "without a unique ReturnOp";
}
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
FunctionOpInterface calledFunction = getCalledFunction(callOp);
- assert(calledFunction && "could not retrieved called func::FuncOp");
+ assert(calledFunction && "could not retrieved called FunctionOpInterface");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
@@ -362,11 +350,15 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// most generic layout map as function return types. After bufferizing the
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
-static void foldMemRefCasts(func::FuncOp funcOp) {
- if (funcOp.getBody().empty())
+static void foldMemRefCasts(FunctionOpInterface funcOp) {
+ if (funcOp.getFunctionBody().empty())
+ return;
+
+ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
+
+ if (!returnOp)
return;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
SmallVector<Type> resultTypes;
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -379,7 +371,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
auto newFuncType = FunctionType::get(
- funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
+ funcOp.getContext(), funcOp.getArgumentTypes(), resultTypes);
funcOp.setType(newFuncType);
}
@@ -403,31 +395,26 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// Analyze ops.
for (FunctionOpInterface funcOp : orderedFuncOps) {
- // The following analysis is specific to the FuncOp type.
- if(!isa<FuncOp>(funcOp))
- continue;
- FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
-
- if (!state.getOptions().isOpAllowed(funcOpCasted))
+ if (!state.getOptions().isOpAllowed(funcOp))
continue;
// Now analyzing function.
- funcState.startFunctionAnalysis(funcOpCasted);
+ funcState.startFunctionAnalysis(funcOp);
// Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOpCasted, state, funcState);
+ equivalenceAnalysis(funcOp, state, funcState);
// Analyze funcOp.
- if (failed(analyzeOp(funcOpCasted, state, statistics)))
+ if (failed(analyzeOp(funcOp, state, statistics)))
return failure();
// Run some extra function analyses.
- if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
- failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
+ if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
+ failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
return failure();
// Mark op as fully analyzed.
- funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
+ funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
}
return success();
@@ -435,7 +422,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
- moduleOp.walk([&](func::FuncOp op) {
+ moduleOp.walk([&](FunctionOpInterface op) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
});
@@ -475,14 +462,14 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Change buffer return types to more precise layout maps.
- if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
- foldMemRefCasts(cast<func::FuncOp>(funcOp));
+ if (options.inferFunctionResultLayout)
+ foldMemRefCasts(funcOp);
}
// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
- if (isa<func::FuncOp>(&op))
+ if (isa<FunctionOpInterface>(&op))
continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
@@ -509,12 +496,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
- auto func = dyn_cast<func::FuncOp>(op);
+ auto func = dyn_cast<FunctionOpInterface>(op);
if (!func)
- func = op->getParentOfType<func::FuncOp>();
+ func = op->getParentOfType<FunctionOpInterface>();
if (func)
return llvm::is_contained(options.noAnalysisFuncFilter,
- func.getSymName());
+ func.getName());
return false;
};
OneShotBufferizationOptions updatedOptions(options);
>From 26e69ad35197b7c1b7a2084810b714898af2aeb7 Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Fri, 6 Sep 2024 10:56:50 -0400
Subject: [PATCH 4/6] Make getAssumedUniqueReturnOp return funcOp if there is
no return
---
.../Transforms/OneShotModuleBufferize.cpp | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index cfb87aef6e64bb..bd054ac4e7b87e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -88,6 +88,7 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
+/// Return `funcOp` it self if there is no ReturnOp.
static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
Operation *returnOp = nullptr;
for (Block &b : funcOp.getFunctionBody()) {
@@ -98,6 +99,8 @@ static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
returnOp = candidateOp;
}
}
+ if (!returnOp)
+ return funcOp;
return returnOp;
}
@@ -147,9 +150,10 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &s
}
// Support only single return-terminated block in the function.
- if (!isa<func::FuncOp>(funcOp))
- return success();
+ // If funcOp has no returnOp, skip the following analysis.
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (returnOp == funcOp)
+ return success();
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -300,9 +304,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
- if (!funcOp.getFunctionBody().empty() && isa<func::FuncOp>(funcOp)) {
+ if (!funcOp.getFunctionBody().empty()) {
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
+ if (!returnOp && returnOp != funcOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
@@ -356,7 +360,7 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) {
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
+ if (!returnOp || returnOp == funcOp)
return;
SmallVector<Type> resultTypes;
>From 074192ca0e62ba600f63de4e914d44fb4bf86ffb Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Fri, 6 Sep 2024 14:35:56 -0400
Subject: [PATCH 5/6] Use getNumResults to guard functions without any return
type
---
.../Transforms/OneShotModuleBufferize.cpp | 19 ++++---------------
1 file changed, 4 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index bd054ac4e7b87e..6933fde7f95657 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -88,7 +88,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
-/// Return `funcOp` it self if there is no ReturnOp.
static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
Operation *returnOp = nullptr;
for (Block &b : funcOp.getFunctionBody()) {
@@ -99,8 +98,6 @@ static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
returnOp = candidateOp;
}
}
- if (!returnOp)
- return funcOp;
return returnOp;
}
@@ -132,7 +129,7 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
static LogicalResult
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- if (funcOp.getFunctionBody().empty()) {
+ if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0) {
// No function body available. Conservatively assume that every tensor
// return value may alias with any tensor bbArg.
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
@@ -150,10 +147,7 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &s
}
// Support only single return-terminated block in the function.
- // If funcOp has no returnOp, skip the following analysis.
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
- if (returnOp == funcOp)
- return success();
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -304,9 +298,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
- if (!funcOp.getFunctionBody().empty()) {
+ if (!funcOp.getFunctionBody().empty() && funcOp.getNumResults() != 0) {
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp && returnOp != funcOp)
+ if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
@@ -355,14 +349,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(FunctionOpInterface funcOp) {
- if (funcOp.getFunctionBody().empty())
+ if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0)
return;
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
-
- if (!returnOp || returnOp == funcOp)
- return;
-
SmallVector<Type> resultTypes;
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -398,7 +388,6 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// Analyze ops.
for (FunctionOpInterface funcOp : orderedFuncOps) {
-
if (!state.getOptions().isOpAllowed(funcOp))
continue;
>From 4ba535b93e607698f3319cc5d13a3432fb0c67c4 Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at xanadu.ai>
Date: Tue, 10 Sep 2024 14:30:18 -0400
Subject: [PATCH 6/6] Update
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Co-authored-by: erick-xanadu <110487834+erick-xanadu at users.noreply.github.com>
---
.../Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 6933fde7f95657..bf29b7e86a46d9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -349,7 +349,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(FunctionOpInterface funcOp) {
- if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0)
+ if (funcOp.getFunctionBody().empty())
return;
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
More information about the Mlir-commits
mailing list