[Mlir-commits] [mlir] cd80617 - [mlir][bufferize][NFC] Make func BufferizableOpInterface impl compatible with One-Shot Bufferize

Matthias Springer llvmlistbot at llvm.org
Wed Jun 15 01:05:24 PDT 2022


Author: Matthias Springer
Date: 2022-06-15T10:05:15+02:00
New Revision: cd80617a8afca896b78dc0d05cceb3e785503c38

URL: https://github.com/llvm/llvm-project/commit/cd80617a8afca896b78dc0d05cceb3e785503c38
DIFF: https://github.com/llvm/llvm-project/commit/cd80617a8afca896b78dc0d05cceb3e785503c38.diff

LOG: [mlir][bufferize][NFC] Make func BufferizableOpInterface impl compatible with One-Shot Bufferize

Bufferization of the func dialect must go through `OneShotModuleBufferize`. With this change, the analysis interface methods of the BufferizableOpInterface of func dialect ops can be used together with the normal `OneShotBufferize`. (In the absence of analysis information, they will return conservative results.)

Differential Revision: https://reviews.llvm.org/D127299

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 23846346d638..6cdd8b494215 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -111,9 +111,14 @@ getFuncAnalysisState(const AnalysisState &state) {
 /// Return the state (phase) of analysis of the FuncOp.
 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
                                                   FuncOp funcOp) {
-  const FuncAnalysisState &funcState = getFuncAnalysisState(state);
-  auto it = funcState.analyzedFuncOps.find(funcOp);
-  if (it == funcState.analyzedFuncOps.end())
+  Optional<const FuncAnalysisState *> maybeState =
+      state.getDialectState<FuncAnalysisState>(
+          func::FuncDialect::getDialectNamespace());
+  if (!maybeState.hasValue())
+    return FuncOpAnalysisState::NotAnalyzed;
+  const auto &analyzedFuncOps = maybeState.getValue()->analyzedFuncOps;
+  auto it = analyzedFuncOps.find(funcOp);
+  if (it == analyzedFuncOps.end())
     return FuncOpAnalysisState::NotAnalyzed;
   return it->second;
 }
@@ -145,11 +150,11 @@ struct CallOpInterface
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is read.
       return true;
 
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.readBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -160,11 +165,11 @@ struct CallOpInterface
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is written.
       return true;
 
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.writtenBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -174,7 +179,6 @@ struct CallOpInterface
     func::CallOp callOp = cast<func::CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) !=
         FuncOpAnalysisState::Analyzed) {
       // FuncOp not analyzed yet. Any OpResult may be aliasing.
@@ -186,6 +190,7 @@ struct CallOpInterface
     }
 
     // Get aliasing results from state.
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     auto aliasingReturnVals =
         funcState.aliasingReturnVals.lookup(funcOp).lookup(
             opOperand.getOperandNumber());
@@ -201,7 +206,6 @@ struct CallOpInterface
     func::CallOp callOp = cast<func::CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) !=
         FuncOpAnalysisState::Analyzed) {
       // FuncOp not analyzed yet. Any OpOperand may be aliasing.
@@ -213,6 +217,7 @@ struct CallOpInterface
     }
 
     // Get aliasing bbArgs from state.
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup(
         opResult.getResultNumber());
     SmallVector<OpOperand *> result;
@@ -226,13 +231,13 @@ struct CallOpInterface
     func::CallOp callOp = cast<func::CallOp>(op);
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     if (getFuncOpAnalysisState(state, funcOp) !=
         FuncOpAnalysisState::Analyzed) {
       // Function not analyzed yet. The conservative answer is "None".
       return BufferRelation::None;
     }
 
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     Optional<int64_t> maybeEquiv =
         getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber());
     if (maybeEquiv.hasValue()) {


        


More information about the Mlir-commits mailing list