[Mlir-commits] [mlir] [MLIR] Use cached symbol tables in `getFuncOpsOrderedByCalls` (PR #141967)

Michele Scuttari llvmlistbot at llvm.org
Thu May 29 08:52:46 PDT 2025


https://github.com/mscuttari updated https://github.com/llvm/llvm-project/pull/141967

>From da3f4b1170ed2bb343a1c2ae8ec766db5c1f829f Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 29 May 2025 17:36:17 +0200
Subject: [PATCH 1/2] Use cached symbol tables in `getFuncOpsOrderedByCalls`

---
 .../Transforms/OneShotModuleBufferize.cpp     | 28 ++++++++++++++-----
 1 file changed, 21 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index dee2af8271ce8..f7b72a8ab022b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -310,21 +310,19 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
 /// any func::CallOp.
 static LogicalResult getFuncOpsOrderedByCalls(
     ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
-    SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
+    SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
+    SymbolTableCollection &symbolTables) {
   // For each FuncOp, the set of functions called by it (i.e. the union of
   // symbols of all nested func::CallOp).
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
 
-  // TODO Avoid recomputing the symbol tables every time.
-  mlir::SymbolTableCollection symbolTable;
-
   for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
+      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.
@@ -362,6 +360,21 @@ static LogicalResult getFuncOpsOrderedByCalls(
   return success();
 }
 
+static LogicalResult getFuncOpsOrderedByCalls(
+    ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+    SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
+    OneShotAnalysisState &analysisState) {
+  auto *funcAnalysisState = analysisState.getExtension<FuncAnalysisState>();
+
+  if (funcAnalysisState)
+    return getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, remainingFuncOps,
+                                    callerMap, funcAnalysisState->symbolTables);
+
+  SymbolTableCollection symbolTables;
+  return getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, remainingFuncOps,
+                                  callerMap, symbolTables);
+}
+
 /// Helper function that extracts the source from a memref.cast. If the given
 /// value is not a memref.cast result, simply returns the given value.
 static Value unpackCast(Value v) {
@@ -458,7 +471,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   FuncCallerMap callerMap;
 
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
-                                      remainingFuncOps, callerMap)))
+                                      remainingFuncOps, callerMap, state)))
     return failure();
 
   // Analyze functions in order. Starting with functions that are not calling
@@ -534,7 +547,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   // each other recursively are bufferized in an unspecified order at the end.
   // We may use unnecessarily "complex" (in terms of layout map) buffer types.
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
-                                      remainingFuncOps, callerMap)))
+                                      remainingFuncOps, callerMap,
+                                      state.getSymbolTables())))
     return failure();
   llvm::append_range(orderedFuncOps, remainingFuncOps);
 

>From bc51f3a28da26468928cff4eb818e8c17d450edf Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 29 May 2025 17:52:33 +0200
Subject: [PATCH 2/2] Take symbol tables directly from the
 FunctionAnalysisState object

---
 .../Transforms/OneShotModuleBufferize.cpp      | 18 ++----------------
 1 file changed, 2 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index f7b72a8ab022b..fc6424a25ac70 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -360,21 +360,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
   return success();
 }
 
-static LogicalResult getFuncOpsOrderedByCalls(
-    ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
-    SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
-    OneShotAnalysisState &analysisState) {
-  auto *funcAnalysisState = analysisState.getExtension<FuncAnalysisState>();
-
-  if (funcAnalysisState)
-    return getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, remainingFuncOps,
-                                    callerMap, funcAnalysisState->symbolTables);
-
-  SymbolTableCollection symbolTables;
-  return getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, remainingFuncOps,
-                                  callerMap, symbolTables);
-}
-
 /// Helper function that extracts the source from a memref.cast. If the given
 /// value is not a memref.cast result, simply returns the given value.
 static Value unpackCast(Value v) {
@@ -471,7 +456,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   FuncCallerMap callerMap;
 
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
-                                      remainingFuncOps, callerMap, state)))
+                                      remainingFuncOps, callerMap,
+                                      funcState.symbolTables)))
     return failure();
 
   // Analyze functions in order. Starting with functions that are not calling



More information about the Mlir-commits mailing list