[llvm-branch-commits] [flang] [Flang][OpenMP] Improve entry block argument creation and binding (PR #110267)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Sep 27 06:22:28 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
The main purpose of this patch is to centralize the logic for creating MLIR operation entry blocks and for binding them to the corresponding symbols. This minimizes the chances of mixing arguments up for operations having multiple entry block argument-generating clauses and prevents divergence while binding arguments.
Some changes implemented to this end are:
- Split into two functions the creation of the entry block, and the binding of its arguments and the corresponding Fortran symbol. This enabled a significant simplification of the lowering of composite constructs, where it's no longer necessary to manually ensure the lists of arguments and symbols refer to the same variables in the same order and also match the expected order by the `BlockArgOpenMPOpInterface`.
- Removed redundant and error-prone passing of types and locations from `ClauseProcessor` methods. Instead, these are obtained from the values in the appropriate clause operands structure. This also simplifies argument lists of several lowering functions.
- Access block arguments of already created MLIR operations through the `BlockArgOpenMPOpInterface` instead of directly indexing the argument list of the operation, which is not scalable as more entry block argument-generating clauses are added to an operation.
- Simplified the implementation of `genParallelOp` to no longer need to define different callbacks depending on whether delayed privatization is enabled.
---
Patch is 89.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110267.diff
7 Files Affected:
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+22-57)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+12-26)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+520-499)
- (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+2-3)
- (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+1-2)
- (modified) flang/lib/Lower/OpenMP/Utils.cpp (+1-8)
- (modified) flang/lib/Lower/OpenMP/Utils.h (+1-3)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e9ef8579100e93..44f5ca7f342707 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter,
static void addUseDeviceClause(
lower::AbstractConverter &converter, const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
- for (mlir::Value &operand : operands) {
+ for (mlir::Value &operand : operands)
checkMapType(operand.getLoc(), operand.getType());
- useDeviceTypes.push_back(operand.getType());
- useDeviceLocs.push_back(operand.getLoc());
- }
+
for (const omp::Object &object : objects)
useDeviceSyms.push_back(object.sym());
}
@@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
bool ClauseProcessor::processHasDeviceAddr(
mlir::omp::HasDeviceAddrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
- llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &devAddrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
- isDeviceTypes, isDeviceLocs, isDeviceSymbols);
+ isDeviceSyms);
});
}
@@ -864,14 +858,12 @@ bool ClauseProcessor::processIf(
bool ClauseProcessor::processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
- llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
- isDeviceTypes, isDeviceLocs, isDeviceSymbols);
+ isDeviceSyms);
});
}
@@ -892,9 +884,7 @@ void ClauseProcessor::processMapObjects(
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
- llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
@@ -927,12 +917,7 @@ void ClauseProcessor::processMapObjects(
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
} else {
mapVars.push_back(mapOp);
- if (mapSyms)
- mapSyms->push_back(object.sym());
- if (mapSymTypes)
- mapSymTypes->push_back(baseOp.getType());
- if (mapSymLocs)
- mapSymLocs->push_back(baseOp.getLoc());
+ mapSyms.push_back(object.sym());
}
}
}
@@ -940,9 +925,7 @@ void ClauseProcessor::processMapObjects(
bool ClauseProcessor::processMap(
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
// We always require tracking of symbols, even if the caller does not,
// so we create an optionally used local set of symbols when the mapSyms
// argument is not present.
@@ -999,12 +982,11 @@ bool ClauseProcessor::processMap(
}
processMapObjects(stmtCtx, clauseLocation,
std::get<omp::ObjectList>(clause.t), mapTypeBits,
- parentMemberIndices, result.mapVars, ptrMapSyms,
- mapSymLocs, mapSymTypes);
+ parentMemberIndices, result.mapVars, *ptrMapSyms);
});
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
- *ptrMapSyms, mapSymTypes, mapSymLocs);
+ *ptrMapSyms);
return clauseFound;
}
@@ -1027,7 +1009,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
processMapObjects(stmtCtx, clauseLocation, std::get<ObjectList>(clause.t),
mapTypeBits, parentMemberIndices, result.mapVars,
- &mapSymbols);
+ mapSymbols);
};
bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn);
@@ -1035,8 +1017,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
- mapSymbols,
- /*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
+ mapSymbols);
return clauseFound;
}
@@ -1054,8 +1035,7 @@ bool ClauseProcessor::processNontemporal(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
- llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> reductionVars;
@@ -1063,25 +1043,16 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(
- currentLocation, converter, clause, reductionVars, reduceVarByRef,
- reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
+ rp.addDeclareReduction(currentLocation, converter, clause,
+ reductionVars, reduceVarByRef,
+ reductionDeclSymbols, reductionSyms);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionSyms));
-
- if (outReductionTypes) {
- outReductionTypes->reserve(outReductionTypes->size() +
- reductionVars.size());
- llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
- [](mlir::Value v) { return v.getType(); });
- }
-
- if (outReductionSyms)
- llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
+ llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
});
}
@@ -1107,8 +1078,6 @@ bool ClauseProcessor::processEnter(
bool ClauseProcessor::processUseDeviceAddr(
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
@@ -1122,19 +1091,16 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.useDeviceAddrVars,
- &useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
+ useDeviceSyms);
});
insertChildMapInfoIntoParent(converter, parentMemberIndices,
- result.useDeviceAddrVars, useDeviceSyms,
- &useDeviceTypes, &useDeviceLocs);
+ result.useDeviceAddrVars, useDeviceSyms);
return clauseFound;
}
bool ClauseProcessor::processUseDevicePtr(
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
@@ -1148,12 +1114,11 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.useDevicePtrVars,
- &useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
+ useDeviceSyms);
});
insertChildMapInfoIntoParent(converter, parentMemberIndices,
- result.useDevicePtrVars, useDeviceSyms,
- &useDeviceTypes, &useDeviceLocs);
+ result.useDevicePtrVars, useDeviceSyms);
return clauseFound;
}
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 0c8e7bd47ab5a6..f34121c70d0b44 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -68,9 +68,7 @@ class ClauseProcessor {
mlir::omp::FinalClauseOps &result) const;
bool processHasDeviceAddr(
mlir::omp::HasDeviceAddrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
- llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
@@ -104,43 +102,33 @@ class ClauseProcessor {
mlir::omp::IfClauseOps &result) const;
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
- llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
// This method is used to process a map clause.
- // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
- // store the original type, location and Fortran symbol for the map operands.
- // They may be used later on to create the block_arguments for some of the
- // target directives that require it.
- bool processMap(
- mlir::Location currentLocation, lower::StatementContext &stmtCtx,
- mlir::omp::MapClauseOps &result,
- llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
+ // The optional parameter mapSyms is used to store the original Fortran symbol
+ // for the map operands. It may be used later on to create the block_arguments
+ // for some of the directives that require it.
+ bool processMap(mlir::Location currentLocation,
+ lower::StatementContext &stmtCtx,
+ mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
+ nullptr) const;
bool processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result);
bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
- llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
- nullptr) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
mlir::omp::UseDeviceAddrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
bool processUseDevicePtr(
lower::StatementContext &stmtCtx,
mlir::omp::UseDevicePtrClauseOps &result,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
// Call this method for these clauses that should be supported but are not
@@ -181,9 +169,7 @@ class ClauseProcessor {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
- llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
+ llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
lower::AbstractConverter &converter;
semantics::SemanticsContext &semaCtx;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index e9095d631beb7b..e617734619f483 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -45,6 +45,36 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//
+namespace {
+struct EntryBlockArgsEntry {
+ llvm::ArrayRef<const semantics::Symbol *> syms;
+ llvm::ArrayRef<mlir::Value> vars;
+
+ bool isValid() const {
+ // This check allows specifying a smaller number of symbols than values
+ // because in some case cases a single symbol generates multiple block
+ // arguments.
+ return syms.size() <= vars.size();
+ }
+};
+
+struct EntryBlockArgs {
+ EntryBlockArgsEntry inReduction;
+ EntryBlockArgsEntry map;
+ EntryBlockArgsEntry priv;
+ EntryBlockArgsEntry reduction;
+ EntryBlockArgsEntry taskReduction;
+ EntryBlockArgsEntry useDeviceAddr;
+ EntryBlockArgsEntry useDevicePtr;
+
+ bool isValid() const {
+ return inReduction.isValid() && map.isValid() && priv.isValid() &&
+ reduction.isValid() && taskReduction.isValid() &&
+ useDeviceAddr.isValid() && useDevicePtr.isValid();
+ }
+};
+} // namespace
+
static void genOMPDispatch(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
@@ -52,6 +82,163 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
const ConstructQueue &queue,
ConstructQueue::const_iterator item);
+/// Bind symbols to their corresponding entry block arguments.
+///
+/// The binding will be performed inside of the current block, which does not
+/// necessarily have to be part of the operation for which the binding is done.
+/// However, block arguments must be accessible. This enables controlling the
+/// insertion point of any new MLIR operations related to the binding of
+/// arguments of a loop wrapper operation.
+///
+/// \param [in] converter - PFT to MLIR conversion interface.
+/// \param [in] op - owner operation of the block arguments to bind.
+/// \param [in] args - entry block arguments information for the given
+/// operation.
+static void bindEntryBlockArgs(lower::AbstractConverter &converter,
+ mlir::omp::BlockArgOpenMPOpInterface op,
+ const EntryBlockArgs &args) {
+ assert(op != nullptr && "invalid block argument-defining operation");
+ assert(args.isValid() && "invalid args");
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ auto bindSingleMapLike = [&converter,
+ &firOpBuilder](const semantics::Symbol &sym,
+ const mlir::BlockArgument &arg) {
+ // Clones the `bounds` placing them inside the entry block and returns
+ // them.
+ auto cloneBound = [&](mlir::Value bound) {
+ if (mlir::isMemoryEffectFree(bound.getDefiningOp())) {
+ mlir::Operation *clonedOp = firOpBuilder.clone(*bound.getDefiningOp());
+ return clonedOp->getResult(0);
+ }
+ TODO(converter.getCurrentLocation(),
+ "target map-like clause operand unsupported bound type");
+ };
+
+ auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) {
+ llvm::SmallVector<mlir::Value> clonedBounds;
+ llvm::transform(bounds, std::back_inserter(clonedBounds),
+ [&](mlir::Value bound) { return cloneBound(bound); });
+ return clonedBounds;
+ };
+
+ fir::ExtendedValue extVal = converter.getSymbolExtendedValue(sym);
+ auto refType = mlir::dyn_cast<fir::ReferenceType>(arg.getType());
+ if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) {
+ converter.bindSymbol(sym, arg);
+ } else {
+ extVal.match(
+ [&](const fir::BoxValue &v) {
+ converter.bindSymbol(sym,
+ fir::BoxValue(arg, cloneBounds(v.getLBounds()),
+ v.getExplicitParameters(),
+ v.getExplicitExtents()));
+ },
+ [&](const fir::MutableBoxValue &v) {
+ converter.bindSymbol(
+ sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
+ v.getMutableProperties()));
+ },
+ [&](const fir::ArrayBoxValue &v) {
+ converter.bindSymbol(
+ sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
+ cloneBounds(v.getLBounds()),
+ v.getSourceBox()));
+ },
+ [&](const fir::CharArrayBoxValue &v) {
+ converter.bindSymbol(
+ sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()),
+ cloneBounds(v.getExtents()),
+ cloneBounds(v.getLBounds())));
+ },
+ [&](const fir::CharBoxValue &v) {
+ converter.bindSymbol(
+ sym, fir::CharBoxValue(arg, cloneBound(v.getLen())));
+ },
+ [&](const fir::UnboxedValue &v) { converter.bindSymbol(sym, arg); },
+ [&](const auto &) {
+ TODO(converter.getCurrentLocation(),
+ "target map clause operand unsupported type");
+ });
+ }
+ };
+
+ auto bindMapLike =
+ [&bindSingl...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/110267
More information about the llvm-branch-commits
mailing list