[flang-commits] [flang] [flang][NFC] use mlir::SymbolTable in lowering (PR #86673)
via flang-commits
flang-commits at lists.llvm.org
Tue Mar 26 07:37:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-openacc
Author: None (jeanPerier)
<details>
<summary>Changes</summary>
Whenever lowering is checking if a function or global already exists in the mlir::Module, it was doing module->lookup.
On big programs (~5000 globals and functions), this causes important slowdowns because these lookups are linear. Use mlir::SymbolTable to speed-up these lookups. The SymbolTable has to be created from the ModuleOp and maintained in sync. It is therefore placed in the converter, and FirOPBuilders can take a pointer to it to speed-up the lookups.
This patch does not bring mlir::SymbolTable to FIR/HLFIR passes, but some passes creating a lot of runtime calls could benefit from it too. More analysis will be needed.
As an example of the speed-ups, this patch speeds-up compilation of Whizard compare_amplitude_UFO.F90 from 5 mins to 2 mins on my machine (there is still room for speed-ups).
---
Patch is 40.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86673.diff
12 Files Affected:
- (modified) flang/include/flang/Lower/AbstractConverter.h (+13)
- (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+35-35)
- (modified) flang/include/flang/Optimizer/Dialect/FIROpsSupport.h (+11-8)
- (modified) flang/lib/Lower/Bridge.cpp (+17-6)
- (modified) flang/lib/Lower/CallInterface.cpp (+6-3)
- (modified) flang/lib/Lower/OpenACC.cpp (+4-2)
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+31-14)
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+3-3)
- (modified) flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp (+30-32)
- (modified) flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp (+17-17)
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+13-5)
- (modified) flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (+2-5)
``````````diff
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 32e7a5e2b04061..d5dab9040d22bd 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -23,6 +23,10 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/ArrayRef.h"
+namespace mlir {
+class SymbolTable;
+}
+
namespace fir {
class KindMapping;
class FirOpBuilder;
@@ -305,6 +309,15 @@ class AbstractConverter {
virtual Fortran::lower::SymbolBox
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
+ /// Return the mlir::SymbolTable associated to the ModuleOp.
+ /// Look-ups are faster using it than using module.lookup<>,
+ /// but the module op should be queried in case of failure
+ /// because this symbol table is not guaranteed to contain
+ /// all the symbols from the ModuleOp (the symbol table should
+ /// always be provided to the builder helper creating globals and
+ /// functions in order to be in sync).
+ virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+
private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index d61bf681be6194..8537f29b2e549c 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -28,6 +28,10 @@
#include <optional>
#include <utility>
+namespace mlir {
+class SymbolTable;
+}
+
namespace fir {
class AbstractArrayBox;
class ExtendedValue;
@@ -42,8 +46,10 @@ class BoxValue;
/// patterns.
class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
public:
- explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap)
- : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {}
+ explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
+ mlir::SymbolTable *symbolTable = nullptr)
+ : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
+ symbolTable{symbolTable} {}
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
setListener(this);
@@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
// The listener self-reference has to be updated in case of copy-construction.
FirOpBuilder(const FirOpBuilder &other)
: OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
- fastMathFlags{other.fastMathFlags} {
+ fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
setListener(this);
}
FirOpBuilder(FirOpBuilder &&other)
- : OpBuilder(other), OpBuilder::Listener(),
- kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} {
+ : OpBuilder(other), OpBuilder::Listener(), kindMap{std::move(
+ other.kindMap)},
+ fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
setListener(this);
}
@@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Get a reference to the kind map.
const fir::KindMapping &getKindMap() { return kindMap; }
+ /// Get func.func/fir.global symbol table attached to this builder if any.
+ mlir::SymbolTable *getMLIRSymbolTable() { return symbolTable; }
+
/// Get the default integer type
[[maybe_unused]] mlir::IntegerType getDefaultIntegerType() {
return getIntegerType(
@@ -280,25 +290,28 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Get a function by name. If the function exists in the current module, it
/// is returned. Otherwise, a null FuncOp is returned.
mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
- return getNamedFunction(getModule(), name);
+ return getNamedFunction(getModule(), name, getMLIRSymbolTable());
}
- static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
- llvm::StringRef name);
+ static mlir::func::FuncOp
+ getNamedFunction(mlir::ModuleOp module, llvm::StringRef name,
+ const mlir::SymbolTable *symbolTable);
/// Get a function by symbol name. The result will be null if there is no
/// function with the given symbol in the module.
mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
- return getNamedFunction(getModule(), symbol);
+ return getNamedFunction(getModule(), symbol, getMLIRSymbolTable());
}
- static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
- mlir::SymbolRefAttr symbol);
+ static mlir::func::FuncOp
+ getNamedFunction(mlir::ModuleOp module, mlir::SymbolRefAttr symbol,
+ const mlir::SymbolTable *symbolTable);
fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
- return getNamedGlobal(getModule(), name);
+ return getNamedGlobal(getModule(), name, getMLIRSymbolTable());
}
static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
- llvm::StringRef name);
+ llvm::StringRef name,
+ const mlir::SymbolTable *symbolTable);
/// Lazy creation of fir.convert op.
mlir::Value createConvert(mlir::Location loc, mlir::Type toTy,
@@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// result of the load if it was created, otherwise return \p val
mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);
- /// Create a new FuncOp. If the function may have already been created, use
- /// `addNamedFunction` instead.
+ /// Determine if the named function is already in the module. Return the
+ /// instance if found, otherwise add a new named function to the module.
mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
mlir::FunctionType ty) {
- return createFunction(loc, getModule(), name, ty);
+ return createFunction(loc, getModule(), name, ty, getMLIRSymbolTable());
}
static mlir::func::FuncOp createFunction(mlir::Location loc,
mlir::ModuleOp module,
llvm::StringRef name,
- mlir::FunctionType ty);
-
- /// Determine if the named function is already in the module. Return the
- /// instance if found, otherwise add a new named function to the module.
- mlir::func::FuncOp addNamedFunction(mlir::Location loc, llvm::StringRef name,
- mlir::FunctionType ty) {
- if (auto func = getNamedFunction(name))
- return func;
- return createFunction(loc, name, ty);
- }
-
- static mlir::func::FuncOp addNamedFunction(mlir::Location loc,
- mlir::ModuleOp module,
- llvm::StringRef name,
- mlir::FunctionType ty) {
- if (auto func = getNamedFunction(module, name))
- return func;
- return createFunction(loc, module, name, ty);
- }
+ mlir::FunctionType ty,
+ mlir::SymbolTable *);
/// Cast the input value to IndexType.
mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
@@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// FastMathFlags that need to be set for operations that support
/// mlir::arith::FastMathAttr.
mlir::arith::FastMathFlags fastMathFlags{};
+
+ /// fir::GlobalOp and func::FuncOp symbol table to speed-up
+ /// lookups.
+ mlir::SymbolTable *symbolTable = nullptr;
};
} // namespace fir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index e8226b6df58ca2..f29e44504acb63 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -52,16 +52,19 @@ inline bool pureCall(mlir::Operation *op) {
/// Get or create a FuncOp in a module.
///
/// If `module` already contains FuncOp `name`, it is returned. Otherwise, a new
-/// FuncOp is created, and that new FuncOp is returned.
-mlir::func::FuncOp
-createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name,
- mlir::FunctionType type,
- llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
-
-/// Get or create a GlobalOp in a module.
+/// FuncOp is created, and that new FuncOp is returned. A symbol table can
+/// be provided to speed-up the lookups.
+mlir::func::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
+ llvm::StringRef name, mlir::FunctionType type,
+ llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+ const mlir::SymbolTable *symbolTable = nullptr);
+
+/// Get or create a GlobalOp in a module. A symbol table can be provided to
+/// speed-up the lookups.
fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
llvm::StringRef name, mlir::Type type,
- llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
+ llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+ const mlir::SymbolTable *symbolTable = nullptr);
/// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; }
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 48830dc55578c2..46a259d9ae86c9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -273,7 +273,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
public:
explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
: Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
- bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
+ bridge{bridge}, foldingContext{bridge.createFoldingContext()},
+ mlirSymbolTable{bridge.getModule()} {}
virtual ~FirConverter() = default;
/// Convert the PFT to FIR.
@@ -329,8 +330,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
- builder = new fir::FirOpBuilder(bridge.getModule(),
- bridge.getKindMap());
+ builder = new fir::FirOpBuilder(
+ bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
Fortran::lower::genOpenACCRoutineConstruct(
*this, bridge.getSemanticsContext(), bridge.getModule(),
d.routine, accRoutineInfos);
@@ -1036,6 +1037,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return {};
}
+ mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+
/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
@@ -4570,7 +4573,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
llvm::dbgs() << "\n");
Fortran::lower::CalleeInterface callee(funit, *this);
mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
- builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+ builder =
+ new fir::FirOpBuilder(func, bridge.getKindMap(), &mlirSymbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
builder->setInsertionPointToStart(&func.front());
@@ -4838,12 +4842,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// FIXME: get rid of the bogus function context and instantiate the
// globals directly into the module.
mlir::MLIRContext *context = &getMLIRContext();
+ mlir::SymbolTable *symbolTable = getMLIRSymbolTable();
mlir::func::FuncOp func = fir::FirOpBuilder::createFunction(
mlir::UnknownLoc::get(context), getModuleOp(),
fir::NameUniquer::doGenerated("Sham"),
- mlir::FunctionType::get(context, std::nullopt, std::nullopt));
+ mlir::FunctionType::get(context, std::nullopt, std::nullopt),
+ symbolTable);
func.addEntryBlock();
- builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+ builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
assert(builder && "FirOpBuilder did not instantiate");
builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
createGlobals();
@@ -5335,6 +5341,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// utilities to deal with procedure pointer components whose arguments have
/// the type of the containing derived type.
Fortran::lower::TypeConstructionStack typeConstructionStack;
+ /// MLIR symbol table of the fir.global/func.func operations. Note that it is
+ /// not guaranteed to contain all operations of the ModuleOp with Symbol
+ /// attribute since mlirSymbolTable must pro-actively be maintained when
+ /// new Symbol operations are created.
+ mlir::SymbolTable mlirSymbolTable;
};
} // namespace
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index c65becc497459c..fef38da0133060 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -667,11 +667,13 @@ void Fortran::lower::CallInterface<T>::declare() {
if (!side().isIndirectCall()) {
std::string name = side().getMangledName();
mlir::ModuleOp module = converter.getModuleOp();
- func = fir::FirOpBuilder::getNamedFunction(module, name);
+ mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
+ func = fir::FirOpBuilder::getNamedFunction(module, name, symbolTable);
if (!func) {
mlir::Location loc = side().getCalleeLocation();
mlir::FunctionType ty = genFunctionType();
- func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
+ func =
+ fir::FirOpBuilder::createFunction(loc, module, name, ty, symbolTable);
if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
if (side().isMainProgram()) {
func->setAttr(fir::getSymbolAttrName(),
@@ -1644,7 +1646,8 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
Fortran::lower::AbstractConverter &converter) {
mlir::ModuleOp module = converter.getModuleOp();
std::string name = getProcMangledName(proc, converter);
- mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
+ mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
+ module, name, converter.getMLIRSymbolTable());
if (func)
return func;
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 7b7e4a875cd8e8..0ef3baa19c0199 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3809,7 +3809,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
std::string funcName;
if (name) {
funcName = converter.mangleName(*name->symbol);
- funcOp = builder.getNamedFunction(mod, funcName);
+ funcOp =
+ builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
} else {
Fortran::semantics::Scope &scope =
semanticsContext.FindScope(routineConstruct.source);
@@ -3821,7 +3822,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
: nullptr};
if (subpDetails && subpDetails->isInterface()) {
funcName = converter.mangleName(*progUnit.symbol());
- funcOp = builder.getNamedFunction(mod, funcName);
+ funcOp =
+ builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
} else {
funcOp = builder.getFunction();
funcName = funcOp.getName();
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 2bcd5e5914027d..a8606a79af1671 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -36,26 +36,39 @@ static llvm::cl::opt<std::size_t>
"name"),
llvm::cl::init(32));
-mlir::func::FuncOp fir::FirOpBuilder::createFunction(mlir::Location loc,
- mlir::ModuleOp module,
- llvm::StringRef name,
- mlir::FunctionType ty) {
- return fir::createFuncOp(loc, module, name, ty);
+mlir::func::FuncOp
+fir::FirOpBuilder::createFunction(mlir::Location loc, mlir::ModuleOp module,
+ llvm::StringRef name, mlir::FunctionType ty,
+ mlir::SymbolTable *symbolTable) {
+ return fir::createFuncOp(loc, module, name, ty, /*attrs*/ {}, symbolTable);
}
-mlir::func::FuncOp fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
- llvm::StringRef name) {
+mlir::func::FuncOp
+fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, llvm::StringRef name,
+ const mlir::SymbolTable *symbolTable) {
+ if (symbolTable)
+ if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name))
+ return func;
return modOp.lookupSymbol<mlir::func::FuncOp>(name);
}
mlir::func::FuncOp
fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
- mlir::SymbolRefAttr symbol) {
+ mlir::SymbolRefAttr symbol,
+ const mlir::SymbolTable *symbolTable) {
+ if (symbolTable)
+ if (auto func =
+ symbolTable->lookup<mlir::func::FuncOp>(symbol.getLeafReference()))
+ return func;
return modOp.lookupSymbol<mlir::func::FuncOp>(symbol);
}
-fir::GlobalOp fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
- llvm::StringRef name) {
+fir::GlobalOp
+fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, llvm::StringRef name,
+ const mlir::SymbolTable *symbolTable) {
+ if (symbolTable)
+ if (auto global = symbolTable->lookup<fir::GlobalOp>(name))
+ return global;
return modOp.lookupSymbol<fir::GlobalOp>(name);
}
@@ -279,10 +292,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
mlir::Location loc, mlir::Type type, llvm::StringRef name,
mlir::StringAttr linkage, mlir::Attribute value, bool isConst,
bool isTarget, fir::CUDADataAttributeAttr cudaAttr) {
+ if (auto global = getNamedGlobal(name))
+ return global;
auto module = getModule();
auto insertPt = saveInsertionPoint();
- if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
- return glob;
setInsertionPoint(module.getBody(), module.getBody()->end());
llvm::SmallVector<mlir::NamedAttribute> attrs;
if (cudaAttr) {
@@ -294,6 +307,8 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value,
linkage, attrs);
restoreInsertionPoint(insertPt);
+ if (symbolTable)
+ symbolTable->insert(glob);
return glob;
}
@@ -301,10 +316,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
mlir::Location loc, mlir::Type type, llvm::StringRef name, bool isConst,
bool isTarget, std::function<void(FirOpBuilder &)> bodyBuilder,
mlir::StringAttr linkage, fir::CUDADataAttributeAttr cudaAttr) {
+ if (auto global = getNamedGlobal(name))
+ return global;
auto module = getModule();
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/86673
More information about the flang-commits
mailing list