[Mlir-commits] [mlir] cb186bc - [mlir][bufferize][NFC] Rename ModuleAnalysisState to FuncAnalysisState
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 6 08:49:01 PDT 2022
Author: Matthias Springer
Date: 2022-04-07T00:48:53+09:00
New Revision: cb186bc5084ddace49e6eef2de2346b781391dc4
URL: https://github.com/llvm/llvm-project/commit/cb186bc5084ddace49e6eef2de2346b781391dc4
DIFF: https://github.com/llvm/llvm-project/commit/cb186bc5084ddace49e6eef2de2346b781391dc4.diff
LOG: [mlir][bufferize][NFC] Rename ModuleAnalysisState to FuncAnalysisState
This is for consistency reasons. `*AnalysisState` always starts with the name of the dialect.
Differential Revision: https://reviews.llvm.org/D123209
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 778199aa25202..5b0b77b76c89d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -16,8 +16,7 @@
// respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
-// gathered through PostAnalysisStepFns and stored in
-// `ModuleAnalysisState`.
+// gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`.
//
// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
// for
@@ -93,7 +92,7 @@ enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
/// Extra analysis state that is required for bufferization of function
/// boundaries.
-struct ModuleAnalysisState : public DialectAnalysisState {
+struct FuncAnalysisState : public DialectAnalysisState {
// Note: Function arguments and/or function return values may disappear during
// bufferization. Functions and their CallOps are analyzed and bufferized
// separately. To ensure that a CallOp analysis/bufferization can access an
@@ -162,26 +161,26 @@ struct ModuleAnalysisState : public DialectAnalysisState {
};
} // namespace
-/// Get ModuleAnalysisState.
-static const ModuleAnalysisState &
-getModuleAnalysisState(const AnalysisState &state) {
- Optional<const ModuleAnalysisState *> maybeState =
- state.getDialectState<ModuleAnalysisState>(
+/// Get FuncAnalysisState.
+static const FuncAnalysisState &
+getFuncAnalysisState(const AnalysisState &state) {
+ Optional<const FuncAnalysisState *> maybeState =
+ state.getDialectState<FuncAnalysisState>(
func::FuncDialect::getDialectNamespace());
- assert(maybeState.hasValue() && "ModuleAnalysisState does not exist");
+ assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
return **maybeState;
}
-/// Get or create ModuleAnalysisState.
-static ModuleAnalysisState &getModuleAnalysisState(AnalysisState &state) {
- return state.getOrCreateDialectState<ModuleAnalysisState>(
+/// Get or create FuncAnalysisState.
+static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) {
+ return state.getOrCreateDialectState<FuncAnalysisState>(
func::FuncDialect::getDialectNamespace());
}
/// Return the state (phase) of analysis of the FuncOp.
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
FuncOp funcOp) {
- const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ const FuncAnalysisState &moduleState = getFuncAnalysisState(state);
auto it = moduleState.analyzedFuncOps.find(funcOp);
if (it == moduleState.analyzedFuncOps.end())
return FuncOpAnalysisState::NotAnalyzed;
@@ -226,12 +225,12 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
}
/// Store function BlockArguments that are equivalent to/aliasing a returned
-/// value in ModuleAnalysisState.
+/// value in FuncAnalysisState.
static LogicalResult
aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
- ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ FuncAnalysisState &funcState = getFuncAnalysisState(state);
// Support only single return-terminated block in the function.
auto funcOp = cast<FuncOp>(op);
@@ -245,14 +244,13 @@ aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state,
int64_t returnIdx = returnVal.getOperandNumber();
int64_t bbArgIdx = bbArg.getArgNumber();
if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
- moduleState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
+ funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
if (state.getOptions().testAnalysisOnly)
annotateEquivalentReturnBbArg(returnVal, bbArg);
}
if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) {
- moduleState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx);
- moduleState.aliasingReturnVals[funcOp][bbArgIdx].push_back(
- returnIdx);
+ funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx);
+ funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
}
}
@@ -311,15 +309,15 @@ static LogicalResult
funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
- ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ FuncAnalysisState &funcState = getFuncAnalysisState(state);
auto funcOp = cast<FuncOp>(op);
// If the function has no body, conservatively assume that all args are
// read + written.
if (funcOp.getBody().empty()) {
for (BlockArgument bbArg : funcOp.getArguments()) {
- moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
- moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
+ funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
+ funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
}
return success();
@@ -333,9 +331,9 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
if (state.getOptions().testAnalysisOnly)
annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten);
if (isRead)
- moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
+ funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
if (isWritten)
- moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
+ funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
}
return success();
@@ -399,16 +397,16 @@ getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes,
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(FuncOp funcOp,
BufferizationAliasInfo &aliasInfo,
- ModuleAnalysisState &moduleState) {
+ FuncAnalysisState &funcState) {
funcOp->walk([&](func::CallOp callOp) {
FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called FuncOp");
// No equivalence info available for the called function.
- if (!moduleState.equivalentFuncArgs.count(calledFunction))
+ if (!funcState.equivalentFuncArgs.count(calledFunction))
return WalkResult::skip();
- for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
+ for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
int64_t returnIdx = it.first;
int64_t bbargIdx = it.second;
Value returnVal = callOp.getResult(returnIdx);
@@ -437,8 +435,8 @@ static void equivalenceAnalysis(FuncOp funcOp,
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
RewriterBase &rewriter,
BufferizationState &state) {
- const ModuleAnalysisState &moduleState =
- getModuleAnalysisState(state.getAnalysisState());
+ const FuncAnalysisState &funcState =
+ getFuncAnalysisState(state.getAnalysisState());
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) &&
@@ -490,8 +488,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
}
// If return operand is equivalent to some bbArg, no need to return it.
- auto funcOpIt = moduleState.equivalentFuncArgs.find(funcOp);
- if (funcOpIt != moduleState.equivalentFuncArgs.end() &&
+ auto funcOpIt = funcState.equivalentFuncArgs.find(funcOp);
+ if (funcOpIt != funcState.equivalentFuncArgs.end() &&
funcOpIt->second.count(returnOperand.getOperandNumber()))
continue;
@@ -726,9 +724,9 @@ namespace std_ext {
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
/// specified return value (if any).
-static Optional<int64_t>
-getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleAnalysisState &state,
- int64_t returnValIdx) {
+static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
+ const FuncAnalysisState &state,
+ int64_t returnValIdx) {
auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
if (funcOpIt == state.equivalentFuncArgs.end())
// No equivalence info stores for funcOp.
@@ -751,12 +749,12 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is read.
return true;
- return moduleState.readBbArgs.lookup(funcOp).contains(
+ return funcState.readBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -766,12 +764,12 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is written.
return true;
- return moduleState.writtenBbArgs.lookup(funcOp).contains(
+ return funcState.writtenBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -780,7 +778,7 @@ struct CallOpInterface
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
if (getFuncOpAnalysisState(state, funcOp) !=
FuncOpAnalysisState::Analyzed) {
// FuncOp not analyzed yet. Any OpResult may be aliasing.
@@ -793,7 +791,7 @@ struct CallOpInterface
// Get aliasing results from state.
auto aliasingReturnVals =
- moduleState.aliasingReturnVals.lookup(funcOp).lookup(
+ funcState.aliasingReturnVals.lookup(funcOp).lookup(
opOperand.getOperandNumber());
SmallVector<OpResult> result;
for (int64_t resultIdx : aliasingReturnVals)
@@ -807,7 +805,7 @@ struct CallOpInterface
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
if (getFuncOpAnalysisState(state, funcOp) !=
FuncOpAnalysisState::Analyzed) {
// FuncOp not analyzed yet. Any OpOperand may be aliasing.
@@ -819,7 +817,7 @@ struct CallOpInterface
}
// Get aliasing bbArgs from state.
- auto aliasingFuncArgs = moduleState.aliasingFuncArgs.lookup(funcOp).lookup(
+ auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
opResult.getResultNumber());
SmallVector<OpOperand *> result;
for (int64_t bbArgIdx : aliasingFuncArgs)
@@ -842,8 +840,8 @@ struct CallOpInterface
unsigned numOperands = callOp->getNumOperands();
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleAnalysisState &moduleState =
- getModuleAnalysisState(state.getAnalysisState());
+ const FuncAnalysisState &funcState =
+ getFuncAnalysisState(state.getAnalysisState());
const OneShotBufferizationOptions &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
@@ -885,7 +883,7 @@ struct CallOpInterface
}
if (Optional<int64_t> bbArgIdx =
- getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
+ getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not
// returned.
FailureOr<Value> bufferOrFailure =
@@ -1068,11 +1066,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
IRRewriter rewriter(moduleOp.getContext());
OneShotAnalysisState analysisState(moduleOp, options);
BufferizationState bufferizationState(analysisState);
- ModuleAnalysisState &moduleState = getModuleAnalysisState(analysisState);
+ FuncAnalysisState &funcState = getFuncAnalysisState(analysisState);
BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo();
- if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
- moduleState.callerMap)))
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, funcState.orderedFuncOps,
+ funcState.callerMap)))
return failure();
// Collect bbArg/return value information after the analysis.
@@ -1080,23 +1078,23 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis);
// Analyze ops.
- for (FuncOp funcOp : moduleState.orderedFuncOps) {
+ for (FuncOp funcOp : funcState.orderedFuncOps) {
// No body => no analysis.
if (funcOp.getBody().empty())
continue;
// Now analyzing function.
- moduleState.startFunctionAnalysis(funcOp);
+ funcState.startFunctionAnalysis(funcOp);
// Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+ equivalenceAnalysis(funcOp, aliasInfo, funcState);
// Analyze funcOp.
if (failed(analyzeOp(funcOp, analysisState)))
return failure();
// Mark op as fully analyzed.
- moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+ funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
// Add annotations to function arguments.
if (options.testAnalysisOnly)
@@ -1107,7 +1105,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
return success();
// Bufferize functions.
- for (FuncOp funcOp : moduleState.orderedFuncOps) {
+ for (FuncOp funcOp : funcState.orderedFuncOps) {
// No body => no analysis.
if (!funcOp.getBody().empty())
if (failed(bufferizeOp(funcOp, bufferizationState)))
@@ -1120,7 +1118,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
}
// Check result.
- for (FuncOp funcOp : moduleState.orderedFuncOps) {
+ for (FuncOp funcOp : funcState.orderedFuncOps) {
if (!options.allowReturnAllocs &&
llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) {
return t.isa<MemRefType, UnrankedMemRefType>();
More information about the Mlir-commits
mailing list