[flang-commits] [flang] [flang][NFC] use mlir::SymbolTable in lowering (PR #86673)

via flang-commits flang-commits at lists.llvm.org
Tue Apr 2 02:00:48 PDT 2024


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/86673

>From f8d7244708cdcc343f1674a3ee8b67ce1316f98f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 26 Mar 2024 03:17:10 -0700
Subject: [PATCH 1/4] [flang][NFC] use mlir::SymbolTable in lowering

Whenever lowering is checking if a function or global already
exists in the mlir::Module, it was doing module->lookup.

On big programs (~5000 globals and functions), this causes
important slowdowns because these lookups are linear. Use
mlir::SymbolTable to speed-up these lookups. The SymbolTable
has to be created from the ModuleOp and maintained in sync.
It is therefore placed in the converter, and FirOPBuilders
can take a pointer to it to speed-up the lookups.

This patch does not bring mlir::SymbolTable to FIR/HLFIR
passes, but some passes creating a lot of runtime calls could
benefit from it too. More analysis will be needed.

As an example of the speed-ups, this patch speeds-up compilation
of Whizard compare_amplitude_UFO from 5 mins to 2 mins on
my machine (there is still room for speed-ups).
---
 flang/include/flang/Lower/AbstractConverter.h | 13 ++++
 .../flang/Optimizer/Builder/FIRBuilder.h      | 70 +++++++++----------
 .../flang/Optimizer/Dialect/FIROpsSupport.h   | 19 ++---
 flang/lib/Lower/Bridge.cpp                    | 23 ++++--
 flang/lib/Lower/CallInterface.cpp             |  9 ++-
 flang/lib/Lower/OpenACC.cpp                   |  6 +-
 flang/lib/Optimizer/Builder/FIRBuilder.cpp    | 45 ++++++++----
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp |  6 +-
 .../Optimizer/Builder/LowLevelIntrinsics.cpp  | 62 ++++++++--------
 .../Optimizer/Builder/PPCIntrinsicCall.cpp    | 34 ++++-----
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 18 +++--
 .../Transforms/SimplifyIntrinsics.cpp         |  7 +-
 12 files changed, 182 insertions(+), 130 deletions(-)

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 32e7a5e2b04061..d5dab9040d22bd 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -23,6 +23,10 @@
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/ArrayRef.h"
 
+namespace mlir {
+class SymbolTable;
+}
+
 namespace fir {
 class KindMapping;
 class FirOpBuilder;
@@ -305,6 +309,15 @@ class AbstractConverter {
   virtual Fortran::lower::SymbolBox
   lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
 
+  /// Return the mlir::SymbolTable associated to the ModuleOp.
+  /// Look-ups are faster using it than using module.lookup<>,
+  /// but the module op should be queried in case of failure
+  /// because this symbol table is not guaranteed to contain
+  /// all the symbols from the ModuleOp (the symbol table should
+  /// always be provided to the builder helper creating globals and
+  /// functions in order to be in sync).
+  virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index d61bf681be6194..8537f29b2e549c 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -28,6 +28,10 @@
 #include <optional>
 #include <utility>
 
+namespace mlir {
+class SymbolTable;
+}
+
 namespace fir {
 class AbstractArrayBox;
 class ExtendedValue;
@@ -42,8 +46,10 @@ class BoxValue;
 /// patterns.
 class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
 public:
-  explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap)
-      : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {}
+  explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
+                        mlir::SymbolTable *symbolTable = nullptr)
+      : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
+        symbolTable{symbolTable} {}
   explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
       : OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
     setListener(this);
@@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   // The listener self-reference has to be updated in case of copy-construction.
   FirOpBuilder(const FirOpBuilder &other)
       : OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
-        fastMathFlags{other.fastMathFlags} {
+        fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
     setListener(this);
   }
 
   FirOpBuilder(FirOpBuilder &&other)
-      : OpBuilder(other), OpBuilder::Listener(),
-        kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} {
+      : OpBuilder(other), OpBuilder::Listener(), kindMap{std::move(
+                                                     other.kindMap)},
+        fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
     setListener(this);
   }
 
@@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// Get a reference to the kind map.
   const fir::KindMapping &getKindMap() { return kindMap; }
 
+  /// Get func.func/fir.global symbol table attached to this builder if any.
+  mlir::SymbolTable *getMLIRSymbolTable() { return symbolTable; }
+
   /// Get the default integer type
   [[maybe_unused]] mlir::IntegerType getDefaultIntegerType() {
     return getIntegerType(
@@ -280,25 +290,28 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// Get a function by name. If the function exists in the current module, it
   /// is returned. Otherwise, a null FuncOp is returned.
   mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
-    return getNamedFunction(getModule(), name);
+    return getNamedFunction(getModule(), name, getMLIRSymbolTable());
   }
-  static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
-                                             llvm::StringRef name);
+  static mlir::func::FuncOp
+  getNamedFunction(mlir::ModuleOp module, llvm::StringRef name,
+                   const mlir::SymbolTable *symbolTable);
 
   /// Get a function by symbol name. The result will be null if there is no
   /// function with the given symbol in the module.
   mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
-    return getNamedFunction(getModule(), symbol);
+    return getNamedFunction(getModule(), symbol, getMLIRSymbolTable());
   }
-  static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
-                                             mlir::SymbolRefAttr symbol);
+  static mlir::func::FuncOp
+  getNamedFunction(mlir::ModuleOp module, mlir::SymbolRefAttr symbol,
+                   const mlir::SymbolTable *symbolTable);
 
   fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
-    return getNamedGlobal(getModule(), name);
+    return getNamedGlobal(getModule(), name, getMLIRSymbolTable());
   }
 
   static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
-                                      llvm::StringRef name);
+                                      llvm::StringRef name,
+                                      const mlir::SymbolTable *symbolTable);
 
   /// Lazy creation of fir.convert op.
   mlir::Value createConvert(mlir::Location loc, mlir::Type toTy,
@@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// result of the load if it was created, otherwise return \p val
   mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);
 
-  /// Create a new FuncOp. If the function may have already been created, use
-  /// `addNamedFunction` instead.
+  /// Determine if the named function is already in the module. Return the
+  /// instance if found, otherwise add a new named function to the module.
   mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
                                     mlir::FunctionType ty) {
-    return createFunction(loc, getModule(), name, ty);
+    return createFunction(loc, getModule(), name, ty, getMLIRSymbolTable());
   }
 
   static mlir::func::FuncOp createFunction(mlir::Location loc,
                                            mlir::ModuleOp module,
                                            llvm::StringRef name,
-                                           mlir::FunctionType ty);
-
-  /// Determine if the named function is already in the module. Return the
-  /// instance if found, otherwise add a new named function to the module.
-  mlir::func::FuncOp addNamedFunction(mlir::Location loc, llvm::StringRef name,
-                                      mlir::FunctionType ty) {
-    if (auto func = getNamedFunction(name))
-      return func;
-    return createFunction(loc, name, ty);
-  }
-
-  static mlir::func::FuncOp addNamedFunction(mlir::Location loc,
-                                             mlir::ModuleOp module,
-                                             llvm::StringRef name,
-                                             mlir::FunctionType ty) {
-    if (auto func = getNamedFunction(module, name))
-      return func;
-    return createFunction(loc, module, name, ty);
-  }
+                                           mlir::FunctionType ty,
+                                           mlir::SymbolTable *);
 
   /// Cast the input value to IndexType.
   mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
@@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// FastMathFlags that need to be set for operations that support
   /// mlir::arith::FastMathAttr.
   mlir::arith::FastMathFlags fastMathFlags{};
+
+  /// fir::GlobalOp and func::FuncOp symbol table to speed-up
+  /// lookups.
+  mlir::SymbolTable *symbolTable = nullptr;
 };
 
 } // namespace fir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index e8226b6df58ca2..f29e44504acb63 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -52,16 +52,19 @@ inline bool pureCall(mlir::Operation *op) {
 /// Get or create a FuncOp in a module.
 ///
 /// If `module` already contains FuncOp `name`, it is returned. Otherwise, a new
-/// FuncOp is created, and that new FuncOp is returned.
-mlir::func::FuncOp
-createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name,
-             mlir::FunctionType type,
-             llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
-
-/// Get or create a GlobalOp in a module.
+/// FuncOp is created, and that new FuncOp is returned. A symbol table can
+/// be provided to speed-up the lookups.
+mlir::func::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
+                                llvm::StringRef name, mlir::FunctionType type,
+                                llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+                                const mlir::SymbolTable *symbolTable = nullptr);
+
+/// Get or create a GlobalOp in a module. A symbol table can be provided to
+/// speed-up the lookups.
 fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
                              llvm::StringRef name, mlir::Type type,
-                             llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
+                             llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+                             const mlir::SymbolTable *symbolTable = nullptr);
 
 /// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
 constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; }
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 48830dc55578c2..46a259d9ae86c9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -273,7 +273,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 public:
   explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
       : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
-        bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
+        bridge{bridge}, foldingContext{bridge.createFoldingContext()},
+        mlirSymbolTable{bridge.getModule()} {}
   virtual ~FirConverter() = default;
 
   /// Convert the PFT to FIR.
@@ -329,8 +330,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               [&](Fortran::lower::pft::BlockDataUnit &b) {},
               [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
               [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
-                builder = new fir::FirOpBuilder(bridge.getModule(),
-                                                bridge.getKindMap());
+                builder = new fir::FirOpBuilder(
+                    bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
                 Fortran::lower::genOpenACCRoutineConstruct(
                     *this, bridge.getSemanticsContext(), bridge.getModule(),
                     d.routine, accRoutineInfos);
@@ -1036,6 +1037,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return {};
   }
 
+  mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+
   /// Add the symbol to the local map and return `true`. If the symbol is
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
@@ -4570,7 +4573,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                llvm::dbgs() << "\n");
     Fortran::lower::CalleeInterface callee(funit, *this);
     mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
-    builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+    builder =
+        new fir::FirOpBuilder(func, bridge.getKindMap(), &mlirSymbolTable);
     assert(builder && "FirOpBuilder did not instantiate");
     builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
     builder->setInsertionPointToStart(&func.front());
@@ -4838,12 +4842,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // FIXME: get rid of the bogus function context and instantiate the
     // globals directly into the module.
     mlir::MLIRContext *context = &getMLIRContext();
+    mlir::SymbolTable *symbolTable = getMLIRSymbolTable();
     mlir::func::FuncOp func = fir::FirOpBuilder::createFunction(
         mlir::UnknownLoc::get(context), getModuleOp(),
         fir::NameUniquer::doGenerated("Sham"),
-        mlir::FunctionType::get(context, std::nullopt, std::nullopt));
+        mlir::FunctionType::get(context, std::nullopt, std::nullopt),
+        symbolTable);
     func.addEntryBlock();
-    builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+    builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
     assert(builder && "FirOpBuilder did not instantiate");
     builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
     createGlobals();
@@ -5335,6 +5341,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// utilities to deal with procedure pointer components whose arguments have
   /// the type of the containing derived type.
   Fortran::lower::TypeConstructionStack typeConstructionStack;
+  /// MLIR symbol table of the fir.global/func.func operations. Note that it is
+  /// not guaranteed to contain all operations of the ModuleOp with Symbol
+  /// attribute since mlirSymbolTable must pro-actively be maintained when
+  /// new Symbol operations are created.
+  mlir::SymbolTable mlirSymbolTable;
 };
 
 } // namespace
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index c65becc497459c..fef38da0133060 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -667,11 +667,13 @@ void Fortran::lower::CallInterface<T>::declare() {
   if (!side().isIndirectCall()) {
     std::string name = side().getMangledName();
     mlir::ModuleOp module = converter.getModuleOp();
-    func = fir::FirOpBuilder::getNamedFunction(module, name);
+    mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
+    func = fir::FirOpBuilder::getNamedFunction(module, name, symbolTable);
     if (!func) {
       mlir::Location loc = side().getCalleeLocation();
       mlir::FunctionType ty = genFunctionType();
-      func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
+      func =
+          fir::FirOpBuilder::createFunction(loc, module, name, ty, symbolTable);
       if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
         if (side().isMainProgram()) {
           func->setAttr(fir::getSymbolAttrName(),
@@ -1644,7 +1646,8 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
     Fortran::lower::AbstractConverter &converter) {
   mlir::ModuleOp module = converter.getModuleOp();
   std::string name = getProcMangledName(proc, converter);
-  mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
+  mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
+      module, name, converter.getMLIRSymbolTable());
   if (func)
     return func;
 
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 7b7e4a875cd8e8..0ef3baa19c0199 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3809,7 +3809,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
   std::string funcName;
   if (name) {
     funcName = converter.mangleName(*name->symbol);
-    funcOp = builder.getNamedFunction(mod, funcName);
+    funcOp =
+        builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
   } else {
     Fortran::semantics::Scope &scope =
         semanticsContext.FindScope(routineConstruct.source);
@@ -3821,7 +3822,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
             : nullptr};
     if (subpDetails && subpDetails->isInterface()) {
       funcName = converter.mangleName(*progUnit.symbol());
-      funcOp = builder.getNamedFunction(mod, funcName);
+      funcOp =
+          builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
     } else {
       funcOp = builder.getFunction();
       funcName = funcOp.getName();
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 2bcd5e5914027d..a8606a79af1671 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -36,26 +36,39 @@ static llvm::cl::opt<std::size_t>
                                       "name"),
                        llvm::cl::init(32));
 
-mlir::func::FuncOp fir::FirOpBuilder::createFunction(mlir::Location loc,
-                                                     mlir::ModuleOp module,
-                                                     llvm::StringRef name,
-                                                     mlir::FunctionType ty) {
-  return fir::createFuncOp(loc, module, name, ty);
+mlir::func::FuncOp
+fir::FirOpBuilder::createFunction(mlir::Location loc, mlir::ModuleOp module,
+                                  llvm::StringRef name, mlir::FunctionType ty,
+                                  mlir::SymbolTable *symbolTable) {
+  return fir::createFuncOp(loc, module, name, ty, /*attrs*/ {}, symbolTable);
 }
 
-mlir::func::FuncOp fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
-                                                       llvm::StringRef name) {
+mlir::func::FuncOp
+fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, llvm::StringRef name,
+                                    const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name))
+      return func;
   return modOp.lookupSymbol<mlir::func::FuncOp>(name);
 }
 
 mlir::func::FuncOp
 fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
-                                    mlir::SymbolRefAttr symbol) {
+                                    mlir::SymbolRefAttr symbol,
+                                    const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto func =
+            symbolTable->lookup<mlir::func::FuncOp>(symbol.getLeafReference()))
+      return func;
   return modOp.lookupSymbol<mlir::func::FuncOp>(symbol);
 }
 
-fir::GlobalOp fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
-                                                llvm::StringRef name) {
+fir::GlobalOp
+fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, llvm::StringRef name,
+                                  const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto global = symbolTable->lookup<fir::GlobalOp>(name))
+      return global;
   return modOp.lookupSymbol<fir::GlobalOp>(name);
 }
 
@@ -279,10 +292,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
     mlir::Location loc, mlir::Type type, llvm::StringRef name,
     mlir::StringAttr linkage, mlir::Attribute value, bool isConst,
     bool isTarget, fir::CUDADataAttributeAttr cudaAttr) {
+  if (auto global = getNamedGlobal(name))
+    return global;
   auto module = getModule();
   auto insertPt = saveInsertionPoint();
-  if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
-    return glob;
   setInsertionPoint(module.getBody(), module.getBody()->end());
   llvm::SmallVector<mlir::NamedAttribute> attrs;
   if (cudaAttr) {
@@ -294,6 +307,8 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
   auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value,
                                     linkage, attrs);
   restoreInsertionPoint(insertPt);
+  if (symbolTable)
+    symbolTable->insert(glob);
   return glob;
 }
 
@@ -301,10 +316,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
     mlir::Location loc, mlir::Type type, llvm::StringRef name, bool isConst,
     bool isTarget, std::function<void(FirOpBuilder &)> bodyBuilder,
     mlir::StringAttr linkage, fir::CUDADataAttributeAttr cudaAttr) {
+  if (auto global = getNamedGlobal(name))
+    return global;
   auto module = getModule();
   auto insertPt = saveInsertionPoint();
-  if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
-    return glob;
   setInsertionPoint(module.getBody(), module.getBody()->end());
   auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type,
                                     mlir::Attribute{}, linkage);
@@ -314,6 +329,8 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
   setInsertionPointToStart(&block);
   bodyBuilder(*this);
   restoreInsertionPoint(insertPt);
+  if (symbolTable)
+    symbolTable->insert(glob);
   return glob;
 }
 
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ea1ef1f08aba20..069ba81cfe96ab 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -724,7 +724,7 @@ mlir::Value genLibCall(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::func::FuncOp funcOp = builder.getNamedFunction(libFuncName);
 
   if (!funcOp) {
-    funcOp = builder.addNamedFunction(loc, libFuncName, libFuncType);
+    funcOp = builder.createFunction(loc, libFuncName, libFuncType);
     // C-interoperability rules apply to these library functions.
     funcOp->setAttr(fir::getSymbolAttrName(),
                     mlir::StringAttr::get(builder.getContext(), libFuncName));
@@ -1894,8 +1894,8 @@ mlir::func::FuncOp IntrinsicLibrary::getWrapper(GeneratorType generator,
     // Create local context to emit code into the newly created function
     // This new function is not linked to a source file location, only
     // its calls will be.
-    auto localBuilder =
-        std::make_unique<fir::FirOpBuilder>(function, builder.getKindMap());
+    auto localBuilder = std::make_unique<fir::FirOpBuilder>(
+        function, builder.getKindMap(), builder.getMLIRSymbolTable());
     localBuilder->setFastMathFlags(builder.getFastMathFlags());
     localBuilder->setInsertionPointToStart(&function.front());
     // Location of code inside wrapper of the wrapper is independent from
diff --git a/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp b/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
index 1d07b1e724d745..bb5f77d5d4d1de 100644
--- a/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
+++ b/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
@@ -27,8 +27,8 @@ mlir::func::FuncOp fir::factory::getLlvmMemcpy(fir::FirOpBuilder &builder) {
                                         builder.getI1Type()};
   auto memcpyTy =
       mlir::FunctionType::get(builder.getContext(), args, std::nullopt);
-  return builder.addNamedFunction(builder.getUnknownLoc(),
-                                  "llvm.memcpy.p0.p0.i64", memcpyTy);
+  return builder.createFunction(builder.getUnknownLoc(),
+                                "llvm.memcpy.p0.p0.i64", memcpyTy);
 }
 
 mlir::func::FuncOp fir::factory::getLlvmMemmove(fir::FirOpBuilder &builder) {
@@ -37,8 +37,8 @@ mlir::func::FuncOp fir::factory::getLlvmMemmove(fir::FirOpBuilder &builder) {
                                         builder.getI1Type()};
   auto memmoveTy =
       mlir::FunctionType::get(builder.getContext(), args, std::nullopt);
-  return builder.addNamedFunction(builder.getUnknownLoc(),
-                                  "llvm.memmove.p0.p0.i64", memmoveTy);
+  return builder.createFunction(builder.getUnknownLoc(),
+                                "llvm.memmove.p0.p0.i64", memmoveTy);
 }
 
 mlir::func::FuncOp fir::factory::getLlvmMemset(fir::FirOpBuilder &builder) {
@@ -47,16 +47,15 @@ mlir::func::FuncOp fir::factory::getLlvmMemset(fir::FirOpBuilder &builder) {
                                         builder.getI1Type()};
   auto memsetTy =
       mlir::FunctionType::get(builder.getContext(), args, std::nullopt);
-  return builder.addNamedFunction(builder.getUnknownLoc(),
-                                  "llvm.memset.p0.p0.i64", memsetTy);
+  return builder.createFunction(builder.getUnknownLoc(),
+                                "llvm.memset.p0.p0.i64", memsetTy);
 }
 
 mlir::func::FuncOp fir::factory::getRealloc(fir::FirOpBuilder &builder) {
   auto ptrTy = builder.getRefType(builder.getIntegerType(8));
   llvm::SmallVector<mlir::Type> args = {ptrTy, builder.getI64Type()};
   auto reallocTy = mlir::FunctionType::get(builder.getContext(), args, {ptrTy});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "realloc",
-                                  reallocTy);
+  return builder.createFunction(builder.getUnknownLoc(), "realloc", reallocTy);
 }
 
 mlir::func::FuncOp
@@ -64,8 +63,8 @@ fir::factory::getLlvmGetRounding(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), std::nullopt, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.get.rounding",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "llvm.get.rounding",
+                                funcTy);
 }
 
 mlir::func::FuncOp
@@ -73,8 +72,8 @@ fir::factory::getLlvmSetRounding(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {int32Ty}, std::nullopt);
-  return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.set.rounding",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "llvm.set.rounding",
+                                funcTy);
 }
 
 mlir::func::FuncOp fir::factory::getLlvmStackSave(fir::FirOpBuilder &builder) {
@@ -82,8 +81,8 @@ mlir::func::FuncOp fir::factory::getLlvmStackSave(fir::FirOpBuilder &builder) {
   auto ptrTy = builder.getRefType(builder.getIntegerType(8));
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), std::nullopt, {ptrTy});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.stacksave.p0",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "llvm.stacksave.p0",
+                                funcTy);
 }
 
 mlir::func::FuncOp
@@ -92,8 +91,8 @@ fir::factory::getLlvmStackRestore(fir::FirOpBuilder &builder) {
   auto ptrTy = builder.getRefType(builder.getIntegerType(8));
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {ptrTy}, std::nullopt);
-  return builder.addNamedFunction(builder.getUnknownLoc(),
-                                  "llvm.stackrestore.p0", funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "llvm.stackrestore.p0",
+                                funcTy);
 }
 
 mlir::func::FuncOp
@@ -101,24 +100,24 @@ fir::factory::getLlvmInitTrampoline(fir::FirOpBuilder &builder) {
   auto ptrTy = builder.getRefType(builder.getIntegerType(8));
   auto funcTy = mlir::FunctionType::get(builder.getContext(),
                                         {ptrTy, ptrTy, ptrTy}, std::nullopt);
-  return builder.addNamedFunction(builder.getUnknownLoc(),
-                                  "llvm.init.trampoline", funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "llvm.init.trampoline",
+                                funcTy);
 }
 
 mlir::func::FuncOp
 fir::factory::getLlvmAdjustTrampoline(fir::FirOpBuilder &builder) {
   auto ptrTy = builder.getRefType(builder.getIntegerType(8));
   auto funcTy = mlir::FunctionType::get(builder.getContext(), {ptrTy}, {ptrTy});
-  return builder.addNamedFunction(builder.getUnknownLoc(),
-                                  "llvm.adjust.trampoline", funcTy);
+  return builder.createFunction(builder.getUnknownLoc(),
+                                "llvm.adjust.trampoline", funcTy);
 }
 
 mlir::func::FuncOp fir::factory::getFeclearexcept(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "feclearexcept",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "feclearexcept",
+                                funcTy);
 }
 
 mlir::func::FuncOp
@@ -126,38 +125,37 @@ fir::factory::getFedisableexcept(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "fedisableexcept",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "fedisableexcept",
+                                funcTy);
 }
 
 mlir::func::FuncOp fir::factory::getFeenableexcept(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "feenableexcept",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "feenableexcept",
+                                funcTy);
 }
 
 mlir::func::FuncOp fir::factory::getFegetexcept(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), std::nullopt, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "fegetexcept",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "fegetexcept", funcTy);
 }
 
 mlir::func::FuncOp fir::factory::getFeraiseexcept(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "feraiseexcept",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "feraiseexcept",
+                                funcTy);
 }
 
 mlir::func::FuncOp fir::factory::getFetestexcept(fir::FirOpBuilder &builder) {
   auto int32Ty = builder.getIntegerType(32);
   auto funcTy =
       mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
-  return builder.addNamedFunction(builder.getUnknownLoc(), "fetestexcept",
-                                  funcTy);
+  return builder.createFunction(builder.getUnknownLoc(), "fetestexcept",
+                                funcTy);
 }
diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
index ab0d5079d8afe0..e588b19dded4f1 100644
--- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
@@ -1084,11 +1084,11 @@ void PPCIntrinsicLibrary::genMtfsf(llvm::ArrayRef<fir::ExtendedValue> args) {
   if (isImm) {
     libFuncType = genFuncType<Ty::Void, Ty::Integer<4>, Ty::Integer<4>>(
         builder.getContext(), builder);
-    funcOp = builder.addNamedFunction(loc, "llvm.ppc.mtfsfi", libFuncType);
+    funcOp = builder.createFunction(loc, "llvm.ppc.mtfsfi", libFuncType);
   } else {
     libFuncType = genFuncType<Ty::Void, Ty::Integer<4>, Ty::Real<8>>(
         builder.getContext(), builder);
-    funcOp = builder.addNamedFunction(loc, "llvm.ppc.mtfsf", libFuncType);
+    funcOp = builder.createFunction(loc, "llvm.ppc.mtfsf", libFuncType);
   }
   builder.create<fir::CallOp>(loc, funcOp, scalarArgs);
 }
@@ -1116,7 +1116,7 @@ PPCIntrinsicLibrary::genVecAbs(mlir::Type resultType,
           genFuncType<Ty::RealVector<8>, Ty::RealVector<8>>(context, builder);
     }
 
-    funcOp = builder.addNamedFunction(loc, fname, ftype);
+    funcOp = builder.createFunction(loc, fname, ftype);
     auto callOp{builder.create<fir::CallOp>(loc, funcOp, argBases[0])};
     return callOp.getResult(0);
   } else if (auto eleTy = vTypeInfo.eleTy.dyn_cast<mlir::IntegerType>()) {
@@ -1155,7 +1155,7 @@ PPCIntrinsicLibrary::genVecAbs(mlir::Type resultType,
     default:
       llvm_unreachable("invalid integer size");
     }
-    funcOp = builder.addNamedFunction(loc, fname, ftype);
+    funcOp = builder.createFunction(loc, fname, ftype);
 
     mlir::Value args[] = {zeroSubVarg1, varg1};
     auto callOp{builder.create<fir::CallOp>(loc, funcOp, args)};
@@ -1339,7 +1339,7 @@ PPCIntrinsicLibrary::genVecAnyCompare(mlir::Type resultType,
   }
   assert((!fname.empty() && ftype) && "invalid type");
 
-  mlir::func::FuncOp funcOp{builder.addNamedFunction(loc, fname, ftype)};
+  mlir::func::FuncOp funcOp{builder.createFunction(loc, fname, ftype)};
   auto callOp{builder.create<fir::CallOp>(loc, funcOp, cmpArgs)};
   return callOp.getResult(0);
 }
@@ -1445,7 +1445,7 @@ PPCIntrinsicLibrary::genVecCmp(mlir::Type resultType,
   std::pair<llvm::StringRef, mlir::FunctionType> funcTyNam{
       getVecCmpFuncTypeAndName(vecTyInfo, vop, builder)};
 
-  mlir::func::FuncOp funcOp = builder.addNamedFunction(
+  mlir::func::FuncOp funcOp = builder.createFunction(
       loc, std::get<0>(funcTyNam), std::get<1>(funcTyNam));
 
   mlir::Value res{nullptr};
@@ -1572,7 +1572,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
                                    Ty::Integer<4>>(context, builder)};
       const llvm::StringRef fname{(isUnsigned) ? "llvm.ppc.altivec.vcfux"
                                                : "llvm.ppc.altivec.vcfsx"};
-      auto funcOp{builder.addNamedFunction(loc, fname, ftype)};
+      auto funcOp{builder.createFunction(loc, fname, ftype)};
       mlir::Value newArgs[] = {argBases[0], convArg};
       auto callOp{builder.create<fir::CallOp>(loc, funcOp, newArgs)};
 
@@ -1627,7 +1627,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
       const llvm::StringRef fname{"llvm.ppc.vsx.xvcvspdp"};
       auto ftype{
           genFuncType<Ty::RealVector<8>, Ty::RealVector<4>>(context, builder)};
-      auto funcOp{builder.addNamedFunction(loc, fname, ftype)};
+      auto funcOp{builder.createFunction(loc, fname, ftype)};
       auto callOp{builder.create<fir::CallOp>(loc, funcOp, newArgs)};
 
       return callOp.getResult(0);
@@ -1635,7 +1635,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
       const llvm::StringRef fname{"llvm.ppc.vsx.xvcvdpsp"};
       auto ftype{
           genFuncType<Ty::RealVector<4>, Ty::RealVector<8>>(context, builder)};
-      auto funcOp{builder.addNamedFunction(loc, fname, ftype)};
+      auto funcOp{builder.createFunction(loc, fname, ftype)};
       newArgs[0] =
           builder.create<fir::CallOp>(loc, funcOp, newArgs).getResult(0);
       auto fvf32Ty{newArgs[0].getType()};
@@ -1963,7 +1963,7 @@ PPCIntrinsicLibrary::genVecLdCallGrp(mlir::Type resultType,
 
   auto funcType{
       mlir::FunctionType::get(context, {addr.getType()}, {intrinResTy})};
-  auto funcOp{builder.addNamedFunction(loc, fname, funcType)};
+  auto funcOp{builder.createFunction(loc, fname, funcType)};
   auto result{
       builder.create<fir::CallOp>(loc, funcOp, parsedArgs).getResult(0)};
 
@@ -2022,7 +2022,7 @@ PPCIntrinsicLibrary::genVecLvsGrp(mlir::Type resultType,
     llvm_unreachable("invalid vector operation for generator");
   }
   auto funcType{mlir::FunctionType::get(context, {addr.getType()}, {mlirTy})};
-  auto funcOp{builder.addNamedFunction(loc, fname, funcType)};
+  auto funcOp{builder.createFunction(loc, fname, funcType)};
   auto result{
       builder.create<fir::CallOp>(loc, funcOp, parsedArgs).getResult(0)};
 
@@ -2057,8 +2057,8 @@ PPCIntrinsicLibrary::genVecNmaddMsub(mlir::Type resultType,
            genFuncType<Ty::RealVector<8>, Ty::RealVector<8>, Ty::RealVector<8>>(
                context, builder))}};
 
-  auto funcOp{builder.addNamedFunction(loc, std::get<0>(fmaMap[width]),
-                                       std::get<1>(fmaMap[width]))};
+  auto funcOp{builder.createFunction(loc, std::get<0>(fmaMap[width]),
+                                     std::get<1>(fmaMap[width]))};
   if (vop == VecOp::Nmadd) {
     // vec_nmadd(arg1, arg2, arg3) = -fma(arg1, arg2, arg3)
     auto callOp{builder.create<fir::CallOp>(loc, funcOp, newArgs)};
@@ -2110,7 +2110,7 @@ PPCIntrinsicLibrary::genVecPerm(mlir::Type resultType,
           builder.create<mlir::LLVM::BitcastOp>(loc, vi32Ty, mArg1).getResult();
     }
 
-    auto funcOp{builder.addNamedFunction(
+    auto funcOp{builder.createFunction(
         loc, "llvm.ppc.altivec.vperm",
         genFuncType<Ty::IntegerVector<4>, Ty::IntegerVector<4>,
                     Ty::IntegerVector<4>, Ty::IntegerVector<1>>(context,
@@ -2307,7 +2307,7 @@ PPCIntrinsicLibrary::genVecShift(mlir::Type resultType,
     }
     auto funcTy{genFuncType<Ty::IntegerVector<4>, Ty::IntegerVector<4>,
                             Ty::IntegerVector<4>>(context, builder)};
-    mlir::func::FuncOp funcOp{builder.addNamedFunction(loc, funcName, funcTy)};
+    mlir::func::FuncOp funcOp{builder.createFunction(loc, funcName, funcTy)};
     auto callOp{builder.create<fir::CallOp>(loc, funcOp, mlirVecArgs)};
 
     // If the result vector type is different from the original type, need
@@ -2755,7 +2755,7 @@ void PPCIntrinsicLibrary::genMmaIntr(llvm::ArrayRef<fir::ExtendedValue> args) {
   auto context{builder.getContext()};
   mlir::FunctionType intrFuncType{getMmaIrFuncType(context, IntrId)};
   mlir::func::FuncOp funcOp{
-      builder.addNamedFunction(loc, getMmaIrIntrName(IntrId), intrFuncType)};
+      builder.createFunction(loc, getMmaIrIntrName(IntrId), intrFuncType)};
   llvm::SmallVector<mlir::Value> intrArgs;
 
   // Depending on SubToFunc, change the subroutine call to a function call.
@@ -2892,7 +2892,7 @@ void PPCIntrinsicLibrary::genVecStore(llvm::ArrayRef<fir::ExtendedValue> args) {
 
   auto funcType{
       mlir::FunctionType::get(context, {stTy, addr.getType()}, std::nullopt)};
-  mlir::func::FuncOp funcOp = builder.addNamedFunction(loc, fname, funcType);
+  mlir::func::FuncOp funcOp = builder.createFunction(loc, fname, funcType);
 
   llvm::SmallVector<mlir::Value, 4> biArgs;
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 9bb10a42a3997c..aeae6ff01507f1 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3677,10 +3677,14 @@ fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result,
   return mlir::success();
 }
 
-mlir::func::FuncOp
-fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
-                  llvm::StringRef name, mlir::FunctionType type,
-                  llvm::ArrayRef<mlir::NamedAttribute> attrs) {
+mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
+                                     llvm::StringRef name,
+                                     mlir::FunctionType type,
+                                     llvm::ArrayRef<mlir::NamedAttribute> attrs,
+                                     const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name))
+      return f;
   if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name))
     return f;
   mlir::OpBuilder modBuilder(module.getBodyRegion());
@@ -3692,7 +3696,11 @@ fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
 
 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
                                   llvm::StringRef name, mlir::Type type,
-                                  llvm::ArrayRef<mlir::NamedAttribute> attrs) {
+                                  llvm::ArrayRef<mlir::NamedAttribute> attrs,
+                                  const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto g = symbolTable->lookup<fir::GlobalOp>(name))
+      return g;
   if (auto g = module.lookupSymbol<fir::GlobalOp>(name))
     return g;
   mlir::OpBuilder modBuilder(module.getBodyRegion());
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index a11aa38c771bd1..f7820b6b8170ba 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -1004,10 +1004,8 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
   //          We can also avoid this by using internal linkage, but
   //          this may increase the size of final executable/shared library.
   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
-  mlir::ModuleOp module = builder.getModule();
   // If we already have a function, just return it.
-  mlir::func::FuncOp newFunc =
-      fir::FirOpBuilder::getNamedFunction(module, replacementName);
+  mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
   mlir::FunctionType fType = typeGenerator(builder);
   if (newFunc) {
     assert(newFunc.getFunctionType() == fType &&
@@ -1017,8 +1015,7 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
 
   // Need to build the function!
   auto loc = mlir::UnknownLoc::get(builder.getContext());
-  newFunc =
-      fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
+  newFunc = builder.createFunction(loc, replacementName, fType);
   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
   auto linkage =
       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);

>From c13845ce6e1813a18f796321b123b5fca1293f25 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 27 Mar 2024 03:17:39 -0700
Subject: [PATCH 2/4] apply clang-format suggestion

---
 flang/include/flang/Optimizer/Builder/FIRBuilder.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 8537f29b2e549c..5868684135bd8e 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -80,9 +80,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   }
 
   FirOpBuilder(FirOpBuilder &&other)
-      : OpBuilder(other), OpBuilder::Listener(), kindMap{std::move(
-                                                     other.kindMap)},
-        fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
+      : OpBuilder(other), OpBuilder::Listener(),
+        kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags},
+        symbolTable{other.symbolTable} {
     setListener(this);
   }
 

>From 5715efd4f04ea25ad077fef6ae29d9a0fbd43f80 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 27 Mar 2024 03:36:12 -0700
Subject: [PATCH 3/4] add sanity check under expensive checks

---
 flang/lib/Optimizer/Builder/FIRBuilder.cpp | 23 ++++++++++++++++++----
 flang/lib/Optimizer/Dialect/FIROps.cpp     | 14 +++++++++++--
 2 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index a8606a79af1671..c29e12a203ad88 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -47,8 +47,13 @@ mlir::func::FuncOp
 fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, llvm::StringRef name,
                                     const mlir::SymbolTable *symbolTable) {
   if (symbolTable)
-    if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name))
+    if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+      assert(func == modOp.lookupSymbol<mlir::func::FuncOp>(name) &&
+             "symbolTable and module out of sync");
+#endif
       return func;
+    }
   return modOp.lookupSymbol<mlir::func::FuncOp>(name);
 }
 
@@ -57,9 +62,14 @@ fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
                                     mlir::SymbolRefAttr symbol,
                                     const mlir::SymbolTable *symbolTable) {
   if (symbolTable)
-    if (auto func =
-            symbolTable->lookup<mlir::func::FuncOp>(symbol.getLeafReference()))
+    if (auto func = symbolTable->lookup<mlir::func::FuncOp>(
+            symbol.getLeafReference())) {
+#ifdef EXPENSIVE_CHECKS
+      assert(func == modOp.lookupSymbol<mlir::func::FuncOp>(symbol) &&
+             "symbolTable and module out of sync");
+#endif
       return func;
+    }
   return modOp.lookupSymbol<mlir::func::FuncOp>(symbol);
 }
 
@@ -67,8 +77,13 @@ fir::GlobalOp
 fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, llvm::StringRef name,
                                   const mlir::SymbolTable *symbolTable) {
   if (symbolTable)
-    if (auto global = symbolTable->lookup<fir::GlobalOp>(name))
+    if (auto global = symbolTable->lookup<fir::GlobalOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+      assert(global == modOp.lookupSymbol<fir::GlobalOp>(name) &&
+             "symbolTable and module out of sync");
+#endif
       return global;
+    }
   return modOp.lookupSymbol<fir::GlobalOp>(name);
 }
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index aeae6ff01507f1..dba2c30d1851bf 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3683,8 +3683,13 @@ mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
                                      llvm::ArrayRef<mlir::NamedAttribute> attrs,
                                      const mlir::SymbolTable *symbolTable) {
   if (symbolTable)
-    if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name))
+    if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+      assert(f == module.lookupSymbol<mlir::func::FuncOp>(name) &&
+             "symbolTable and module out of sync");
+#endif
       return f;
+    }
   if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name))
     return f;
   mlir::OpBuilder modBuilder(module.getBodyRegion());
@@ -3699,8 +3704,13 @@ fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
                                   llvm::ArrayRef<mlir::NamedAttribute> attrs,
                                   const mlir::SymbolTable *symbolTable) {
   if (symbolTable)
-    if (auto g = symbolTable->lookup<fir::GlobalOp>(name))
+    if (auto g = symbolTable->lookup<fir::GlobalOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+      assert(g == module.lookupSymbol<fir::GlobalOp>(name) &&
+             "symbolTable and module out of sync");
+#endif
       return g;
+    }
   if (auto g = module.lookupSymbol<fir::GlobalOp>(name))
     return g;
   mlir::OpBuilder modBuilder(module.getBodyRegion());

>From 41fa1556bb14d81206000b262f958b0c1f209448 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 2 Apr 2024 01:33:06 -0700
Subject: [PATCH 4/4] NFC - shuffle arg order as suggested by Val

---
 .../flang/Optimizer/Builder/FIRBuilder.h       | 18 +++++++++---------
 flang/lib/Lower/CallInterface.cpp              |  4 ++--
 flang/lib/Lower/OpenACC.cpp                    |  4 ++--
 flang/lib/Optimizer/Builder/FIRBuilder.cpp     | 14 ++++++++------
 4 files changed, 21 insertions(+), 19 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 5868684135bd8e..940866b25d2fe8 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -290,28 +290,28 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// Get a function by name. If the function exists in the current module, it
   /// is returned. Otherwise, a null FuncOp is returned.
   mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
-    return getNamedFunction(getModule(), name, getMLIRSymbolTable());
+    return getNamedFunction(getModule(), getMLIRSymbolTable(), name);
   }
   static mlir::func::FuncOp
-  getNamedFunction(mlir::ModuleOp module, llvm::StringRef name,
-                   const mlir::SymbolTable *symbolTable);
+  getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
+                   llvm::StringRef name);
 
   /// Get a function by symbol name. The result will be null if there is no
   /// function with the given symbol in the module.
   mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
-    return getNamedFunction(getModule(), symbol, getMLIRSymbolTable());
+    return getNamedFunction(getModule(), getMLIRSymbolTable(), symbol);
   }
   static mlir::func::FuncOp
-  getNamedFunction(mlir::ModuleOp module, mlir::SymbolRefAttr symbol,
-                   const mlir::SymbolTable *symbolTable);
+  getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
+                   mlir::SymbolRefAttr symbol);
 
   fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
-    return getNamedGlobal(getModule(), name, getMLIRSymbolTable());
+    return getNamedGlobal(getModule(), getMLIRSymbolTable(), name);
   }
 
   static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
-                                      llvm::StringRef name,
-                                      const mlir::SymbolTable *symbolTable);
+                                      const mlir::SymbolTable *symbolTable,
+                                      llvm::StringRef name);
 
   /// Lazy creation of fir.convert op.
   mlir::Value createConvert(mlir::Location loc, mlir::Type toTy,
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index fef38da0133060..29cdb3cff589ba 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -668,7 +668,7 @@ void Fortran::lower::CallInterface<T>::declare() {
     std::string name = side().getMangledName();
     mlir::ModuleOp module = converter.getModuleOp();
     mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
-    func = fir::FirOpBuilder::getNamedFunction(module, name, symbolTable);
+    func = fir::FirOpBuilder::getNamedFunction(module, symbolTable, name);
     if (!func) {
       mlir::Location loc = side().getCalleeLocation();
       mlir::FunctionType ty = genFunctionType();
@@ -1647,7 +1647,7 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
   mlir::ModuleOp module = converter.getModuleOp();
   std::string name = getProcMangledName(proc, converter);
   mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
-      module, name, converter.getMLIRSymbolTable());
+      module, converter.getMLIRSymbolTable(), name);
   if (func)
     return func;
 
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 0ef3baa19c0199..0b2d8237791592 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3810,7 +3810,7 @@ void Fortran::lower::genOpenACCRoutineConstruct(
   if (name) {
     funcName = converter.mangleName(*name->symbol);
     funcOp =
-        builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
+        builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
   } else {
     Fortran::semantics::Scope &scope =
         semanticsContext.FindScope(routineConstruct.source);
@@ -3823,7 +3823,7 @@ void Fortran::lower::genOpenACCRoutineConstruct(
     if (subpDetails && subpDetails->isInterface()) {
       funcName = converter.mangleName(*progUnit.symbol());
       funcOp =
-          builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
+          builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
     } else {
       funcOp = builder.getFunction();
       funcName = funcOp.getName();
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index c29e12a203ad88..e4362b2f9e6945 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -44,8 +44,9 @@ fir::FirOpBuilder::createFunction(mlir::Location loc, mlir::ModuleOp module,
 }
 
 mlir::func::FuncOp
-fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, llvm::StringRef name,
-                                    const mlir::SymbolTable *symbolTable) {
+fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
+                                    const mlir::SymbolTable *symbolTable,
+                                    llvm::StringRef name) {
   if (symbolTable)
     if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name)) {
 #ifdef EXPENSIVE_CHECKS
@@ -59,8 +60,8 @@ fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, llvm::StringRef name,
 
 mlir::func::FuncOp
 fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
-                                    mlir::SymbolRefAttr symbol,
-                                    const mlir::SymbolTable *symbolTable) {
+                                    const mlir::SymbolTable *symbolTable,
+                                    mlir::SymbolRefAttr symbol) {
   if (symbolTable)
     if (auto func = symbolTable->lookup<mlir::func::FuncOp>(
             symbol.getLeafReference())) {
@@ -74,8 +75,9 @@ fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
 }
 
 fir::GlobalOp
-fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, llvm::StringRef name,
-                                  const mlir::SymbolTable *symbolTable) {
+fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
+                                  const mlir::SymbolTable *symbolTable,
+                                  llvm::StringRef name) {
   if (symbolTable)
     if (auto global = symbolTable->lookup<fir::GlobalOp>(name)) {
 #ifdef EXPENSIVE_CHECKS



More information about the flang-commits mailing list