[Mlir-commits] [mlir] cb186bc - [mlir][bufferize][NFC] Rename ModuleAnalysisState to FuncAnalysisState

Matthias Springer llvmlistbot at llvm.org
Wed Apr 6 08:49:01 PDT 2022


Author: Matthias Springer
Date: 2022-04-07T00:48:53+09:00
New Revision: cb186bc5084ddace49e6eef2de2346b781391dc4

URL: https://github.com/llvm/llvm-project/commit/cb186bc5084ddace49e6eef2de2346b781391dc4
DIFF: https://github.com/llvm/llvm-project/commit/cb186bc5084ddace49e6eef2de2346b781391dc4.diff

LOG: [mlir][bufferize][NFC] Rename ModuleAnalysisState to FuncAnalysisState

This is for consistency reasons. `*AnalysisState` always starts with the name of the dialect.

Differential Revision: https://reviews.llvm.org/D123209

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 778199aa25202..5b0b77b76c89d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -16,8 +16,7 @@
 // respective callers.
 //
 // After analyzing a FuncOp, additional information about its bbArgs is
-// gathered through PostAnalysisStepFns and stored in
-// `ModuleAnalysisState`.
+// gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`.
 //
 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
 // for
@@ -93,7 +92,7 @@ enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
 
 /// Extra analysis state that is required for bufferization of function
 /// boundaries.
-struct ModuleAnalysisState : public DialectAnalysisState {
+struct FuncAnalysisState : public DialectAnalysisState {
   // Note: Function arguments and/or function return values may disappear during
   // bufferization. Functions and their CallOps are analyzed and bufferized
   // separately. To ensure that a CallOp analysis/bufferization can access an
@@ -162,26 +161,26 @@ struct ModuleAnalysisState : public DialectAnalysisState {
 };
 } // namespace
 
-/// Get ModuleAnalysisState.
-static const ModuleAnalysisState &
-getModuleAnalysisState(const AnalysisState &state) {
-  Optional<const ModuleAnalysisState *> maybeState =
-      state.getDialectState<ModuleAnalysisState>(
+/// Get FuncAnalysisState.
+static const FuncAnalysisState &
+getFuncAnalysisState(const AnalysisState &state) {
+  Optional<const FuncAnalysisState *> maybeState =
+      state.getDialectState<FuncAnalysisState>(
           func::FuncDialect::getDialectNamespace());
-  assert(maybeState.hasValue() && "ModuleAnalysisState does not exist");
+  assert(maybeState.hasValue() && "FuncAnalysisState does not exist");
   return **maybeState;
 }
 
-/// Get or create ModuleAnalysisState.
-static ModuleAnalysisState &getModuleAnalysisState(AnalysisState &state) {
-  return state.getOrCreateDialectState<ModuleAnalysisState>(
+/// Get or create FuncAnalysisState.
+static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) {
+  return state.getOrCreateDialectState<FuncAnalysisState>(
       func::FuncDialect::getDialectNamespace());
 }
 
 /// Return the state (phase) of analysis of the FuncOp.
 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
                                                   FuncOp funcOp) {
-  const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+  const FuncAnalysisState &moduleState = getFuncAnalysisState(state);
   auto it = moduleState.analyzedFuncOps.find(funcOp);
   if (it == moduleState.analyzedFuncOps.end())
     return FuncOpAnalysisState::NotAnalyzed;
@@ -226,12 +225,12 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
 }
 
 /// Store function BlockArguments that are equivalent to/aliasing a returned
-/// value in ModuleAnalysisState.
+/// value in FuncAnalysisState.
 static LogicalResult
 aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state,
                              BufferizationAliasInfo &aliasInfo,
                              SmallVector<Operation *> &newOps) {
-  ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+  FuncAnalysisState &funcState = getFuncAnalysisState(state);
 
   // Support only single return-terminated block in the function.
   auto funcOp = cast<FuncOp>(op);
@@ -245,14 +244,13 @@ aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state,
           int64_t returnIdx = returnVal.getOperandNumber();
           int64_t bbArgIdx = bbArg.getArgNumber();
           if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
-            moduleState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
+            funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
             if (state.getOptions().testAnalysisOnly)
               annotateEquivalentReturnBbArg(returnVal, bbArg);
           }
           if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) {
-            moduleState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx);
-            moduleState.aliasingReturnVals[funcOp][bbArgIdx].push_back(
-                returnIdx);
+            funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx);
+            funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
           }
         }
 
@@ -311,15 +309,15 @@ static LogicalResult
 funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
                              BufferizationAliasInfo &aliasInfo,
                              SmallVector<Operation *> &newOps) {
-  ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+  FuncAnalysisState &funcState = getFuncAnalysisState(state);
   auto funcOp = cast<FuncOp>(op);
 
   // If the function has no body, conservatively assume that all args are
   // read + written.
   if (funcOp.getBody().empty()) {
     for (BlockArgument bbArg : funcOp.getArguments()) {
-      moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
-      moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
+      funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
+      funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
     }
 
     return success();
@@ -333,9 +331,9 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
     if (state.getOptions().testAnalysisOnly)
       annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten);
     if (isRead)
-      moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
+      funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
     if (isWritten)
-      moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
+      funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
   }
 
   return success();
@@ -399,16 +397,16 @@ getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes,
 // TODO: This does not handle cyclic function call graphs etc.
 static void equivalenceAnalysis(FuncOp funcOp,
                                 BufferizationAliasInfo &aliasInfo,
-                                ModuleAnalysisState &moduleState) {
+                                FuncAnalysisState &funcState) {
   funcOp->walk([&](func::CallOp callOp) {
     FuncOp calledFunction = getCalledFunction(callOp);
     assert(calledFunction && "could not retrieved called FuncOp");
 
     // No equivalence info available for the called function.
-    if (!moduleState.equivalentFuncArgs.count(calledFunction))
+    if (!funcState.equivalentFuncArgs.count(calledFunction))
       return WalkResult::skip();
 
-    for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
+    for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
       int64_t returnIdx = it.first;
       int64_t bbargIdx = it.second;
       Value returnVal = callOp.getResult(returnIdx);
@@ -437,8 +435,8 @@ static void equivalenceAnalysis(FuncOp funcOp,
 static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
                                              RewriterBase &rewriter,
                                              BufferizationState &state) {
-  const ModuleAnalysisState &moduleState =
-      getModuleAnalysisState(state.getAnalysisState());
+  const FuncAnalysisState &funcState =
+      getFuncAnalysisState(state.getAnalysisState());
 
   // If nothing to do then we are done.
   if (!llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) &&
@@ -490,8 +488,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    auto funcOpIt = moduleState.equivalentFuncArgs.find(funcOp);
-    if (funcOpIt != moduleState.equivalentFuncArgs.end() &&
+    auto funcOpIt = funcState.equivalentFuncArgs.find(funcOp);
+    if (funcOpIt != funcState.equivalentFuncArgs.end() &&
         funcOpIt->second.count(returnOperand.getOperandNumber()))
       continue;
 
@@ -726,9 +724,9 @@ namespace std_ext {
 
 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
 /// specified return value (if any).
-static Optional<int64_t>
-getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleAnalysisState &state,
-                        int64_t returnValIdx) {
+static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp,
+                                                 const FuncAnalysisState &state,
+                                                 int64_t returnValIdx) {
   auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
   if (funcOpIt == state.equivalentFuncArgs.end())
     // No equivalence info stores for funcOp.
@@ -751,12 +749,12 @@ struct CallOpInterface
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
 
-    const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is read.
       return true;
 
-    return moduleState.readBbArgs.lookup(funcOp).contains(
+    return funcState.readBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
 
@@ -766,12 +764,12 @@ struct CallOpInterface
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
 
-    const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is written.
       return true;
 
-    return moduleState.writtenBbArgs.lookup(funcOp).contains(
+    return funcState.writtenBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
 
@@ -780,7 +778,7 @@ struct CallOpInterface
     func::CallOp callOp = cast<func::CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
-    const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) !=
         FuncOpAnalysisState::Analyzed) {
       // FuncOp not analyzed yet. Any OpResult may be aliasing.
@@ -793,7 +791,7 @@ struct CallOpInterface
 
     // Get aliasing results from state.
     auto aliasingReturnVals =
-        moduleState.aliasingReturnVals.lookup(funcOp).lookup(
+        funcState.aliasingReturnVals.lookup(funcOp).lookup(
             opOperand.getOperandNumber());
     SmallVector<OpResult> result;
     for (int64_t resultIdx : aliasingReturnVals)
@@ -807,7 +805,7 @@ struct CallOpInterface
     func::CallOp callOp = cast<func::CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
-    const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) !=
         FuncOpAnalysisState::Analyzed) {
       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
@@ -819,7 +817,7 @@ struct CallOpInterface
     }
 
     // Get aliasing bbArgs from state.
-    auto aliasingFuncArgs = moduleState.aliasingFuncArgs.lookup(funcOp).lookup(
+    auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
         opResult.getResultNumber());
     SmallVector<OpOperand *> result;
     for (int64_t bbArgIdx : aliasingFuncArgs)
@@ -842,8 +840,8 @@ struct CallOpInterface
     unsigned numOperands = callOp->getNumOperands();
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
-    const ModuleAnalysisState &moduleState =
-        getModuleAnalysisState(state.getAnalysisState());
+    const FuncAnalysisState &funcState =
+        getFuncAnalysisState(state.getAnalysisState());
     const OneShotBufferizationOptions &options =
         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
 
@@ -885,7 +883,7 @@ struct CallOpInterface
       }
 
       if (Optional<int64_t> bbArgIdx =
-              getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
+              getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
         // Return operands that are equivalent to some bbArg, are not
         // returned.
         FailureOr<Value> bufferOrFailure =
@@ -1068,11 +1066,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
   IRRewriter rewriter(moduleOp.getContext());
   OneShotAnalysisState analysisState(moduleOp, options);
   BufferizationState bufferizationState(analysisState);
-  ModuleAnalysisState &moduleState = getModuleAnalysisState(analysisState);
+  FuncAnalysisState &funcState = getFuncAnalysisState(analysisState);
   BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo();
 
-  if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
-                                      moduleState.callerMap)))
+  if (failed(getFuncOpsOrderedByCalls(moduleOp, funcState.orderedFuncOps,
+                                      funcState.callerMap)))
     return failure();
 
   // Collect bbArg/return value information after the analysis.
@@ -1080,23 +1078,23 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
   options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis);
 
   // Analyze ops.
-  for (FuncOp funcOp : moduleState.orderedFuncOps) {
+  for (FuncOp funcOp : funcState.orderedFuncOps) {
     // No body => no analysis.
     if (funcOp.getBody().empty())
       continue;
 
     // Now analyzing function.
-    moduleState.startFunctionAnalysis(funcOp);
+    funcState.startFunctionAnalysis(funcOp);
 
     // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+    equivalenceAnalysis(funcOp, aliasInfo, funcState);
 
     // Analyze funcOp.
     if (failed(analyzeOp(funcOp, analysisState)))
       return failure();
 
     // Mark op as fully analyzed.
-    moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+    funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
 
     // Add annotations to function arguments.
     if (options.testAnalysisOnly)
@@ -1107,7 +1105,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
     return success();
 
   // Bufferize functions.
-  for (FuncOp funcOp : moduleState.orderedFuncOps) {
+  for (FuncOp funcOp : funcState.orderedFuncOps) {
     // No body => no analysis.
     if (!funcOp.getBody().empty())
       if (failed(bufferizeOp(funcOp, bufferizationState)))
@@ -1120,7 +1118,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
   }
 
   // Check result.
-  for (FuncOp funcOp : moduleState.orderedFuncOps) {
+  for (FuncOp funcOp : funcState.orderedFuncOps) {
     if (!options.allowReturnAllocs &&
         llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) {
           return t.isa<MemRefType, UnrankedMemRefType>();


        


More information about the Mlir-commits mailing list