[llvm] [mlir] {OpenMP][OpenMPIRBuilder]Add ReductionInfoManager (PR #81300)
Jan Leyonberg via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 9 11:08:06 PST 2024
https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/81300
Add reduction info manager that will be used to communicate information between distribute and workshare lowerings.
>From 763b01fb06679f40efb235899476bb8e7378d487 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Fri, 9 Feb 2024 14:00:18 -0500
Subject: [PATCH] {OpenMP][OpenMPIRBuilder]Add ReductionInfoManager
Add reduction info manager that will be used to communicate information between
distribute and workshare lowerings.
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 63 ++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 120 +++++++-----------
mlir/test/Target/LLVMIR/openmp-reduction.mlir | 2 +-
3 files changed, 108 insertions(+), 77 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 2288969ecc95c4..6b1403afc80e78 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1236,15 +1236,15 @@ class OpenMPIRBuilder {
/// Functions used to generate reductions. Such functions take two Values
/// representing LHS and RHS of the reduction, respectively, and a reference
/// to the value that is updated to refer to the reduction result.
- using ReductionGenTy =
- function_ref<InsertPointTy(InsertPointTy, Value *, Value *, Value *&)>;
+ using ReductionGenTy = std::function<OpenMPIRBuilder::InsertPointTy(
+ OpenMPIRBuilder::InsertPointTy, Value *, Value *, Value *&)>;
/// Functions used to generate atomic reductions. Such functions take two
/// Values representing pointers to LHS and RHS of the reduction, as well as
/// the element type of these pointers. They are expected to atomically
/// update the LHS to the reduced value.
- using AtomicReductionGenTy =
- function_ref<InsertPointTy(InsertPointTy, Type *, Value *, Value *)>;
+ using AtomicReductionGenTy = std::function<OpenMPIRBuilder::InsertPointTy(
+ OpenMPIRBuilder::InsertPointTy, Type *, Value *, Value *)>;
/// Information about an OpenMP reduction.
struct ReductionInfo {
@@ -1254,6 +1254,10 @@ class OpenMPIRBuilder {
: ElementType(ElementType), Variable(Variable),
PrivateVariable(PrivateVariable), ReductionGen(ReductionGen),
AtomicReductionGen(AtomicReductionGen) {}
+ ReductionInfo(Value *PrivateVariable)
+ : ElementType(nullptr), Variable(nullptr),
+ PrivateVariable(PrivateVariable), ReductionGen(),
+ AtomicReductionGen() {}
/// Reduction element type, must match pointee type of variable.
Type *ElementType;
@@ -1276,6 +1280,54 @@ class OpenMPIRBuilder {
AtomicReductionGenTy AtomicReductionGen;
};
+ /// A class that manages the reduction info to facilitate lowering of
+ /// reductions at multiple levels of parallelism. For example handling teams
+ /// and parallel reductions on GPUs
+
+ class ReductionInfoManager {
+ private:
+ SmallVector<ReductionInfo> ReductionInfos;
+ std::optional<InsertPointTy> PrivateVarAllocaIP;
+
+ public:
+ ReductionInfoManager(){};
+ void clear() {
+ ReductionInfos.clear();
+ PrivateVarAllocaIP.reset();
+ }
+
+ Value *
+ allocatePrivateReductionVar(IRBuilderBase &builder,
+ llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+ Type *VarType) {
+ llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+ llvm::Value *var = builder.CreateAlloca(VarType);
+ var->setName("private_redvar");
+ llvm::Value *castVar =
+ builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+ ReductionInfos.push_back(ReductionInfo(castVar));
+ return castVar;
+ }
+
+ ReductionInfo getReductionInfo(unsigned Index) {
+ return ReductionInfos[Index];
+ }
+ ReductionInfo setReductionInfo(unsigned Index, ReductionInfo &RI) {
+ return ReductionInfos[Index] = RI;
+ }
+ Value *getPrivateReductionVariable(unsigned Index) {
+ return ReductionInfos[Index].PrivateVariable;
+ }
+ SmallVector<ReductionInfo> &getReductionInfos() { return ReductionInfos; }
+
+ bool hasPrivateVarAllocaIP() { return PrivateVarAllocaIP.has_value(); }
+ InsertPointTy getPrivateVarAllocaIP() {
+ assert(PrivateVarAllocaIP.has_value() && "AllocaIP not set");
+ return *PrivateVarAllocaIP;
+ }
+ void setPrivateVarAllocaIP(InsertPointTy IP) { PrivateVarAllocaIP = IP; }
+ };
+
// TODO: provide atomic and non-atomic reduction generators for reduction
// operators defined by the OpenMP specification.
@@ -1481,6 +1533,9 @@ class OpenMPIRBuilder {
/// Info manager to keep track of target regions.
OffloadEntriesInfoManager OffloadInfoManager;
+ /// Info manager to keep track of reduction information.
+ ReductionInfoManager RIManager;
+
/// The target triple of the underlying module.
const Triple T;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 79956f82ed141a..535b40f7151e75 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -439,28 +439,16 @@ static LogicalResult inlineConvertOmpRegions(
return success();
}
-namespace {
-/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
-/// store lambdas with capture.
-using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
- llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
- llvm::Value *&)>;
-using OwningAtomicReductionGen =
- std::function<llvm::OpenMPIRBuilder::InsertPointTy(
- llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
- llvm::Value *)>;
-} // namespace
-
/// Create an OpenMPIRBuilder-compatible reduction generator for the given
/// reduction declaration. The generator uses `builder` but ignores its
/// insertion point.
-static OwningReductionGen
+static llvm::OpenMPIRBuilder::ReductionGenTy
makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
// The lambda is mutable because we need access to non-const methods of decl
// (which aren't actually mutating it), and we must capture decl by-value to
// avoid the dangling reference after the parent function returns.
- OwningReductionGen gen =
+ llvm::OpenMPIRBuilder::ReductionGenTy gen =
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
llvm::Value *lhs, llvm::Value *rhs,
llvm::Value *&result) mutable {
@@ -484,17 +472,17 @@ makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
/// given reduction declaration. The generator uses `builder` but ignores its
/// insertion point. Returns null if there is no atomic region available in the
/// reduction declaration.
-static OwningAtomicReductionGen
+static llvm::OpenMPIRBuilder::AtomicReductionGenTy
makeAtomicReductionGen(omp::ReductionDeclareOp decl,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (decl.getAtomicReductionRegion().empty())
- return OwningAtomicReductionGen();
+ return llvm::OpenMPIRBuilder::AtomicReductionGenTy();
// The lambda is mutable because we need access to non-const methods of decl
// (which aren't actually mutating it), and we must capture decl by-value to
// avoid the dangling reference after the parent function returns.
- OwningAtomicReductionGen atomicGen =
+ llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen =
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
llvm::Value *lhs, llvm::Value *rhs) mutable {
Region &atomicRegion = decl.getAtomicReductionRegion();
@@ -781,55 +769,50 @@ convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder,
template <typename T>
static void
allocReductionVars(T loop, llvm::IRBuilderBase &builder,
+ llvm::OpenMPIRBuilder &ompBuilder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVector<omp::ReductionDeclareOp> &reductionDecls,
- SmallVector<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap) {
unsigned numReductions = loop.getNumReductionVars();
- privateReductionVariables.reserve(numReductions);
if (numReductions != 0) {
llvm::IRBuilderBase::InsertPointGuard guard(builder);
- builder.restoreIP(allocaIP);
+ llvm::OpenMPIRBuilder::InsertPointTy curIP = builder.saveIP();
+
+ if (!ompBuilder.RIManager.hasPrivateVarAllocaIP())
+ ompBuilder.RIManager.setPrivateVarAllocaIP(allocaIP);
+ builder.restoreIP(ompBuilder.RIManager.getPrivateVarAllocaIP());
for (unsigned i = 0; i < numReductions; ++i) {
- llvm::Value *var = builder.CreateAlloca(
+ llvm::Value *var = ompBuilder.RIManager.allocatePrivateReductionVar(
+ builder, allocaIP,
moduleTranslation.convertType(reductionDecls[i].getType()));
- privateReductionVariables.push_back(var);
reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
}
+ builder.restoreIP(curIP);
}
}
/// Collect reduction info
template <typename T>
-static void collectReductionInfo(
- T loop, llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation,
- SmallVector<omp::ReductionDeclareOp> &reductionDecls,
- SmallVector<OwningReductionGen> &owningReductionGens,
- SmallVector<OwningAtomicReductionGen> &owningAtomicReductionGens,
- const SmallVector<llvm::Value *> &privateReductionVariables,
- SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
+static void
+collectReductionInfo(T &loop, llvm::IRBuilderBase &builder,
+ llvm::OpenMPIRBuilder &ompBuilder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ SmallVector<omp::ReductionDeclareOp> &reductionDecls) {
unsigned numReductions = loop.getNumReductionVars();
for (unsigned i = 0; i < numReductions; ++i) {
- owningReductionGens.push_back(
- makeReductionGen(reductionDecls[i], builder, moduleTranslation));
- owningAtomicReductionGens.push_back(
- makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
- }
-
- // Collect the reduction information.
- reductionInfos.reserve(numReductions);
- for (unsigned i = 0; i < numReductions; ++i) {
- llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
- if (owningAtomicReductionGens[i])
- atomicGen = owningAtomicReductionGens[i];
llvm::Value *variable =
moduleTranslation.lookupValue(loop.getReductionVars()[i]);
- reductionInfos.push_back(
- {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
- privateReductionVariables[i], owningReductionGens[i], atomicGen});
+ llvm::OpenMPIRBuilder::ReductionInfo RI =
+ ompBuilder.RIManager.getReductionInfo(i);
+ RI.Variable = variable;
+ RI.ElementType = moduleTranslation.convertType(reductionDecls[i].getType());
+ RI.ReductionGen =
+ makeReductionGen(reductionDecls[i], builder, moduleTranslation);
+ RI.AtomicReductionGen =
+ makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation);
+ ompBuilder.RIManager.setReductionInfo(i, RI);
}
}
@@ -837,6 +820,7 @@ static void collectReductionInfo(
static LogicalResult
convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto loop = cast<omp::WsLoopOp>(opInst);
// TODO: this should be in the op verifier instead.
if (loop.getLowerBound().empty())
@@ -861,10 +845,9 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
- SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
- allocReductionVars(loop, builder, moduleTranslation, allocaIP, reductionDecls,
- privateReductionVariables, reductionVariableMap);
+ allocReductionVars(loop, builder, *ompBuilder, moduleTranslation, allocaIP,
+ reductionDecls, reductionVariableMap);
// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
@@ -883,7 +866,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
assert(phis.size() == 1 && "expected one value to be yielded from the "
"reduction neutral element declaration region");
- builder.CreateStore(phis[0], privateReductionVariables[i]);
+ builder.CreateStore(phis[0],
+ ompBuilder->RIManager.getPrivateReductionVariable(i));
}
// Set up the source location value for OpenMP runtime.
@@ -918,7 +902,6 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
// i.e. it has a positive step, uses signed integer semantics. Reconsider
// this code when WsLoop clearly supports more cases.
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
llvm::Value *lowerBound =
moduleTranslation.lookupValue(loop.getLowerBound()[i]);
@@ -975,12 +958,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
// Create the reduction generators. We need to own them here because
// ReductionInfo only accepts references to the generators.
- SmallVector<OwningReductionGen> owningReductionGens;
- SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
- SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
- collectReductionInfo(loop, builder, moduleTranslation, reductionDecls,
- owningReductionGens, owningAtomicReductionGens,
- privateReductionVariables, reductionInfos);
+ collectReductionInfo(loop, builder, *ompBuilder, moduleTranslation,
+ reductionDecls);
// The call to createReductions below expects the block to have a
// terminator. Create an unreachable instruction to serve as terminator
@@ -988,7 +967,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
builder.SetInsertPoint(tempTerminator);
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
- ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
+ ompBuilder->createReductions(builder.saveIP(), allocaIP,
+ ompBuilder->RIManager.getReductionInfos(),
loop.getNowait());
if (!contInsertPoint.getBlock())
return loop->emitOpError() << "failed to convert reductions";
@@ -996,7 +976,7 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
tempTerminator->eraseFromParent();
builder.restoreIP(nextInsertionPoint);
-
+ ompBuilder->RIManager.clear();
return success();
}
@@ -1016,11 +996,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
collectReductionDecls(opInst, reductionDecls);
// Allocate reduction vars
- SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
- allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
- reductionDecls, privateReductionVariables,
- reductionVariableMap);
+ allocReductionVars(opInst, builder, *ompBuilder, moduleTranslation,
+ allocaIP, reductionDecls, reductionVariableMap);
// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
@@ -1040,7 +1018,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
"expected one value to be yielded from the "
"reduction neutral element declaration region");
builder.restoreIP(allocaIP);
- builder.CreateStore(phis[0], privateReductionVariables[i]);
+ builder.CreateStore(phis[0],
+ ompBuilder->RIManager.getPrivateReductionVariable(i));
}
// Save the alloca insertion point on ModuleTranslation stack for use in
@@ -1057,12 +1036,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// Process the reductions if required.
if (opInst.getNumReductionVars() > 0) {
// Collect reduction info
- SmallVector<OwningReductionGen> owningReductionGens;
- SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
- SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
- collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
- owningReductionGens, owningAtomicReductionGens,
- privateReductionVariables, reductionInfos);
+ collectReductionInfo(opInst, builder, *ompBuilder, moduleTranslation,
+ reductionDecls);
// Move to region cont block
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1072,8 +1047,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
builder.SetInsertPoint(tempTerminator);
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
- ompBuilder->createReductions(builder.saveIP(), allocaIP,
- reductionInfos, false);
+ ompBuilder->createReductions(
+ builder.saveIP(), allocaIP,
+ ompBuilder->RIManager.getReductionInfos(), false);
if (!contInsertPoint.getBlock()) {
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
return;
@@ -1118,7 +1094,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
builder.restoreIP(
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
ifCond, numThreads, pbKind, isCancellable));
-
+ ompBuilder->RIManager.clear();
return bodyGenStatus;
}
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
index 93ab578df9e4e8..ea74ca474420e8 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
@@ -532,7 +532,7 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
// CHECK: define internal void @[[OUTLINED]]
// Private reduction variable and its initialization.
-// CHECK: %[[PRIVATE:[0-9]+]] = alloca i32
+// CHECK: %[[PRIVATE:private_redvar]] = alloca i32
// CHECK: store i32 0, ptr %[[PRIVATE]]
// Loop exit:
More information about the llvm-commits
mailing list