[Mlir-commits] [mlir] [MLIR] Enforce symbol visibility during symbol lookup (PR #179370)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 2 18:04:25 PST 2026


https://github.com/neildhar created https://github.com/llvm/llvm-project/pull/179370

This PR is stacked on top of https://github.com/llvm/llvm-project/pull/179362. Please review that PR for the first commit.

Update symbol resolution to examine whether a nested symbol being resolved is private, and fail in that case. This ensures that we maintain invariants on symbol visibility that we depend on in optimisations.

>From ec446eff377a4dbd134a9af709d8f98e646f0870 Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Mon, 2 Feb 2026 15:59:52 -0800
Subject: [PATCH 1/2] [NFC][MLIR] Simplify lookup of nested symbols

---
 mlir/lib/IR/SymbolTable.cpp | 32 ++++++++++++--------------------
 1 file changed, 12 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 9f5dd2c9e3b72..4f8418fac7b9d 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -413,30 +413,22 @@ static LogicalResult lookupSymbolInImpl(
   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
 
   // Lookup the root reference for this symbol.
-  symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
-  if (!symbolTableOp)
+  auto *symbolOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
+  if (!symbolOp)
     return failure();
-  symbols.push_back(symbolTableOp);
+  symbols.push_back(symbolOp);
 
-  // If there are no nested references, just return the root symbol directly.
-  ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
-  if (nestedRefs.empty())
-    return success();
-
-  // Verify that the root is also a symbol table.
-  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
-    return failure();
-
-  // Otherwise, lookup each of the nested non-leaf references and ensure that
-  // each corresponds to a valid symbol table.
-  for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
-    symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
-    if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+  // Lookup each of the nested references.
+  for (FlatSymbolRefAttr ref : symbol.getNestedReferences()) {
+    // Check that we have a valid symbol table to lookup ref.
+    if (!symbolOp->hasTrait<OpTrait::SymbolTable>())
+      return failure();
+    symbolOp = lookupSymbolFn(symbolOp, ref.getAttr());
+    if (!symbolOp)
       return failure();
-    symbols.push_back(symbolTableOp);
+    symbols.push_back(symbolOp);
   }
-  symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
-  return success(symbols.back());
+  return success();
 }
 
 LogicalResult

>From 0f17b7290547abc2f2927a929c513c1328b45896 Mon Sep 17 00:00:00 2001
From: Neil Dhar <neildhar at meta.com>
Date: Mon, 2 Feb 2026 17:59:38 -0800
Subject: [PATCH 2/2] [MLIR] Enforce visibility during symbol lookup

---
 mlir/lib/IR/SymbolTable.cpp         |  4 +++-
 mlir/test/Dialect/GPU/invalid.mlir  |  2 +-
 mlir/test/IR/test-symbol-uses.mlir  | 25 +++++++++++++++++++++++++
 mlir/test/lib/IR/TestSymbolUses.cpp |  4 ++++
 4 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 4f8418fac7b9d..4e191e7d612ad 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -424,7 +424,9 @@ static LogicalResult lookupSymbolInImpl(
     if (!symbolOp->hasTrait<OpTrait::SymbolTable>())
       return failure();
     symbolOp = lookupSymbolFn(symbolOp, ref.getAttr());
-    if (!symbolOp)
+    // If the nested symbol is private, lookup failed.
+    if (!symbolOp || SymbolTable::getSymbolVisibility(symbolOp) ==
+                         SymbolTable::Visibility::Private)
       return failure();
     symbols.push_back(symbolOp);
   }
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index ad6ad7338ff38..9c338e06f1b26 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -135,7 +135,7 @@ module attributes {gpu.container_module} {
 module attributes {gpu.container_module} {
   gpu.module @kernels {
     // expected-note at +1 {{see the kernel definition here}}
-    memref.global "private" @kernel_1 : memref<4xi32>
+    memref.global @kernel_1 : memref<4xi32>
   }
 
   func.func @launch_func_undefined_function(%sz : index) {
diff --git a/mlir/test/IR/test-symbol-uses.mlir b/mlir/test/IR/test-symbol-uses.mlir
index 54e3ef1812510..d9d839e9fc307 100644
--- a/mlir/test/IR/test-symbol-uses.mlir
+++ b/mlir/test/IR/test-symbol-uses.mlir
@@ -68,3 +68,28 @@ func.func @symbol_bar() {
   "foo.possibly_unknown_symbol_table"() ({
   }) : () -> ()
 }
+
+// -----
+
+module {
+  // expected-remark at below {{symbol has 2 uses}}
+  module @inner_module {
+    // expected-remark at below {{symbol has 1 uses}}
+    func.func private @private_inner()
+    // expected-remark at below {{symbol has 1 uses}}
+    func.func nested @nested_inner()
+  }
+
+
+  // expected-remark at below {{symbol has no uses}}
+  // expected-remark at below {{symbol contains 2 nested references}}
+  func.func @outer_caller() {
+    // expected-remark at below {{found use of symbol : @inner_module::@nested_inner : "inner_module"}}
+    // expected-remark at below {{found use of symbol : @inner_module::@nested_inner : "nested_inner"}}
+    "foo.op"() { use = @inner_module::@nested_inner } : () -> ()
+    // expected-remark at below {{failed to resolve use of symbol : @inner_module::@private_inner : "inner_module"}}
+    // expected-remark at below {{failed to resolve use of symbol : @inner_module::@private_inner : "private_inner"}}
+    "foo.op"() { use = @inner_module::@private_inner } : () -> ()
+    return
+  }
+}
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index e841e142c6563..24f7d8505bebd 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -60,6 +60,10 @@ struct SymbolUsesPass
         symbolUse.getUser()->emitRemark()
             << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
             << symbol.getNameAttr();
+      } else {
+        symbolUse.getUser()->emitRemark()
+            << "failed to resolve use of symbol : " << symbolUse.getSymbolRef()
+            << " : " << symbol.getNameAttr();
       }
     }
     symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";



More information about the Mlir-commits mailing list