[Mlir-commits] [mlir] bb37296 - [MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jun 21 01:55:47 PDT 2025


Author: Michele Scuttari
Date: 2025-06-21T10:55:44+02:00
New Revision: bb372963dfcef9722d5aeb4f65ddb5c50be24e01

URL: https://github.com/llvm/llvm-project/commit/bb372963dfcef9722d5aeb4f65ddb5c50be24e01
DIFF: https://github.com/llvm/llvm-project/commit/bb372963dfcef9722d5aeb4f65ddb5c50be24e01.diff

LOG: [MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)

This PR allows to optionally speed up the lookup of symbols by providing a `SymbolTableCollection` instance to the interested conversion patterns. It is follow-up on the discussion about symbol / symbol table management carried on [Discourse](https://discourse.llvm.org/t/symbol-table-as-first-class-citizen-in-builders/86813).

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
    mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
    mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
    mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index 88f18022da9bb..2dfb6b03bcfcd 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -20,6 +20,7 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class Pass;
+class SymbolTableCollection;
 
 #define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
 /// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
 /// set to false, the program execution continues when a condition is
 /// unsatisfied.
-void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
-                                           RewritePatternSet &patterns,
-                                           bool abortOnFailure = true);
+void populateAssertToLLVMConversionPattern(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry);
 

diff  --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index b1ea2740c0605..e530b0a43b8e0 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -27,20 +27,23 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class SymbolTable;
+class SymbolTableCollection;
 
 /// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
 /// provided LLVMTypeConverter. Return failure if failed to so.
 FailureOr<LLVM::LLVMFuncOp>
 convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                           ConversionPatternRewriter &rewriter,
-                          const LLVMTypeConverter &converter);
+                          const LLVMTypeConverter &converter,
+                          SymbolTableCollection *symbolTables = nullptr);
 
 /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
 /// `emitCWrappers` is set, the pattern will also produce functions
 /// that pass memref descriptors by pointer-to-structure in addition to the
 /// default unpacked form.
 void populateFuncToLLVMFuncOpConversionPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables = nullptr);
 
 /// Collect the patterns to convert from the Func dialect to LLVM. The
 /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
 /// not an error to provide it anyway.
 void populateFuncToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    const SymbolTable *symbolTable = nullptr);
+    SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertFuncToLLVMInterface(DialectRegistry &registry);
 

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 33402301115b7..d7de40555bb6a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -17,6 +17,7 @@ namespace mlir {
 
 class OpBuilder;
 class LLVMTypeConverter;
+class SymbolTableCollection;
 
 namespace LLVM {
 
@@ -26,7 +27,8 @@ namespace LLVM {
 LogicalResult createPrintStrCall(
     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
     StringRef string, const LLVMTypeConverter &typeConverter,
-    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
+    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
+    SymbolTableCollection *symbolTables = nullptr);
 } // namespace LLVM
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
index 996a64baf9dd5..e93d5bdce7bf2 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
@@ -16,6 +16,7 @@ class DialectRegistry;
 class Pass;
 class LLVMTypeConverter;
 class RewritePatternSet;
+class SymbolTableCollection;
 
 #define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -23,7 +24,8 @@ class RewritePatternSet;
 /// Collect a set of patterns to convert memory-related operations from the
 /// MemRef dialect to the LLVM dialect.
 void populateFinalizeMemRefToLLVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 4a7ec6f2efe64..8ad9ed18acebd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -24,6 +24,7 @@ class OpBuilder;
 class Operation;
 class Type;
 class ValueRange;
+class SymbolTableCollection;
 
 namespace LLVM {
 class LLVMFuncOp;
@@ -33,55 +34,73 @@ class LLVMFuncOp;
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
 /// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
-                                                      Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
-                            std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
-                                                      Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
-                                                       Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
-                                                       Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
-                                                         Operation *moduleOp);
+                            std::optional<StringRef> runtimeFunctionName = {},
+                            SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+                           SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+                           SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                       SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                             SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
-                                                 Operation *moduleOp);
+lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+                     SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
+    OpBuilder &b, Operation *moduleOp, Type indexType,
+    SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
-                                    Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
-                                                        Operation *moduleOp);
+lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+                            SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
-                           Type unrankedDescriptorType);
+                           Type unrankedDescriptorType,
+                           SymbolTableCollection *symbolTables = nullptr);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 /// Return a failure if the FuncOp found has unexpected signature.
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                  ArrayRef<Type> paramTypes = {}, Type resultType = {},
-                 bool isVarArg = false, bool isReserved = false);
+                 bool isVarArg = false, bool isReserved = false,
+                 SymbolTableCollection *symbolTables = nullptr);
 
 } // namespace LLVM
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index d31d7d801e149..3d0804fd11b6b 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -44,9 +44,10 @@ namespace {
 /// lowering.
 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
-                            bool abortOnFailedAssert = true)
+                            bool abortOnFailedAssert = true,
+                            SymbolTableCollection *symbolTables = nullptr)
       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
-        abortOnFailedAssert(abortOnFailedAssert) {}
+        abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
@@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
     auto createResult = LLVM::createPrintStrCall(
         rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
         /*addNewLine=*/false,
-        /*runtimeFunctionName=*/"puts");
+        /*runtimeFunctionName=*/"puts", symbolTables);
     if (createResult.failed())
       return failure();
 
@@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   /// If set to `false`, messages are printed but program execution continues.
   /// This is useful for testing asserts.
   bool abortOnFailedAssert = true;
+
+  SymbolTableCollection *symbolTables = nullptr;
 };
 
 /// Helper function for converting branch ops. This function converts the
@@ -232,8 +235,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
 
 void mlir::cf::populateAssertToLLVMConversionPattern(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool abortOnFailure) {
-  patterns.add<AssertOpLowering>(converter, abortOnFailure);
+    bool abortOnFailure, SymbolTableCollection *symbolTables) {
+  patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 538016927256b..4499cbd4d1a20 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
   }
 }
 
-FailureOr<LLVM::LLVMFuncOp>
-mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
-                                ConversionPatternRewriter &rewriter,
-                                const LLVMTypeConverter &converter) {
+FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
+    FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
+    const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
   // Check the funcOp has `FunctionType`.
   auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
   if (!funcTy)
@@ -361,10 +360,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
 
   SmallVector<NamedAttribute, 4> attributes;
   filterFuncAttributes(funcOp, attributes);
+
+  Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+  if (symbolTables && symbolTableOp) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+    symbolTable.remove(funcOp);
+  }
+
   auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
       /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
       attributes);
+
+  if (symbolTables && symbolTableOp) {
+    auto ip = rewriter.getInsertionPoint();
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+    symbolTable.insert(newFuncOp, ip);
+  }
+
   cast<FunctionOpInterface>(newFuncOp.getOperation())
       .setVisibility(funcOp.getVisibility());
 
@@ -473,16 +487,20 @@ namespace {
 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
 /// information.
-struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
-  FuncOpConversion(const LLVMTypeConverter &converter)
-      : ConvertOpToLLVMPattern(converter) {}
+class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit FuncOpConversion(const LLVMTypeConverter &converter,
+                            SymbolTableCollection *symbolTables = nullptr)
+      : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
         cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
-        *getTypeConverter());
+        *getTypeConverter(), symbolTables);
     if (failed(newFuncOp))
       return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
 
@@ -591,11 +609,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
 
 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
 public:
-  CallOpLowering(const LLVMTypeConverter &typeConverter,
-                 // Can be nullptr.
-                 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
+  explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
+                          SymbolTableCollection *symbolTables = nullptr,
+                          PatternBenefit benefit = 1)
       : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
-        symbolTable(symbolTable) {}
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
@@ -603,10 +621,10 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
     bool useBarePtrCallConv = false;
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
       useBarePtrCallConv = true;
-    } else if (symbolTable != nullptr) {
+    } else if (symbolTables != nullptr) {
       // Fast lookup.
       Operation *callee =
-          symbolTable->lookup(callOp.getCalleeAttr().getValue());
+          symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
       useBarePtrCallConv =
           callee != nullptr && callee->hasAttr(barePtrAttrName);
     } else {
@@ -620,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
   }
 
 private:
-  const SymbolTable *symbolTable = nullptr;
+  SymbolTableCollection *symbolTables = nullptr;
 };
 
 struct CallIndirectOpLowering
@@ -731,16 +749,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
 } // namespace
 
 void mlir::populateFuncToLLVMFuncOpConversionPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<FuncOpConversion>(converter);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables) {
+  patterns.add<FuncOpConversion>(converter, symbolTables);
 }
 
 void mlir::populateFuncToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    const SymbolTable *symbolTable) {
-  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+    SymbolTableCollection *symbolTables) {
+  populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
   patterns.add<CallIndirectOpLowering>(converter);
-  patterns.add<CallOpLowering>(converter, symbolTable);
+  patterns.add<CallOpLowering>(converter, symbolTables);
   patterns.add<ConstantOpLowering>(converter);
   patterns.add<ReturnOpLowering>(converter);
 }
@@ -780,15 +799,11 @@ struct ConvertFuncToLLVMPass
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
 
-    std::optional<SymbolTable> optSymbolTable = std::nullopt;
-    const SymbolTable *symbolTable = nullptr;
-    if (!options.useBarePtrCallConv) {
-      optSymbolTable.emplace(m);
-      symbolTable = &optSymbolTable.value();
-    }
-
     RewritePatternSet patterns(&getContext());
-    populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
+    SymbolTableCollection symbolTables;
+
+    populateFuncToLLVMConversionPatterns(typeConverter, patterns,
+                                         &symbolTables);
 
     LLVMConversionTarget target(getContext());
     if (failed(applyPartialConversion(m, target, std::move(patterns))))

diff  --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 2815e05b3e11b..49c73fbc9dd79 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -17,8 +17,26 @@
 using namespace mlir;
 using namespace llvm;
 
-static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
-                                            StringRef symbolName) {
+/// Check if a given symbol name is already in use within the module operation.
+/// If no symbol with such name is present, then the same identifier is
+/// returned. Otherwise, a unique and yet unused identifier is computed starting
+/// from the requested one.
+static std::string
+ensureSymbolNameIsUnique(ModuleOp moduleOp, StringRef symbolName,
+                         SymbolTableCollection *symbolTables = nullptr) {
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+    unsigned counter = 0;
+    SmallString<128> uniqueName = symbolTable.generateSymbolName<128>(
+        symbolName,
+        [&](const SmallString<128> &tentativeName) {
+          return symbolTable.lookupSymbolIn(moduleOp, tentativeName) != nullptr;
+        },
+        counter);
+
+    return static_cast<std::string>(uniqueName);
+  }
+
   static int counter = 0;
   std::string uniqueName = std::string(symbolName);
   while (moduleOp.lookupSymbol(uniqueName)) {
@@ -30,7 +48,8 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
 LogicalResult mlir::LLVM::createPrintStrCall(
     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
     StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
-    std::optional<StringRef> runtimeFunctionName) {
+    std::optional<StringRef> runtimeFunctionName,
+    SymbolTableCollection *symbolTables) {
   auto ip = builder.saveInsertionPoint();
   builder.setInsertionPointToStart(moduleOp.getBody());
   MLIRContext *ctx = builder.getContext();
@@ -49,7 +68,7 @@ LogicalResult mlir::LLVM::createPrintStrCall(
       LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
   auto globalOp = builder.create<LLVM::GlobalOp>(
       loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
-      ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+      ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr);
 
   auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
   // Emit call to `printStr` in runtime library.

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 8ccf1bfc292d5..e8294a5234c4f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -48,35 +48,39 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
-          ModuleOp module) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
+          SymbolTableCollection *symbolTables) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
-    return LLVM::lookupOrCreateGenericFreeFn(b, module);
+    return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
 
-  return LLVM::lookupOrCreateFreeFn(b, module);
+  return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
 getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
-                     Operation *module, Type indexType) {
+                     Operation *module, Type indexType,
+                     SymbolTableCollection *symbolTables) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
   if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
+    return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
+                                              symbolTables);
 
-  return LLVM::lookupOrCreateMallocFn(b, module, indexType);
+  return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
 getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
-                  Operation *module, Type indexType) {
+                  Operation *module, Type indexType,
+                  SymbolTableCollection *symbolTables) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
+    return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
+                                                     symbolTables);
 
-  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
+  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
 }
 
 /// Computes the aligned value for 'input' as follows:
@@ -126,8 +130,15 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
   return allocatedPtr;
 }
 
-struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
-  using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
+                           SymbolTableCollection *symbolTables = nullptr,
+                           PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
@@ -138,9 +149,10 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
       return rewriter.notifyMatchFailure(op, "incompatible memref type");
 
     // Get or insert alloc function into the module.
-    FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
-        rewriter, getTypeConverter(),
-        op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+    FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
+        getNotalignedAllocFn(rewriter, getTypeConverter(),
+                             op->getParentWithTrait<OpTrait::SymbolTable>(),
+                             getIndexType(), symbolTables);
     if (failed(allocFuncOp))
       return failure();
 
@@ -210,8 +222,15 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
   }
 };
 
-struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
-  using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
+                                  SymbolTableCollection *symbolTables = nullptr,
+                                  PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
@@ -222,9 +241,10 @@ struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
       return rewriter.notifyMatchFailure(op, "incompatible memref type");
 
     // Get or insert alloc function into module.
-    FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
-        rewriter, getTypeConverter(),
-        op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+    FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
+        getAlignedAllocFn(rewriter, getTypeConverter(),
+                          op->getParentWithTrait<OpTrait::SymbolTable>(),
+                          getIndexType(), symbolTables);
     if (failed(allocFuncOp))
       return failure();
 
@@ -446,18 +466,23 @@ struct AssumeAlignmentOpLowering
 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
 // The memref descriptor being an SSA value, there is no need to clean it up
 // in any way.
-struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
-  using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
+class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
+  SymbolTableCollection *symbolTables = nullptr;
 
-  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
-      : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
+public:
+  explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
+                             SymbolTableCollection *symbolTables = nullptr,
+                             PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
-    FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
-        rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
+    FailureOr<LLVM::LLVMFuncOp> freeFunc =
+        getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
+                  symbolTables);
     if (failed(freeFunc))
       return failure();
     Value allocatedPtr;
@@ -710,9 +735,15 @@ convertGlobalMemrefTypeToLLVM(MemRefType type,
 }
 
 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
-struct GlobalMemrefOpLowering
-    : public ConvertOpToLLVMPattern<memref::GlobalOp> {
-  using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
+class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
+                                  SymbolTableCollection *symbolTables = nullptr,
+                                  PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
@@ -743,9 +774,31 @@ struct GlobalMemrefOpLowering
     if (failed(addressSpace))
       return global.emitOpError(
           "memory space cannot be converted to an integer address space");
+
+    if (symbolTables) {
+      Operation *symbolTableOp =
+          global->getParentWithTrait<OpTrait::SymbolTable>();
+
+      if (symbolTableOp) {
+        SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+        symbolTable.remove(global);
+      }
+    }
+
     auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
         initialValue, alignment, *addressSpace);
+
+    if (symbolTables) {
+      Operation *symbolTableOp =
+          global->getParentWithTrait<OpTrait::SymbolTable>();
+
+      if (symbolTableOp) {
+        SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+        symbolTable.insert(newGlobal, rewriter.getInsertionPoint());
+      }
+    }
+
     if (!global.isExternal() && global.isUninitialized()) {
       rewriter.createBlock(&newGlobal.getInitializerRegion());
       Value undef[] = {
@@ -997,8 +1050,15 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
 /// For memrefs with identity layouts, the copy is lowered to the llvm
 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
 /// to the generic `MemrefCopyFn`.
-struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
-  using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
+class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
+                                SymbolTableCollection *symbolTables = nullptr,
+                                PatternBenefit benefit = 1)
+      : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
+        symbolTables(symbolTables) {}
 
   LogicalResult
   lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
@@ -1093,7 +1153,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
         rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
-        sourcePtr.getType());
+        sourcePtr.getType(), symbolTables);
     if (failed(copyFn))
       return failure();
     rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
@@ -1928,7 +1988,8 @@ class ExtractStridedMetadataOpLowering
 } // namespace
 
 void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables) {
   // clang-format off
   patterns.add<
       AllocaOpLowering,
@@ -1939,11 +2000,9 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
       DimOpLowering,
       ExtractStridedMetadataOpLowering,
       GenericAtomicRMWOpLowering,
-      GlobalMemrefOpLowering,
       GetGlobalMemrefOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,
-      MemRefCopyOpLowering,
       MemorySpaceCastOpLowering,
       MemRefReinterpretCastOpLowering,
       MemRefReshapeOpLowering,
@@ -1956,11 +2015,14 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
       TransposeOpLowering,
       ViewOpLowering>(converter);
   // clang-format on
+  patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
+                                                             symbolTables);
   auto allocLowering = converter.getOptions().allocLowering;
   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
-    patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
+    patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
+                                                            symbolTables);
   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
-    patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
+    patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
 }
 
 namespace {
@@ -1987,7 +2049,9 @@ struct FinalizeMemRefToLLVMConversionPass
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
     RewritePatternSet patterns(&getContext());
-    populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
+    SymbolTableCollection symbolTables;
+    populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns,
+                                                   &symbolTables);
     LLVMConversionTarget target(getContext());
     target.addLegalOp<func::FuncOp>();
     if (failed(applyPartialConversion(op, target, std::move(patterns))))

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..d53d11f87efe8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1595,8 +1595,14 @@ class VectorCreateMaskOpConversion
 };
 
 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
 public:
-  using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
+  explicit VectorPrintOpConversion(
+      const LLVMTypeConverter &typeConverter,
+      SymbolTableCollection *symbolTables = nullptr)
+      : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
+        symbolTables(symbolTables) {}
 
   // Lowering implementation that relies on a small runtime support library,
   // which only needs to provide a few printing methods (single value for all
@@ -1643,13 +1649,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       FailureOr<LLVM::LLVMFuncOp> op = [&]() {
         switch (punct) {
         case PrintPunctuation::Close:
-          return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
+          return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
+                                                  symbolTables);
         case PrintPunctuation::Open:
-          return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
+          return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
+                                                 symbolTables);
         case PrintPunctuation::Comma:
-          return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
+          return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
+                                                  symbolTables);
         case PrintPunctuation::NewLine:
-          return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
+          return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
+                                                    symbolTables);
         default:
           llvm_unreachable("unexpected punctuation");
         }
@@ -1683,17 +1693,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     PrintConversion conversion = PrintConversion::None;
     FailureOr<Operation *> printer;
     if (printType.isF32()) {
-      printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
+      printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
     } else if (printType.isF64()) {
-      printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
+      printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
     } else if (printType.isF16()) {
       conversion = PrintConversion::Bitcast16; // bits!
-      printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
+      printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
     } else if (printType.isBF16()) {
       conversion = PrintConversion::Bitcast16; // bits!
-      printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
+      printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
     } else if (printType.isIndex()) {
-      printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
+      printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
     } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
       // Integers need a zero or sign extension on the operand
       // (depending on the source type) as well as a signed or
@@ -1703,7 +1713,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         if (width <= 64) {
           if (width < 64)
             conversion = PrintConversion::ZeroExt64;
-          printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
+          printer =
+              LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
         } else {
           return failure();
         }
@@ -1716,7 +1727,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
             conversion = PrintConversion::ZeroExt64;
           else if (width < 64)
             conversion = PrintConversion::SignExt64;
-          printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
+          printer =
+              LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
         } else {
           return failure();
         }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 1b4a8f496d3d0..89f765dacda35 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -44,15 +44,31 @@ static constexpr llvm::StringRef kGenericAlignedAlloc =
 static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
+namespace {
+/// Search for an LLVMFuncOp with a given name within an operation with the
+/// SymbolTable trait. An optional collection of cached symbol tables can be
+/// given to avoid a linear scan of the symbol table operation.
+LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
+                              SymbolTableCollection *symbolTables = nullptr) {
+  if (symbolTables) {
+    return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
+        symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
+  }
+
+  return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(symbolTableOp, name));
+}
+} // namespace
+
 /// Generic print function lookupOrCreate helper.
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                              ArrayRef<Type> paramTypes, Type resultType,
-                             bool isVarArg, bool isReserved) {
+                             bool isVarArg, bool isReserved,
+                             SymbolTableCollection *symbolTables) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
-  auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
-      SymbolTable::lookupSymbolIn(moduleOp, name));
+  auto func = lookupFuncOp(name, moduleOp, symbolTables);
   auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
   // Assert the signature of the found function is same as expected
   if (func) {
@@ -73,60 +89,75 @@ mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
   OpBuilder::InsertionGuard g(b);
   assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
   b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
-  return b.create<LLVM::LLVMFuncOp>(
+  auto funcOp = b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
+
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+    symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+  }
+
+  return funcOp;
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
-                         ArrayRef<Type> paramTypes, Type resultType) {
+                         ArrayRef<Type> paramTypes, Type resultType,
+                         SymbolTableCollection *symbolTables) {
   return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
-                          /*isVarArg=*/false, /*isReserved=*/true);
+                          /*isVarArg=*/false, /*isReserved=*/true,
+                          symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+                                     SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+                                     SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+                                     SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintF16,
       IntegerType::get(moduleOp->getContext(), 16), // bits!
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+                                      SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintBF16,
       IntegerType::get(moduleOp->getContext(), 16), // bits!
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+                                     SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+                                     SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -140,90 +171,102 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
 
 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
     OpBuilder &b, Operation *moduleOp,
-    std::optional<StringRef> runtimeFunctionName) {
+    std::optional<StringRef> runtimeFunctionName,
+    SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, runtimeFunctionName.value_or(kPrintString),
       getCharPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+                                      SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintOpen, {},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+                                       SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintClose, {},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+                                       SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintComma, {},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+                                         SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kPrintNewline, {},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp,
-                                   Type indexType) {
+                                   Type indexType,
+                                   SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
-                                  getVoidPtr(moduleOp->getContext()));
+                                  getVoidPtr(moduleOp->getContext()),
+                                  symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
-                                         Type indexType) {
-  return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc,
-                                  {indexType, indexType},
-                                  getVoidPtr(moduleOp->getContext()));
+                                         Type indexType,
+                                         SymbolTableCollection *symbolTables) {
+  return lookupOrCreateReservedFn(
+      b, moduleOp, kAlignedAlloc, {indexType, indexType},
+      getVoidPtr(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+                                 SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp,
-                                         Type indexType) {
+                                         Type indexType,
+                                         SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
-                                  getVoidPtr(moduleOp->getContext()));
+                                  getVoidPtr(moduleOp->getContext()),
+                                  symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(
-    OpBuilder &b, Operation *moduleOp, Type indexType) {
-  return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc,
-                                  {indexType, indexType},
-                                  getVoidPtr(moduleOp->getContext()));
+    OpBuilder &b, Operation *moduleOp, Type indexType,
+    SymbolTableCollection *symbolTables) {
+  return lookupOrCreateReservedFn(
+      b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+      getVoidPtr(moduleOp->getContext()), symbolTables);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+                                        SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
-FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp,
-                                       Type indexType,
-                                       Type unrankedDescriptorType) {
+FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
+    OpBuilder &b, Operation *moduleOp, Type indexType,
+    Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
   return lookupOrCreateReservedFn(
       b, moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }


        


More information about the Mlir-commits mailing list