[flang-commits] [flang] a4798bb - [flang][NFC] use mlir::SymbolTable in lowering (#86673)
via flang-commits
flang-commits at lists.llvm.org
Tue Apr 2 05:29:33 PDT 2024
Author: jeanPerier
Date: 2024-04-02T14:29:29+02:00
New Revision: a4798bb0b67533b37d6b34fd5292714aac3b17d9
URL: https://github.com/llvm/llvm-project/commit/a4798bb0b67533b37d6b34fd5292714aac3b17d9
DIFF: https://github.com/llvm/llvm-project/commit/a4798bb0b67533b37d6b34fd5292714aac3b17d9.diff
LOG: [flang][NFC] use mlir::SymbolTable in lowering (#86673)
Whenever lowering is checking if a function or global already exists in
the mlir::Module, it was doing module->lookup.
On big programs (~5000 globals and functions), this causes important
slowdowns because these lookups are linear. Use mlir::SymbolTable to
speed-up these lookups. The SymbolTable has to be created from the
ModuleOp and maintained in sync. It is therefore placed in the
converter, and FirOPBuilders can take a pointer to it to speed-up the
lookups.
This patch does not bring mlir::SymbolTable to FIR/HLFIR passes, but
some passes creating a lot of runtime calls could benefit from it too.
More analysis will be needed.
As an example of the speed-ups, this patch speeds-up compilation of
Whizard compare_amplitude_UFO.F90 from 5 mins to 2 mins on my machine
(there is still room for speed-ups).
Added:
Modified:
flang/include/flang/Lower/AbstractConverter.h
flang/include/flang/Optimizer/Builder/FIRBuilder.h
flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/CallInterface.cpp
flang/lib/Lower/OpenACC.cpp
flang/lib/Optimizer/Builder/FIRBuilder.cpp
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 32e7a5e2b04061..d5dab9040d22bd 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -23,6 +23,10 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/ArrayRef.h"
+namespace mlir {
+class SymbolTable;
+}
+
namespace fir {
class KindMapping;
class FirOpBuilder;
@@ -305,6 +309,15 @@ class AbstractConverter {
virtual Fortran::lower::SymbolBox
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
+ /// Return the mlir::SymbolTable associated to the ModuleOp.
+ /// Look-ups are faster using it than using module.lookup<>,
+ /// but the module op should be queried in case of failure
+ /// because this symbol table is not guaranteed to contain
+ /// all the symbols from the ModuleOp (the symbol table should
+ /// always be provided to the builder helper creating globals and
+ /// functions in order to be in sync).
+ virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+
private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index d61bf681be6194..940866b25d2fe8 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -28,6 +28,10 @@
#include <optional>
#include <utility>
+namespace mlir {
+class SymbolTable;
+}
+
namespace fir {
class AbstractArrayBox;
class ExtendedValue;
@@ -42,8 +46,10 @@ class BoxValue;
/// patterns.
class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
public:
- explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap)
- : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {}
+ explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
+ mlir::SymbolTable *symbolTable = nullptr)
+ : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
+ symbolTable{symbolTable} {}
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
setListener(this);
@@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
// The listener self-reference has to be updated in case of copy-construction.
FirOpBuilder(const FirOpBuilder &other)
: OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
- fastMathFlags{other.fastMathFlags} {
+ fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
setListener(this);
}
FirOpBuilder(FirOpBuilder &&other)
: OpBuilder(other), OpBuilder::Listener(),
- kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} {
+ kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags},
+ symbolTable{other.symbolTable} {
setListener(this);
}
@@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Get a reference to the kind map.
const fir::KindMapping &getKindMap() { return kindMap; }
+ /// Get func.func/fir.global symbol table attached to this builder if any.
+ mlir::SymbolTable *getMLIRSymbolTable() { return symbolTable; }
+
/// Get the default integer type
[[maybe_unused]] mlir::IntegerType getDefaultIntegerType() {
return getIntegerType(
@@ -280,24 +290,27 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Get a function by name. If the function exists in the current module, it
/// is returned. Otherwise, a null FuncOp is returned.
mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
- return getNamedFunction(getModule(), name);
+ return getNamedFunction(getModule(), getMLIRSymbolTable(), name);
}
- static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
- llvm::StringRef name);
+ static mlir::func::FuncOp
+ getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
+ llvm::StringRef name);
/// Get a function by symbol name. The result will be null if there is no
/// function with the given symbol in the module.
mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
- return getNamedFunction(getModule(), symbol);
+ return getNamedFunction(getModule(), getMLIRSymbolTable(), symbol);
}
- static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
- mlir::SymbolRefAttr symbol);
+ static mlir::func::FuncOp
+ getNamedFunction(mlir::ModuleOp module, const mlir::SymbolTable *symbolTable,
+ mlir::SymbolRefAttr symbol);
fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
- return getNamedGlobal(getModule(), name);
+ return getNamedGlobal(getModule(), getMLIRSymbolTable(), name);
}
static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
+ const mlir::SymbolTable *symbolTable,
llvm::StringRef name);
/// Lazy creation of fir.convert op.
@@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// result of the load if it was created, otherwise return \p val
mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);
- /// Create a new FuncOp. If the function may have already been created, use
- /// `addNamedFunction` instead.
+ /// Determine if the named function is already in the module. Return the
+ /// instance if found, otherwise add a new named function to the module.
mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
mlir::FunctionType ty) {
- return createFunction(loc, getModule(), name, ty);
+ return createFunction(loc, getModule(), name, ty, getMLIRSymbolTable());
}
static mlir::func::FuncOp createFunction(mlir::Location loc,
mlir::ModuleOp module,
llvm::StringRef name,
- mlir::FunctionType ty);
-
- /// Determine if the named function is already in the module. Return the
- /// instance if found, otherwise add a new named function to the module.
- mlir::func::FuncOp addNamedFunction(mlir::Location loc, llvm::StringRef name,
- mlir::FunctionType ty) {
- if (auto func = getNamedFunction(name))
- return func;
- return createFunction(loc, name, ty);
- }
-
- static mlir::func::FuncOp addNamedFunction(mlir::Location loc,
- mlir::ModuleOp module,
- llvm::StringRef name,
- mlir::FunctionType ty) {
- if (auto func = getNamedFunction(module, name))
- return func;
- return createFunction(loc, module, name, ty);
- }
+ mlir::FunctionType ty,
+ mlir::SymbolTable *);
/// Cast the input value to IndexType.
mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
@@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// FastMathFlags that need to be set for operations that support
/// mlir::arith::FastMathAttr.
mlir::arith::FastMathFlags fastMathFlags{};
+
+ /// fir::GlobalOp and func::FuncOp symbol table to speed-up
+ /// lookups.
+ mlir::SymbolTable *symbolTable = nullptr;
};
} // namespace fir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index e8226b6df58ca2..f29e44504acb63 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -52,16 +52,19 @@ inline bool pureCall(mlir::Operation *op) {
/// Get or create a FuncOp in a module.
///
/// If `module` already contains FuncOp `name`, it is returned. Otherwise, a new
-/// FuncOp is created, and that new FuncOp is returned.
-mlir::func::FuncOp
-createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name,
- mlir::FunctionType type,
- llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
-
-/// Get or create a GlobalOp in a module.
+/// FuncOp is created, and that new FuncOp is returned. A symbol table can
+/// be provided to speed-up the lookups.
+mlir::func::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
+ llvm::StringRef name, mlir::FunctionType type,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+ const mlir::SymbolTable *symbolTable = nullptr);
+
+/// Get or create a GlobalOp in a module. A symbol table can be provided to
+/// speed-up the lookups.
fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
llvm::StringRef name, mlir::Type type,
- llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
+ llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+ const mlir::SymbolTable *symbolTable = nullptr);
/// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; }
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 91b898eb513e05..5bba0978617c79 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -273,7 +273,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
public:
explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
: Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
- bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
+ bridge{bridge}, foldingContext{bridge.createFoldingContext()},
+ mlirSymbolTable{bridge.getModule()} {}
virtual ~FirConverter() = default;
/// Convert the PFT to FIR.
@@ -329,8 +330,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
- builder = new fir::FirOpBuilder(bridge.getModule(),
- bridge.getKindMap());
+ builder = new fir::FirOpBuilder(
+ bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
Fortran::lower::genOpenACCRoutineConstruct(
*this, bridge.getSemanticsContext(), bridge.getModule(),
d.routine, accRoutineInfos);
@@ -1036,6 +1037,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return {};
}
+ mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+
/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
@@ -4571,7 +4574,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
llvm::dbgs() << "\n");
Fortran::lower::CalleeInterface callee(funit, *this);
mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
- builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+ builder =
+ new fir::FirOpBuilder(func, bridge.getKindMap(), &mlirSymbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
builder->setInsertionPointToStart(&func.front());
@@ -4839,12 +4843,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// FIXME: get rid of the bogus function context and instantiate the
// globals directly into the module.
mlir::MLIRContext *context = &getMLIRContext();
+ mlir::SymbolTable *symbolTable = getMLIRSymbolTable();
mlir::func::FuncOp func = fir::FirOpBuilder::createFunction(
mlir::UnknownLoc::get(context), getModuleOp(),
fir::NameUniquer::doGenerated("Sham"),
- mlir::FunctionType::get(context, std::nullopt, std::nullopt));
+ mlir::FunctionType::get(context, std::nullopt, std::nullopt),
+ symbolTable);
func.addEntryBlock();
- builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+ builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
createGlobals();
@@ -5336,6 +5342,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// utilities to deal with procedure pointer components whose arguments have
/// the type of the containing derived type.
Fortran::lower::TypeConstructionStack typeConstructionStack;
+ /// MLIR symbol table of the fir.global/func.func operations. Note that it is
+ /// not guaranteed to contain all operations of the ModuleOp with Symbol
+ /// attribute since mlirSymbolTable must pro-actively be maintained when
+ /// new Symbol operations are created.
+ mlir::SymbolTable mlirSymbolTable;
};
} // namespace
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index c65becc497459c..29cdb3cff589ba 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -667,11 +667,13 @@ void Fortran::lower::CallInterface<T>::declare() {
if (!side().isIndirectCall()) {
std::string name = side().getMangledName();
mlir::ModuleOp module = converter.getModuleOp();
- func = fir::FirOpBuilder::getNamedFunction(module, name);
+ mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
+ func = fir::FirOpBuilder::getNamedFunction(module, symbolTable, name);
if (!func) {
mlir::Location loc = side().getCalleeLocation();
mlir::FunctionType ty = genFunctionType();
- func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
+ func =
+ fir::FirOpBuilder::createFunction(loc, module, name, ty, symbolTable);
if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
if (side().isMainProgram()) {
func->setAttr(fir::getSymbolAttrName(),
@@ -1644,7 +1646,8 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
Fortran::lower::AbstractConverter &converter) {
mlir::ModuleOp module = converter.getModuleOp();
std::string name = getProcMangledName(proc, converter);
- mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
+ mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
+ module, converter.getMLIRSymbolTable(), name);
if (func)
return func;
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 6e6714454f0591..d933c07aba0e0c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3821,7 +3821,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
std::string funcName;
if (name) {
funcName = converter.mangleName(*name->symbol);
- funcOp = builder.getNamedFunction(mod, funcName);
+ funcOp =
+ builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
} else {
Fortran::semantics::Scope &scope =
semanticsContext.FindScope(routineConstruct.source);
@@ -3833,7 +3834,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
: nullptr};
if (subpDetails && subpDetails->isInterface()) {
funcName = converter.mangleName(*progUnit.symbol());
- funcOp = builder.getNamedFunction(mod, funcName);
+ funcOp =
+ builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
} else {
funcOp = builder.getFunction();
funcName = funcOp.getName();
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 2bcd5e5914027d..e4362b2f9e6945 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -36,26 +36,56 @@ static llvm::cl::opt<std::size_t>
"name"),
llvm::cl::init(32));
-mlir::func::FuncOp fir::FirOpBuilder::createFunction(mlir::Location loc,
- mlir::ModuleOp module,
- llvm::StringRef name,
- mlir::FunctionType ty) {
- return fir::createFuncOp(loc, module, name, ty);
+mlir::func::FuncOp
+fir::FirOpBuilder::createFunction(mlir::Location loc, mlir::ModuleOp module,
+ llvm::StringRef name, mlir::FunctionType ty,
+ mlir::SymbolTable *symbolTable) {
+ return fir::createFuncOp(loc, module, name, ty, /*attrs*/ {}, symbolTable);
}
-mlir::func::FuncOp fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
- llvm::StringRef name) {
+mlir::func::FuncOp
+fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
+ const mlir::SymbolTable *symbolTable,
+ llvm::StringRef name) {
+ if (symbolTable)
+ if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+ assert(func == modOp.lookupSymbol<mlir::func::FuncOp>(name) &&
+ "symbolTable and module out of sync");
+#endif
+ return func;
+ }
return modOp.lookupSymbol<mlir::func::FuncOp>(name);
}
mlir::func::FuncOp
fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
+ const mlir::SymbolTable *symbolTable,
mlir::SymbolRefAttr symbol) {
+ if (symbolTable)
+ if (auto func = symbolTable->lookup<mlir::func::FuncOp>(
+ symbol.getLeafReference())) {
+#ifdef EXPENSIVE_CHECKS
+ assert(func == modOp.lookupSymbol<mlir::func::FuncOp>(symbol) &&
+ "symbolTable and module out of sync");
+#endif
+ return func;
+ }
return modOp.lookupSymbol<mlir::func::FuncOp>(symbol);
}
-fir::GlobalOp fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
- llvm::StringRef name) {
+fir::GlobalOp
+fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
+ const mlir::SymbolTable *symbolTable,
+ llvm::StringRef name) {
+ if (symbolTable)
+ if (auto global = symbolTable->lookup<fir::GlobalOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+ assert(global == modOp.lookupSymbol<fir::GlobalOp>(name) &&
+ "symbolTable and module out of sync");
+#endif
+ return global;
+ }
return modOp.lookupSymbol<fir::GlobalOp>(name);
}
@@ -279,10 +309,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
mlir::Location loc, mlir::Type type, llvm::StringRef name,
mlir::StringAttr linkage, mlir::Attribute value, bool isConst,
bool isTarget, fir::CUDADataAttributeAttr cudaAttr) {
+ if (auto global = getNamedGlobal(name))
+ return global;
auto module = getModule();
auto insertPt = saveInsertionPoint();
- if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
- return glob;
setInsertionPoint(module.getBody(), module.getBody()->end());
llvm::SmallVector<mlir::NamedAttribute> attrs;
if (cudaAttr) {
@@ -294,6 +324,8 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value,
linkage, attrs);
restoreInsertionPoint(insertPt);
+ if (symbolTable)
+ symbolTable->insert(glob);
return glob;
}
@@ -301,10 +333,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
mlir::Location loc, mlir::Type type, llvm::StringRef name, bool isConst,
bool isTarget, std::function<void(FirOpBuilder &)> bodyBuilder,
mlir::StringAttr linkage, fir::CUDADataAttributeAttr cudaAttr) {
+ if (auto global = getNamedGlobal(name))
+ return global;
auto module = getModule();
auto insertPt = saveInsertionPoint();
- if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
- return glob;
setInsertionPoint(module.getBody(), module.getBody()->end());
auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type,
mlir::Attribute{}, linkage);
@@ -314,6 +346,8 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
setInsertionPointToStart(&block);
bodyBuilder(*this);
restoreInsertionPoint(insertPt);
+ if (symbolTable)
+ symbolTable->insert(glob);
return glob;
}
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ea1ef1f08aba20..069ba81cfe96ab 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -724,7 +724,7 @@ mlir::Value genLibCall(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::func::FuncOp funcOp = builder.getNamedFunction(libFuncName);
if (!funcOp) {
- funcOp = builder.addNamedFunction(loc, libFuncName, libFuncType);
+ funcOp = builder.createFunction(loc, libFuncName, libFuncType);
// C-interoperability rules apply to these library functions.
funcOp->setAttr(fir::getSymbolAttrName(),
mlir::StringAttr::get(builder.getContext(), libFuncName));
@@ -1894,8 +1894,8 @@ mlir::func::FuncOp IntrinsicLibrary::getWrapper(GeneratorType generator,
// Create local context to emit code into the newly created function
// This new function is not linked to a source file location, only
// its calls will be.
- auto localBuilder =
- std::make_unique<fir::FirOpBuilder>(function, builder.getKindMap());
+ auto localBuilder = std::make_unique<fir::FirOpBuilder>(
+ function, builder.getKindMap(), builder.getMLIRSymbolTable());
localBuilder->setFastMathFlags(builder.getFastMathFlags());
localBuilder->setInsertionPointToStart(&function.front());
// Location of code inside wrapper of the wrapper is independent from
diff --git a/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp b/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
index 1d07b1e724d745..bb5f77d5d4d1de 100644
--- a/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
+++ b/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
@@ -27,8 +27,8 @@ mlir::func::FuncOp fir::factory::getLlvmMemcpy(fir::FirOpBuilder &builder) {
builder.getI1Type()};
auto memcpyTy =
mlir::FunctionType::get(builder.getContext(), args, std::nullopt);
- return builder.addNamedFunction(builder.getUnknownLoc(),
- "llvm.memcpy.p0.p0.i64", memcpyTy);
+ return builder.createFunction(builder.getUnknownLoc(),
+ "llvm.memcpy.p0.p0.i64", memcpyTy);
}
mlir::func::FuncOp fir::factory::getLlvmMemmove(fir::FirOpBuilder &builder) {
@@ -37,8 +37,8 @@ mlir::func::FuncOp fir::factory::getLlvmMemmove(fir::FirOpBuilder &builder) {
builder.getI1Type()};
auto memmoveTy =
mlir::FunctionType::get(builder.getContext(), args, std::nullopt);
- return builder.addNamedFunction(builder.getUnknownLoc(),
- "llvm.memmove.p0.p0.i64", memmoveTy);
+ return builder.createFunction(builder.getUnknownLoc(),
+ "llvm.memmove.p0.p0.i64", memmoveTy);
}
mlir::func::FuncOp fir::factory::getLlvmMemset(fir::FirOpBuilder &builder) {
@@ -47,16 +47,15 @@ mlir::func::FuncOp fir::factory::getLlvmMemset(fir::FirOpBuilder &builder) {
builder.getI1Type()};
auto memsetTy =
mlir::FunctionType::get(builder.getContext(), args, std::nullopt);
- return builder.addNamedFunction(builder.getUnknownLoc(),
- "llvm.memset.p0.p0.i64", memsetTy);
+ return builder.createFunction(builder.getUnknownLoc(),
+ "llvm.memset.p0.p0.i64", memsetTy);
}
mlir::func::FuncOp fir::factory::getRealloc(fir::FirOpBuilder &builder) {
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
llvm::SmallVector<mlir::Type> args = {ptrTy, builder.getI64Type()};
auto reallocTy = mlir::FunctionType::get(builder.getContext(), args, {ptrTy});
- return builder.addNamedFunction(builder.getUnknownLoc(), "realloc",
- reallocTy);
+ return builder.createFunction(builder.getUnknownLoc(), "realloc", reallocTy);
}
mlir::func::FuncOp
@@ -64,8 +63,8 @@ fir::factory::getLlvmGetRounding(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), std::nullopt, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.get.rounding",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "llvm.get.rounding",
+ funcTy);
}
mlir::func::FuncOp
@@ -73,8 +72,8 @@ fir::factory::getLlvmSetRounding(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {int32Ty}, std::nullopt);
- return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.set.rounding",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "llvm.set.rounding",
+ funcTy);
}
mlir::func::FuncOp fir::factory::getLlvmStackSave(fir::FirOpBuilder &builder) {
@@ -82,8 +81,8 @@ mlir::func::FuncOp fir::factory::getLlvmStackSave(fir::FirOpBuilder &builder) {
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
auto funcTy =
mlir::FunctionType::get(builder.getContext(), std::nullopt, {ptrTy});
- return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.stacksave.p0",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "llvm.stacksave.p0",
+ funcTy);
}
mlir::func::FuncOp
@@ -92,8 +91,8 @@ fir::factory::getLlvmStackRestore(fir::FirOpBuilder &builder) {
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {ptrTy}, std::nullopt);
- return builder.addNamedFunction(builder.getUnknownLoc(),
- "llvm.stackrestore.p0", funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "llvm.stackrestore.p0",
+ funcTy);
}
mlir::func::FuncOp
@@ -101,24 +100,24 @@ fir::factory::getLlvmInitTrampoline(fir::FirOpBuilder &builder) {
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
auto funcTy = mlir::FunctionType::get(builder.getContext(),
{ptrTy, ptrTy, ptrTy}, std::nullopt);
- return builder.addNamedFunction(builder.getUnknownLoc(),
- "llvm.init.trampoline", funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "llvm.init.trampoline",
+ funcTy);
}
mlir::func::FuncOp
fir::factory::getLlvmAdjustTrampoline(fir::FirOpBuilder &builder) {
auto ptrTy = builder.getRefType(builder.getIntegerType(8));
auto funcTy = mlir::FunctionType::get(builder.getContext(), {ptrTy}, {ptrTy});
- return builder.addNamedFunction(builder.getUnknownLoc(),
- "llvm.adjust.trampoline", funcTy);
+ return builder.createFunction(builder.getUnknownLoc(),
+ "llvm.adjust.trampoline", funcTy);
}
mlir::func::FuncOp fir::factory::getFeclearexcept(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "feclearexcept",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "feclearexcept",
+ funcTy);
}
mlir::func::FuncOp
@@ -126,38 +125,37 @@ fir::factory::getFedisableexcept(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "fedisableexcept",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "fedisableexcept",
+ funcTy);
}
mlir::func::FuncOp fir::factory::getFeenableexcept(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "feenableexcept",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "feenableexcept",
+ funcTy);
}
mlir::func::FuncOp fir::factory::getFegetexcept(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), std::nullopt, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "fegetexcept",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "fegetexcept", funcTy);
}
mlir::func::FuncOp fir::factory::getFeraiseexcept(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "feraiseexcept",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "feraiseexcept",
+ funcTy);
}
mlir::func::FuncOp fir::factory::getFetestexcept(fir::FirOpBuilder &builder) {
auto int32Ty = builder.getIntegerType(32);
auto funcTy =
mlir::FunctionType::get(builder.getContext(), {int32Ty}, {int32Ty});
- return builder.addNamedFunction(builder.getUnknownLoc(), "fetestexcept",
- funcTy);
+ return builder.createFunction(builder.getUnknownLoc(), "fetestexcept",
+ funcTy);
}
diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
index ab0d5079d8afe0..e588b19dded4f1 100644
--- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
@@ -1084,11 +1084,11 @@ void PPCIntrinsicLibrary::genMtfsf(llvm::ArrayRef<fir::ExtendedValue> args) {
if (isImm) {
libFuncType = genFuncType<Ty::Void, Ty::Integer<4>, Ty::Integer<4>>(
builder.getContext(), builder);
- funcOp = builder.addNamedFunction(loc, "llvm.ppc.mtfsfi", libFuncType);
+ funcOp = builder.createFunction(loc, "llvm.ppc.mtfsfi", libFuncType);
} else {
libFuncType = genFuncType<Ty::Void, Ty::Integer<4>, Ty::Real<8>>(
builder.getContext(), builder);
- funcOp = builder.addNamedFunction(loc, "llvm.ppc.mtfsf", libFuncType);
+ funcOp = builder.createFunction(loc, "llvm.ppc.mtfsf", libFuncType);
}
builder.create<fir::CallOp>(loc, funcOp, scalarArgs);
}
@@ -1116,7 +1116,7 @@ PPCIntrinsicLibrary::genVecAbs(mlir::Type resultType,
genFuncType<Ty::RealVector<8>, Ty::RealVector<8>>(context, builder);
}
- funcOp = builder.addNamedFunction(loc, fname, ftype);
+ funcOp = builder.createFunction(loc, fname, ftype);
auto callOp{builder.create<fir::CallOp>(loc, funcOp, argBases[0])};
return callOp.getResult(0);
} else if (auto eleTy = vTypeInfo.eleTy.dyn_cast<mlir::IntegerType>()) {
@@ -1155,7 +1155,7 @@ PPCIntrinsicLibrary::genVecAbs(mlir::Type resultType,
default:
llvm_unreachable("invalid integer size");
}
- funcOp = builder.addNamedFunction(loc, fname, ftype);
+ funcOp = builder.createFunction(loc, fname, ftype);
mlir::Value args[] = {zeroSubVarg1, varg1};
auto callOp{builder.create<fir::CallOp>(loc, funcOp, args)};
@@ -1339,7 +1339,7 @@ PPCIntrinsicLibrary::genVecAnyCompare(mlir::Type resultType,
}
assert((!fname.empty() && ftype) && "invalid type");
- mlir::func::FuncOp funcOp{builder.addNamedFunction(loc, fname, ftype)};
+ mlir::func::FuncOp funcOp{builder.createFunction(loc, fname, ftype)};
auto callOp{builder.create<fir::CallOp>(loc, funcOp, cmpArgs)};
return callOp.getResult(0);
}
@@ -1445,7 +1445,7 @@ PPCIntrinsicLibrary::genVecCmp(mlir::Type resultType,
std::pair<llvm::StringRef, mlir::FunctionType> funcTyNam{
getVecCmpFuncTypeAndName(vecTyInfo, vop, builder)};
- mlir::func::FuncOp funcOp = builder.addNamedFunction(
+ mlir::func::FuncOp funcOp = builder.createFunction(
loc, std::get<0>(funcTyNam), std::get<1>(funcTyNam));
mlir::Value res{nullptr};
@@ -1572,7 +1572,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
Ty::Integer<4>>(context, builder)};
const llvm::StringRef fname{(isUnsigned) ? "llvm.ppc.altivec.vcfux"
: "llvm.ppc.altivec.vcfsx"};
- auto funcOp{builder.addNamedFunction(loc, fname, ftype)};
+ auto funcOp{builder.createFunction(loc, fname, ftype)};
mlir::Value newArgs[] = {argBases[0], convArg};
auto callOp{builder.create<fir::CallOp>(loc, funcOp, newArgs)};
@@ -1627,7 +1627,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
const llvm::StringRef fname{"llvm.ppc.vsx.xvcvspdp"};
auto ftype{
genFuncType<Ty::RealVector<8>, Ty::RealVector<4>>(context, builder)};
- auto funcOp{builder.addNamedFunction(loc, fname, ftype)};
+ auto funcOp{builder.createFunction(loc, fname, ftype)};
auto callOp{builder.create<fir::CallOp>(loc, funcOp, newArgs)};
return callOp.getResult(0);
@@ -1635,7 +1635,7 @@ PPCIntrinsicLibrary::genVecConvert(mlir::Type resultType,
const llvm::StringRef fname{"llvm.ppc.vsx.xvcvdpsp"};
auto ftype{
genFuncType<Ty::RealVector<4>, Ty::RealVector<8>>(context, builder)};
- auto funcOp{builder.addNamedFunction(loc, fname, ftype)};
+ auto funcOp{builder.createFunction(loc, fname, ftype)};
newArgs[0] =
builder.create<fir::CallOp>(loc, funcOp, newArgs).getResult(0);
auto fvf32Ty{newArgs[0].getType()};
@@ -1963,7 +1963,7 @@ PPCIntrinsicLibrary::genVecLdCallGrp(mlir::Type resultType,
auto funcType{
mlir::FunctionType::get(context, {addr.getType()}, {intrinResTy})};
- auto funcOp{builder.addNamedFunction(loc, fname, funcType)};
+ auto funcOp{builder.createFunction(loc, fname, funcType)};
auto result{
builder.create<fir::CallOp>(loc, funcOp, parsedArgs).getResult(0)};
@@ -2022,7 +2022,7 @@ PPCIntrinsicLibrary::genVecLvsGrp(mlir::Type resultType,
llvm_unreachable("invalid vector operation for generator");
}
auto funcType{mlir::FunctionType::get(context, {addr.getType()}, {mlirTy})};
- auto funcOp{builder.addNamedFunction(loc, fname, funcType)};
+ auto funcOp{builder.createFunction(loc, fname, funcType)};
auto result{
builder.create<fir::CallOp>(loc, funcOp, parsedArgs).getResult(0)};
@@ -2057,8 +2057,8 @@ PPCIntrinsicLibrary::genVecNmaddMsub(mlir::Type resultType,
genFuncType<Ty::RealVector<8>, Ty::RealVector<8>, Ty::RealVector<8>>(
context, builder))}};
- auto funcOp{builder.addNamedFunction(loc, std::get<0>(fmaMap[width]),
- std::get<1>(fmaMap[width]))};
+ auto funcOp{builder.createFunction(loc, std::get<0>(fmaMap[width]),
+ std::get<1>(fmaMap[width]))};
if (vop == VecOp::Nmadd) {
// vec_nmadd(arg1, arg2, arg3) = -fma(arg1, arg2, arg3)
auto callOp{builder.create<fir::CallOp>(loc, funcOp, newArgs)};
@@ -2110,7 +2110,7 @@ PPCIntrinsicLibrary::genVecPerm(mlir::Type resultType,
builder.create<mlir::LLVM::BitcastOp>(loc, vi32Ty, mArg1).getResult();
}
- auto funcOp{builder.addNamedFunction(
+ auto funcOp{builder.createFunction(
loc, "llvm.ppc.altivec.vperm",
genFuncType<Ty::IntegerVector<4>, Ty::IntegerVector<4>,
Ty::IntegerVector<4>, Ty::IntegerVector<1>>(context,
@@ -2307,7 +2307,7 @@ PPCIntrinsicLibrary::genVecShift(mlir::Type resultType,
}
auto funcTy{genFuncType<Ty::IntegerVector<4>, Ty::IntegerVector<4>,
Ty::IntegerVector<4>>(context, builder)};
- mlir::func::FuncOp funcOp{builder.addNamedFunction(loc, funcName, funcTy)};
+ mlir::func::FuncOp funcOp{builder.createFunction(loc, funcName, funcTy)};
auto callOp{builder.create<fir::CallOp>(loc, funcOp, mlirVecArgs)};
// If the result vector type is
diff erent from the original type, need
@@ -2755,7 +2755,7 @@ void PPCIntrinsicLibrary::genMmaIntr(llvm::ArrayRef<fir::ExtendedValue> args) {
auto context{builder.getContext()};
mlir::FunctionType intrFuncType{getMmaIrFuncType(context, IntrId)};
mlir::func::FuncOp funcOp{
- builder.addNamedFunction(loc, getMmaIrIntrName(IntrId), intrFuncType)};
+ builder.createFunction(loc, getMmaIrIntrName(IntrId), intrFuncType)};
llvm::SmallVector<mlir::Value> intrArgs;
// Depending on SubToFunc, change the subroutine call to a function call.
@@ -2892,7 +2892,7 @@ void PPCIntrinsicLibrary::genVecStore(llvm::ArrayRef<fir::ExtendedValue> args) {
auto funcType{
mlir::FunctionType::get(context, {stTy, addr.getType()}, std::nullopt)};
- mlir::func::FuncOp funcOp = builder.addNamedFunction(loc, fname, funcType);
+ mlir::func::FuncOp funcOp = builder.createFunction(loc, fname, funcType);
llvm::SmallVector<mlir::Value, 4> biArgs;
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 9bb10a42a3997c..dba2c30d1851bf 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3677,10 +3677,19 @@ fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result,
return mlir::success();
}
-mlir::func::FuncOp
-fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
- llvm::StringRef name, mlir::FunctionType type,
- llvm::ArrayRef<mlir::NamedAttribute> attrs) {
+mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
+ llvm::StringRef name,
+ mlir::FunctionType type,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs,
+ const mlir::SymbolTable *symbolTable) {
+ if (symbolTable)
+ if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+ assert(f == module.lookupSymbol<mlir::func::FuncOp>(name) &&
+ "symbolTable and module out of sync");
+#endif
+ return f;
+ }
if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name))
return f;
mlir::OpBuilder modBuilder(module.getBodyRegion());
@@ -3692,7 +3701,16 @@ fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
llvm::StringRef name, mlir::Type type,
- llvm::ArrayRef<mlir::NamedAttribute> attrs) {
+ llvm::ArrayRef<mlir::NamedAttribute> attrs,
+ const mlir::SymbolTable *symbolTable) {
+ if (symbolTable)
+ if (auto g = symbolTable->lookup<fir::GlobalOp>(name)) {
+#ifdef EXPENSIVE_CHECKS
+ assert(g == module.lookupSymbol<fir::GlobalOp>(name) &&
+ "symbolTable and module out of sync");
+#endif
+ return g;
+ }
if (auto g = module.lookupSymbol<fir::GlobalOp>(name))
return g;
mlir::OpBuilder modBuilder(module.getBodyRegion());
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index a11aa38c771bd1..f7820b6b8170ba 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -1004,10 +1004,8 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
// We can also avoid this by using internal linkage, but
// this may increase the size of final executable/shared library.
std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
- mlir::ModuleOp module = builder.getModule();
// If we already have a function, just return it.
- mlir::func::FuncOp newFunc =
- fir::FirOpBuilder::getNamedFunction(module, replacementName);
+ mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
mlir::FunctionType fType = typeGenerator(builder);
if (newFunc) {
assert(newFunc.getFunctionType() == fType &&
@@ -1017,8 +1015,7 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
// Need to build the function!
auto loc = mlir::UnknownLoc::get(builder.getContext());
- newFunc =
- fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
+ newFunc = builder.createFunction(loc, replacementName, fType);
auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
auto linkage =
mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
More information about the flang-commits
mailing list