[Mlir-commits] [mlir] [MLIR] Make `OneShotModuleBufferize` use `OpInterface` (PR #107295)

Tzung-Han Juang llvmlistbot at llvm.org
Wed Sep 4 12:25:29 PDT 2024


https://github.com/tzunghanjuang created https://github.com/llvm/llvm-project/pull/107295

**Description:** 

`OneShotModuleBufferize` deals with the bufferization of `FuncOp`, `CallOp` and `ReturnOp` but they are hard-coded. Any custom function-like operations will not be handled. The PR replaces a part of `FuncOp` and  `CallOp` with `FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize` so that custom function ops and call ops can be bufferized.

**Limitations:** 

`ReturnOp` is not implemented with any interface. Right now we just create if cases for detecting `FuncOp` to trigger `ReturnOp` bufferization.

**Related Discord Discussion:** [Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900)



>From 8a5aca204bb7ed1a0a05f14994274a70f732b3d6 Mon Sep 17 00:00:00 2001
From: Tzung-Han Juang <tzunghan.juang at gmail.com>
Date: Wed, 4 Sep 2024 15:04:36 -0400
Subject: [PATCH] Make OneShotModuleBufferize accept FunctionOpInterface and
 CallOpInterface

---
 .../Transforms/OneShotModuleBufferize.cpp     | 81 ++++++++++++-------
 1 file changed, 50 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..2983af0fcbf3f7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
 using namespace mlir::bufferization::func_ext;
 
 /// A mapping of FuncOps to their callers.
-using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
+using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
 
 /// Get or create FuncAnalysisState.
 static FuncAnalysisState &
@@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
+static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+  if (!sym)
+    return nullptr;
+  return dyn_cast_or_null<FunctionOpInterface>(
+      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+}
+
 /// Gather equivalence info of CallOps.
 /// Note: This only adds new equivalence info if the called function was already
 /// analyzed.
@@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
 }
 
 /// Return "true" if the given function signature has tensor semantics.
-static bool hasTensorSignature(func::FuncOp funcOp) {
-  return llvm::any_of(funcOp.getFunctionType().getInputs(),
+static bool hasTensorSignature(FunctionOpInterface funcOp) {
+  return llvm::any_of(funcOp.getArgumentTypes(),
                       llvm::IsaPred<TensorType>) ||
-         llvm::any_of(funcOp.getFunctionType().getResults(),
+         llvm::any_of(funcOp.getResultTypes(),
                       llvm::IsaPred<TensorType>);
 }
 
@@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
 /// retrieve the called FuncOp from any func::CallOp.
 static LogicalResult
 getFuncOpsOrderedByCalls(ModuleOp moduleOp,
-                         SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+                         SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
                          FuncCallerMap &callerMap) {
   // For each FuncOp, the set of functions called by it (i.e. the union of
   // symbols of all nested func::CallOp).
-  DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
+  DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
-  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-  WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
-    if (!funcOp.getBody().empty()) {
-      func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-      if (!returnOp)
-        return funcOp->emitError()
-               << "cannot bufferize a FuncOp with tensors and "
-                  "without a unique ReturnOp";
+  DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
+  WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
+    // Only handle ReturnOp if funcOp is exactly the FuncOp type.
+    if(isa<FuncOp>(funcOp)) {
+      FuncOp funcOpCasted = cast<FuncOp>(funcOp);
+      if (!funcOpCasted.getBody().empty()) {
+        func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
+        if (!returnOp)
+          return funcOp->emitError()
+                 << "cannot bufferize a FuncOp with tensors and "
+                    "without a unique ReturnOp";
+      }
     }
 
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
-    return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp);
+    return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+      FunctionOpInterface calledFunction = getCalledFunction(callOp);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.
@@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
 
   // A list of functions in the order in which they are analyzed + bufferized.
-  SmallVector<func::FuncOp> orderedFuncOps;
+  SmallVector<FunctionOpInterface> orderedFuncOps;
 
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
@@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
     return failure();
 
   // Analyze ops.
-  for (func::FuncOp funcOp : orderedFuncOps) {
-    if (!state.getOptions().isOpAllowed(funcOp))
+  for (FunctionOpInterface funcOp : orderedFuncOps) {
+
+    // The following analysis is specific to the FuncOp type.
+    if(!isa<FuncOp>(funcOp))
+        continue;
+    FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
+
+    if (!state.getOptions().isOpAllowed(funcOpCasted))
       continue;
 
     // Now analyzing function.
-    funcState.startFunctionAnalysis(funcOp);
+    funcState.startFunctionAnalysis(funcOpCasted);
 
     // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, state, funcState);
+    equivalenceAnalysis(funcOpCasted, state, funcState);
 
     // Analyze funcOp.
-    if (failed(analyzeOp(funcOp, state, statistics)))
+    if (failed(analyzeOp(funcOpCasted, state, statistics)))
       return failure();
 
     // Run some extra function analyses.
-    if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
-        failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
+    if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
+        failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
       return failure();
 
     // Mark op as fully analyzed.
-    funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+    funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
   }
 
   return success();
@@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   IRRewriter rewriter(moduleOp.getContext());
 
   // A list of functions in the order in which they are analyzed + bufferized.
-  SmallVector<func::FuncOp> orderedFuncOps;
+  SmallVector<FunctionOpInterface> orderedFuncOps;
 
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
 
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return failure();
+  SmallVector<FunctionOpInterface> ops;
 
   // Bufferize functions.
-  for (func::FuncOp funcOp : orderedFuncOps) {
+  for (FunctionOpInterface funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-
-    if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+    if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
       // This function was not analyzed and RaW conflicts were not resolved.
       // Buffer copies must be inserted before every write.
       OneShotBufferizationOptions updatedOptions = options;
@@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
     }
 
     // Change buffer return types to more precise layout maps.
-    if (options.inferFunctionResultLayout)
-      foldMemRefCasts(funcOp);
+    if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
+      foldMemRefCasts(cast<func::FuncOp>(funcOp));
   }
 
   // Bufferize all other ops.



More information about the Mlir-commits mailing list