[Mlir-commits] [mlir] [MLIR] Make `OneShotModuleBufferize` use `OpInterface` (PR #107295)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 4 12:26:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Tzung-Han Juang (tzunghanjuang)

<details>
<summary>Changes</summary>

**Description:** 

`OneShotModuleBufferize` deals with the bufferization of `FuncOp`, `CallOp` and `ReturnOp` but they are hard-coded. Any custom function-like operations will not be handled. The PR replaces a part of `FuncOp` and  `CallOp` with `FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize` so that custom function ops and call ops can be bufferized.

**Limitations:** 

`ReturnOp` is not implemented with any interface. Right now we just create if cases for detecting `FuncOp` to trigger `ReturnOp` bufferization.

**Related Discord Discussion:** [Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900)



---
Full diff: https://github.com/llvm/llvm-project/pull/107295.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+50-31) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..2983af0fcbf3f7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
 using namespace mlir::bufferization::func_ext;
 
 /// A mapping of FuncOps to their callers.
-using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
+using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
 
 /// Get or create FuncAnalysisState.
 static FuncAnalysisState &
@@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
+static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+  if (!sym)
+    return nullptr;
+  return dyn_cast_or_null<FunctionOpInterface>(
+      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+}
+
 /// Gather equivalence info of CallOps.
 /// Note: This only adds new equivalence info if the called function was already
 /// analyzed.
@@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
 }
 
 /// Return "true" if the given function signature has tensor semantics.
-static bool hasTensorSignature(func::FuncOp funcOp) {
-  return llvm::any_of(funcOp.getFunctionType().getInputs(),
+static bool hasTensorSignature(FunctionOpInterface funcOp) {
+  return llvm::any_of(funcOp.getArgumentTypes(),
                       llvm::IsaPred<TensorType>) ||
-         llvm::any_of(funcOp.getFunctionType().getResults(),
+         llvm::any_of(funcOp.getResultTypes(),
                       llvm::IsaPred<TensorType>);
 }
 
@@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
 /// retrieve the called FuncOp from any func::CallOp.
 static LogicalResult
 getFuncOpsOrderedByCalls(ModuleOp moduleOp,
-                         SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+                         SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
                          FuncCallerMap &callerMap) {
   // For each FuncOp, the set of functions called by it (i.e. the union of
   // symbols of all nested func::CallOp).
-  DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
+  DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
-  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-  WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
-    if (!funcOp.getBody().empty()) {
-      func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-      if (!returnOp)
-        return funcOp->emitError()
-               << "cannot bufferize a FuncOp with tensors and "
-                  "without a unique ReturnOp";
+  DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
+  WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
+    // Only handle ReturnOp if funcOp is exactly the FuncOp type.
+    if(isa<FuncOp>(funcOp)) {
+      FuncOp funcOpCasted = cast<FuncOp>(funcOp);
+      if (!funcOpCasted.getBody().empty()) {
+        func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
+        if (!returnOp)
+          return funcOp->emitError()
+                 << "cannot bufferize a FuncOp with tensors and "
+                    "without a unique ReturnOp";
+      }
     }
 
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
-    return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp);
+    return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+      FunctionOpInterface calledFunction = getCalledFunction(callOp);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.
@@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
 
   // A list of functions in the order in which they are analyzed + bufferized.
-  SmallVector<func::FuncOp> orderedFuncOps;
+  SmallVector<FunctionOpInterface> orderedFuncOps;
 
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
@@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
     return failure();
 
   // Analyze ops.
-  for (func::FuncOp funcOp : orderedFuncOps) {
-    if (!state.getOptions().isOpAllowed(funcOp))
+  for (FunctionOpInterface funcOp : orderedFuncOps) {
+
+    // The following analysis is specific to the FuncOp type.
+    if(!isa<FuncOp>(funcOp))
+        continue;
+    FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
+
+    if (!state.getOptions().isOpAllowed(funcOpCasted))
       continue;
 
     // Now analyzing function.
-    funcState.startFunctionAnalysis(funcOp);
+    funcState.startFunctionAnalysis(funcOpCasted);
 
     // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, state, funcState);
+    equivalenceAnalysis(funcOpCasted, state, funcState);
 
     // Analyze funcOp.
-    if (failed(analyzeOp(funcOp, state, statistics)))
+    if (failed(analyzeOp(funcOpCasted, state, statistics)))
       return failure();
 
     // Run some extra function analyses.
-    if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
-        failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
+    if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
+        failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
       return failure();
 
     // Mark op as fully analyzed.
-    funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+    funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
   }
 
   return success();
@@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   IRRewriter rewriter(moduleOp.getContext());
 
   // A list of functions in the order in which they are analyzed + bufferized.
-  SmallVector<func::FuncOp> orderedFuncOps;
+  SmallVector<FunctionOpInterface> orderedFuncOps;
 
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
 
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return failure();
+  SmallVector<FunctionOpInterface> ops;
 
   // Bufferize functions.
-  for (func::FuncOp funcOp : orderedFuncOps) {
+  for (FunctionOpInterface funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-
-    if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+    if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
       // This function was not analyzed and RaW conflicts were not resolved.
       // Buffer copies must be inserted before every write.
       OneShotBufferizationOptions updatedOptions = options;
@@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
     }
 
     // Change buffer return types to more precise layout maps.
-    if (options.inferFunctionResultLayout)
-      foldMemRefCasts(funcOp);
+    if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
+      foldMemRefCasts(cast<func::FuncOp>(funcOp));
   }
 
   // Bufferize all other ops.

``````````

</details>


https://github.com/llvm/llvm-project/pull/107295


More information about the Mlir-commits mailing list