[Mlir-commits] [mlir] [MLIR] Avoid resolving callable outside the analysis scope in DeadCodeAnalysis (NFC) (PR #155088)

Mehdi Amini llvmlistbot at llvm.org
Wed Sep 10 13:55:13 PDT 2025


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/155088

>From 19cd55036aea76632e4e95745fc67d9376a136cb Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 23 Aug 2025 03:44:10 -0700
Subject: [PATCH] [MLIR] Avoid resolving callable outside the analysis scope in
 DataFlow

We are using the symbol table machinery to lookup for a callable, but when the
analysis scope if a function, such lookup will resolve outside of the scope.
This can lead to race-condition issues since other passes may operate in
parallel on the sibling functions.
In DeadCode Analysis, the callable would be discarded right after the lookup (we check the analysis
scope), so avoiding the lookup is NFC.

For the DataFlow solver, we're looking at the top-level operation, and if
it isn't a SymbolTable we disable the interprocedural optimization in the
solver config directly.
This strategy isn't NFC but seems reasonnable and does not encounter any
change in behavior in practice in tree.

Fix #154948
---
 .../mlir/Analysis/DataFlow/DeadCodeAnalysis.h |  7 ++++
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    | 37 ++++++++++++++-----
 mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp  | 22 +++++++----
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 10 +++--
 mlir/lib/Analysis/DataFlowFramework.cpp       |  7 ++++
 5 files changed, 62 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 2250db823b551..c7c405e1423cb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
   /// considered an external callable.
   Operation *analysisScope;
 
+  /// Whether the analysis scope has a symbol table. This is used to avoid
+  /// resolving callables outside the analysis scope.
+  /// It is updated when recursing into a region in case where the top-level
+  /// operation does not have a symbol table, but one is encountered in a nested
+  /// region.
+  bool hasSymbolTable = false;
+
   /// A symbol table used for O(1) symbol lookups during simplification.
   SymbolTableCollection symbolTable;
 };
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 9424eff3e6b6f..131c49c44171b 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
   LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
          << OpWithFlags(top, OpPrintingFlags().skipRegions());
   analysisScope = top;
+  hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
     LDBG() << "[init] Processing symbol table op: "
            << OpWithFlags(symTable, OpPrintingFlags().skipRegions());
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
       return failure();
   }
   // Recurse on nested operations.
-  for (Region &region : op->getRegions()) {
-    LDBG() << "[init] Recursing into region of op: "
-           << OpWithFlags(op, OpPrintingFlags().skipRegions());
-    for (Operation &nestedOp : region.getOps()) {
-      LDBG() << "[init] Recursing into nested op: "
-             << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
-      if (failed(initializeRecursively(&nestedOp)))
-        return failure();
+  if (op->getNumRegions()) {
+    // If we haven't seen a symbol table yet, check if the current operation
+    // has one. If so, update the flag to allow for resolving callables in
+    // nested regions.
+    bool savedHasSymbolTable = hasSymbolTable;
+    auto restoreHasSymbolTable =
+        llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
+    if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
+      hasSymbolTable = true;
+
+    for (Region &region : op->getRegions()) {
+      LDBG() << "[init] Recursing into region of op: "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
+      for (Operation &nestedOp : region.getOps()) {
+        LDBG() << "[init] Recursing into nested op: "
+               << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
+        if (failed(initializeRecursively(&nestedOp)))
+          return failure();
+      }
     }
   }
   LDBG() << "[init] Finished initializeRecursively for op: "
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
   LDBG() << "visitCallOperation: "
          << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
-  Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+
+  Operation *callableOp = nullptr;
+  if (hasSymbolTable)
+    callableOp = call.resolveCallableInTable(&symbolTable);
+  else
+    LDBG()
+        << "No symbol table present in analysis scope, can't resolve callable";
 
   // A call to a externally-defined callable has unknown predecessors.
   const auto isExternalCallable = [this](Operation *op) {
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index d05374f667a51..b51465bc31ec3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -64,10 +64,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
     AbstractDenseLattice *after) {
   // Allow for customizing the behavior of calls to external symbols, including
   // when the analysis is explicitly marked as non-interprocedural.
-  auto callable =
-      dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
-  if (!getSolverConfig().isInterprocedural() ||
-      (callable && !callable.getCallableRegion())) {
+  auto isExternalCallable = [&]() {
+    auto callable =
+        dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+    return callable && !callable.getCallableRegion();
+  };
+  if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
     return visitCallControlFlowTransfer(
         call, CallControlFlowAction::ExternalCallee, before, after);
   }
@@ -290,6 +292,12 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
 void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
     CallOpInterface call, const AbstractDenseLattice &after,
     AbstractDenseLattice *before) {
+  // If the solver is not interprocedural, let the hook handle it as an external
+  // callee.
+  if (!getSolverConfig().isInterprocedural())
+    return visitCallControlFlowTransfer(
+        call, CallControlFlowAction::ExternalCallee, after, before);
+
   // Find the callee.
   Operation *callee = call.resolveCallableInTable(&symbolTable);
 
@@ -297,12 +305,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
   // No region means the callee is only declared in this module.
   // If that is the case or if the solver is not interprocedural,
   // let the hook handle it.
-  if (!getSolverConfig().isInterprocedural() ||
-      (callable && (!callable.getCallableRegion() ||
-                    callable.getCallableRegion()->empty()))) {
+  if (callable &&
+      (!callable.getCallableRegion() || callable.getCallableRegion()->empty()))
     return visitCallControlFlowTransfer(
         call, CallControlFlowAction::ExternalCallee, after, before);
-  }
 
   if (!callable)
     return setToExitState(before);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 13a3e1480c836..0d2e2ed85549d 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -228,10 +228,12 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
     ArrayRef<AbstractSparseLattice *> resultLattices) {
   // If the call operation is to an external function, attempt to infer the
   // results from the call arguments.
-  auto callable =
-      dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
-  if (!getSolverConfig().isInterprocedural() ||
-      (callable && !callable.getCallableRegion())) {
+  auto isExternalCallable = [&]() {
+    auto callable =
+        dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+    return callable && !callable.getCallableRegion();
+  };
+  if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
     visitExternalCallImpl(call, operandLattices, resultLattices);
     return success();
   }
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 7e1b4052027d3..9352ab02f7472 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/iterator.h"
@@ -109,6 +110,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
   isRunning = true;
   auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
 
+  bool isInterprocedural = config.isInterprocedural();
+  auto restoreInterprocedural = llvm::make_scope_exit(
+      [&]() { config.setInterprocedural(isInterprocedural); });
+  if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
+    config.setInterprocedural(false);
+
   // Initialize equivalent lattice anchors.
   for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
     analysis.initializeEquivalentLatticeAnchor(top);



More information about the Mlir-commits mailing list