[Mlir-commits] [mlir] [MLIR] Avoid resolving callable outside the analysis scope in DeadCodeAnalysis (NFC) (PR #155088)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 23 03:48:03 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

We are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions.
The callable would be discarded right after the lookup (we check the analysis scope), so avoiding the lookup is NFC.

Fix #<!-- -->154948

---
Full diff: https://github.com/llvm/llvm-project/pull/155088.diff


2 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h (+7) 
- (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+28-9) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 2250db823b551..c7c405e1423cb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
   /// considered an external callable.
   Operation *analysisScope;
 
+  /// Whether the analysis scope has a symbol table. This is used to avoid
+  /// resolving callables outside the analysis scope.
+  /// It is updated when recursing into a region in case where the top-level
+  /// operation does not have a symbol table, but one is encountered in a nested
+  /// region.
+  bool hasSymbolTable = false;
+
   /// A symbol table used for O(1) symbol lookups during simplification.
   SymbolTableCollection symbolTable;
 };
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 9424eff3e6b6f..131c49c44171b 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
   LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
          << OpWithFlags(top, OpPrintingFlags().skipRegions());
   analysisScope = top;
+  hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
     LDBG() << "[init] Processing symbol table op: "
            << OpWithFlags(symTable, OpPrintingFlags().skipRegions());
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
       return failure();
   }
   // Recurse on nested operations.
-  for (Region &region : op->getRegions()) {
-    LDBG() << "[init] Recursing into region of op: "
-           << OpWithFlags(op, OpPrintingFlags().skipRegions());
-    for (Operation &nestedOp : region.getOps()) {
-      LDBG() << "[init] Recursing into nested op: "
-             << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
-      if (failed(initializeRecursively(&nestedOp)))
-        return failure();
+  if (op->getNumRegions()) {
+    // If we haven't seen a symbol table yet, check if the current operation
+    // has one. If so, update the flag to allow for resolving callables in
+    // nested regions.
+    bool savedHasSymbolTable = hasSymbolTable;
+    auto restoreHasSymbolTable =
+        llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
+    if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
+      hasSymbolTable = true;
+
+    for (Region &region : op->getRegions()) {
+      LDBG() << "[init] Recursing into region of op: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
+      for (Operation &nestedOp : region.getOps()) {
+        LDBG() << "[init] Recursing into nested op: "
+               << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
+        if (failed(initializeRecursively(&nestedOp)))
+          return failure();
+      }
     }
   }
   LDBG() << "[init] Finished initializeRecursively for op: "
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
   LDBG() << "visitCallOperation: "
          << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
-  Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+
+  Operation *callableOp = nullptr;
+  if (hasSymbolTable)
+    callableOp = call.resolveCallableInTable(&symbolTable);
+  else
+    LDBG()
+        << "No symbol table present in analysis scope, can't resolve callable";
 
   // A call to a externally-defined callable has unknown predecessors.
   const auto isExternalCallable = [this](Operation *op) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/155088


More information about the Mlir-commits mailing list