[llvm] [mlir] {OpenMP][OpenMPIRBuilder]Add ReductionInfoManager (PR #81300)

Jan Leyonberg via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 9 11:08:06 PST 2024


https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/81300

Add reduction info manager that will be used to communicate information between distribute and workshare lowerings.


>From 763b01fb06679f40efb235899476bb8e7378d487 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Fri, 9 Feb 2024 14:00:18 -0500
Subject: [PATCH] {OpenMP][OpenMPIRBuilder]Add ReductionInfoManager

Add reduction info manager that will be used to communicate information between
distribute and workshare lowerings.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  63 ++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 120 +++++++-----------
 mlir/test/Target/LLVMIR/openmp-reduction.mlir |   2 +-
 3 files changed, 108 insertions(+), 77 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 2288969ecc95c4..6b1403afc80e78 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1236,15 +1236,15 @@ class OpenMPIRBuilder {
   /// Functions used to generate reductions. Such functions take two Values
   /// representing LHS and RHS of the reduction, respectively, and a reference
   /// to the value that is updated to refer to the reduction result.
-  using ReductionGenTy =
-      function_ref<InsertPointTy(InsertPointTy, Value *, Value *, Value *&)>;
+  using ReductionGenTy = std::function<OpenMPIRBuilder::InsertPointTy(
+      OpenMPIRBuilder::InsertPointTy, Value *, Value *, Value *&)>;
 
   /// Functions used to generate atomic reductions. Such functions take two
   /// Values representing pointers to LHS and RHS of the reduction, as well as
   /// the element type of these pointers. They are expected to atomically
   /// update the LHS to the reduced value.
-  using AtomicReductionGenTy =
-      function_ref<InsertPointTy(InsertPointTy, Type *, Value *, Value *)>;
+  using AtomicReductionGenTy = std::function<OpenMPIRBuilder::InsertPointTy(
+      OpenMPIRBuilder::InsertPointTy, Type *, Value *, Value *)>;
 
   /// Information about an OpenMP reduction.
   struct ReductionInfo {
@@ -1254,6 +1254,10 @@ class OpenMPIRBuilder {
         : ElementType(ElementType), Variable(Variable),
           PrivateVariable(PrivateVariable), ReductionGen(ReductionGen),
           AtomicReductionGen(AtomicReductionGen) {}
+    ReductionInfo(Value *PrivateVariable)
+        : ElementType(nullptr), Variable(nullptr),
+          PrivateVariable(PrivateVariable), ReductionGen(),
+          AtomicReductionGen() {}
 
     /// Reduction element type, must match pointee type of variable.
     Type *ElementType;
@@ -1276,6 +1280,54 @@ class OpenMPIRBuilder {
     AtomicReductionGenTy AtomicReductionGen;
   };
 
+  /// A class that manages the reduction info to facilitate lowering of
+  /// reductions at multiple levels of parallelism. For example handling teams
+  /// and parallel reductions on GPUs
+
+  class ReductionInfoManager {
+  private:
+    SmallVector<ReductionInfo> ReductionInfos;
+    std::optional<InsertPointTy> PrivateVarAllocaIP;
+
+  public:
+    ReductionInfoManager(){};
+    void clear() {
+      ReductionInfos.clear();
+      PrivateVarAllocaIP.reset();
+    }
+
+    Value *
+    allocatePrivateReductionVar(IRBuilderBase &builder,
+                                llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+                                Type *VarType) {
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *var = builder.CreateAlloca(VarType);
+      var->setName("private_redvar");
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      ReductionInfos.push_back(ReductionInfo(castVar));
+      return castVar;
+    }
+
+    ReductionInfo getReductionInfo(unsigned Index) {
+      return ReductionInfos[Index];
+    }
+    ReductionInfo setReductionInfo(unsigned Index, ReductionInfo &RI) {
+      return ReductionInfos[Index] = RI;
+    }
+    Value *getPrivateReductionVariable(unsigned Index) {
+      return ReductionInfos[Index].PrivateVariable;
+    }
+    SmallVector<ReductionInfo> &getReductionInfos() { return ReductionInfos; }
+
+    bool hasPrivateVarAllocaIP() { return PrivateVarAllocaIP.has_value(); }
+    InsertPointTy getPrivateVarAllocaIP() {
+      assert(PrivateVarAllocaIP.has_value() && "AllocaIP not set");
+      return *PrivateVarAllocaIP;
+    }
+    void setPrivateVarAllocaIP(InsertPointTy IP) { PrivateVarAllocaIP = IP; }
+  };
+
   // TODO: provide atomic and non-atomic reduction generators for reduction
   // operators defined by the OpenMP specification.
 
@@ -1481,6 +1533,9 @@ class OpenMPIRBuilder {
   /// Info manager to keep track of target regions.
   OffloadEntriesInfoManager OffloadInfoManager;
 
+  /// Info manager to keep track of reduction information.
+  ReductionInfoManager RIManager;
+
   /// The target triple of the underlying module.
   const Triple T;
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 79956f82ed141a..535b40f7151e75 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -439,28 +439,16 @@ static LogicalResult inlineConvertOmpRegions(
   return success();
 }
 
-namespace {
-/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
-/// store lambdas with capture.
-using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
-    llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
-    llvm::Value *&)>;
-using OwningAtomicReductionGen =
-    std::function<llvm::OpenMPIRBuilder::InsertPointTy(
-        llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
-        llvm::Value *)>;
-} // namespace
-
 /// Create an OpenMPIRBuilder-compatible reduction generator for the given
 /// reduction declaration. The generator uses `builder` but ignores its
 /// insertion point.
-static OwningReductionGen
+static llvm::OpenMPIRBuilder::ReductionGenTy
 makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
   // The lambda is mutable because we need access to non-const methods of decl
   // (which aren't actually mutating it), and we must capture decl by-value to
   // avoid the dangling reference after the parent function returns.
-  OwningReductionGen gen =
+  llvm::OpenMPIRBuilder::ReductionGenTy gen =
       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
                 llvm::Value *lhs, llvm::Value *rhs,
                 llvm::Value *&result) mutable {
@@ -484,17 +472,17 @@ makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
 /// given reduction declaration. The generator uses `builder` but ignores its
 /// insertion point. Returns null if there is no atomic region available in the
 /// reduction declaration.
-static OwningAtomicReductionGen
+static llvm::OpenMPIRBuilder::AtomicReductionGenTy
 makeAtomicReductionGen(omp::ReductionDeclareOp decl,
                        llvm::IRBuilderBase &builder,
                        LLVM::ModuleTranslation &moduleTranslation) {
   if (decl.getAtomicReductionRegion().empty())
-    return OwningAtomicReductionGen();
+    return llvm::OpenMPIRBuilder::AtomicReductionGenTy();
 
   // The lambda is mutable because we need access to non-const methods of decl
   // (which aren't actually mutating it), and we must capture decl by-value to
   // avoid the dangling reference after the parent function returns.
-  OwningAtomicReductionGen atomicGen =
+  llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen =
       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
                 llvm::Value *lhs, llvm::Value *rhs) mutable {
         Region &atomicRegion = decl.getAtomicReductionRegion();
@@ -781,55 +769,50 @@ convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder,
 template <typename T>
 static void
 allocReductionVars(T loop, llvm::IRBuilderBase &builder,
+                   llvm::OpenMPIRBuilder &ompBuilder,
                    LLVM::ModuleTranslation &moduleTranslation,
                    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
                    SmallVector<omp::ReductionDeclareOp> &reductionDecls,
-                   SmallVector<llvm::Value *> &privateReductionVariables,
                    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
   unsigned numReductions = loop.getNumReductionVars();
-  privateReductionVariables.reserve(numReductions);
   if (numReductions != 0) {
     llvm::IRBuilderBase::InsertPointGuard guard(builder);
-    builder.restoreIP(allocaIP);
+    llvm::OpenMPIRBuilder::InsertPointTy curIP = builder.saveIP();
+
+    if (!ompBuilder.RIManager.hasPrivateVarAllocaIP())
+      ompBuilder.RIManager.setPrivateVarAllocaIP(allocaIP);
+    builder.restoreIP(ompBuilder.RIManager.getPrivateVarAllocaIP());
     for (unsigned i = 0; i < numReductions; ++i) {
-      llvm::Value *var = builder.CreateAlloca(
+      llvm::Value *var = ompBuilder.RIManager.allocatePrivateReductionVar(
+          builder, allocaIP,
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      privateReductionVariables.push_back(var);
       reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
     }
+    builder.restoreIP(curIP);
   }
 }
 
 /// Collect reduction info
 template <typename T>
-static void collectReductionInfo(
-    T loop, llvm::IRBuilderBase &builder,
-    LLVM::ModuleTranslation &moduleTranslation,
-    SmallVector<omp::ReductionDeclareOp> &reductionDecls,
-    SmallVector<OwningReductionGen> &owningReductionGens,
-    SmallVector<OwningAtomicReductionGen> &owningAtomicReductionGens,
-    const SmallVector<llvm::Value *> &privateReductionVariables,
-    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
+static void
+collectReductionInfo(T &loop, llvm::IRBuilderBase &builder,
+                     llvm::OpenMPIRBuilder &ompBuilder,
+                     LLVM::ModuleTranslation &moduleTranslation,
+                     SmallVector<omp::ReductionDeclareOp> &reductionDecls) {
   unsigned numReductions = loop.getNumReductionVars();
 
   for (unsigned i = 0; i < numReductions; ++i) {
-    owningReductionGens.push_back(
-        makeReductionGen(reductionDecls[i], builder, moduleTranslation));
-    owningAtomicReductionGens.push_back(
-        makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
-  }
-
-  // Collect the reduction information.
-  reductionInfos.reserve(numReductions);
-  for (unsigned i = 0; i < numReductions; ++i) {
-    llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
-    if (owningAtomicReductionGens[i])
-      atomicGen = owningAtomicReductionGens[i];
     llvm::Value *variable =
         moduleTranslation.lookupValue(loop.getReductionVars()[i]);
-    reductionInfos.push_back(
-        {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
-         privateReductionVariables[i], owningReductionGens[i], atomicGen});
+    llvm::OpenMPIRBuilder::ReductionInfo RI =
+        ompBuilder.RIManager.getReductionInfo(i);
+    RI.Variable = variable;
+    RI.ElementType = moduleTranslation.convertType(reductionDecls[i].getType());
+    RI.ReductionGen =
+        makeReductionGen(reductionDecls[i], builder, moduleTranslation);
+    RI.AtomicReductionGen =
+        makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation);
+    ompBuilder.RIManager.setReductionInfo(i, RI);
   }
 }
 
@@ -837,6 +820,7 @@ static void collectReductionInfo(
 static LogicalResult
 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto loop = cast<omp::WsLoopOp>(opInst);
   // TODO: this should be in the op verifier instead.
   if (loop.getLowerBound().empty())
@@ -861,10 +845,9 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
-  SmallVector<llvm::Value *> privateReductionVariables;
   DenseMap<Value, llvm::Value *> reductionVariableMap;
-  allocReductionVars(loop, builder, moduleTranslation, allocaIP, reductionDecls,
-                     privateReductionVariables, reductionVariableMap);
+  allocReductionVars(loop, builder, *ompBuilder, moduleTranslation, allocaIP,
+                     reductionDecls, reductionVariableMap);
 
   // Store the mapping between reduction variables and their private copies on
   // ModuleTranslation stack. It can be then recovered when translating
@@ -883,7 +866,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
       return failure();
     assert(phis.size() == 1 && "expected one value to be yielded from the "
                                "reduction neutral element declaration region");
-    builder.CreateStore(phis[0], privateReductionVariables[i]);
+    builder.CreateStore(phis[0],
+                        ompBuilder->RIManager.getPrivateReductionVariable(i));
   }
 
   // Set up the source location value for OpenMP runtime.
@@ -918,7 +902,6 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
   // i.e. it has a positive step, uses signed integer semantics. Reconsider
   // this code when WsLoop clearly supports more cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
         moduleTranslation.lookupValue(loop.getLowerBound()[i]);
@@ -975,12 +958,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
 
   // Create the reduction generators. We need to own them here because
   // ReductionInfo only accepts references to the generators.
-  SmallVector<OwningReductionGen> owningReductionGens;
-  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
-  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
-  collectReductionInfo(loop, builder, moduleTranslation, reductionDecls,
-                       owningReductionGens, owningAtomicReductionGens,
-                       privateReductionVariables, reductionInfos);
+  collectReductionInfo(loop, builder, *ompBuilder, moduleTranslation,
+                       reductionDecls);
 
   // The call to createReductions below expects the block to have a
   // terminator. Create an unreachable instruction to serve as terminator
@@ -988,7 +967,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
   builder.SetInsertPoint(tempTerminator);
   llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
-      ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
+      ompBuilder->createReductions(builder.saveIP(), allocaIP,
+                                   ompBuilder->RIManager.getReductionInfos(),
                                    loop.getNowait());
   if (!contInsertPoint.getBlock())
     return loop->emitOpError() << "failed to convert reductions";
@@ -996,7 +976,7 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
       ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
   tempTerminator->eraseFromParent();
   builder.restoreIP(nextInsertionPoint);
-
+  ompBuilder->RIManager.clear();
   return success();
 }
 
@@ -1016,11 +996,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     collectReductionDecls(opInst, reductionDecls);
 
     // Allocate reduction vars
-    SmallVector<llvm::Value *> privateReductionVariables;
     DenseMap<Value, llvm::Value *> reductionVariableMap;
-    allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
-                       reductionDecls, privateReductionVariables,
-                       reductionVariableMap);
+    allocReductionVars(opInst, builder, *ompBuilder, moduleTranslation,
+                       allocaIP, reductionDecls, reductionVariableMap);
 
     // Store the mapping between reduction variables and their private copies on
     // ModuleTranslation stack. It can be then recovered when translating
@@ -1040,7 +1018,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
              "expected one value to be yielded from the "
              "reduction neutral element declaration region");
       builder.restoreIP(allocaIP);
-      builder.CreateStore(phis[0], privateReductionVariables[i]);
+      builder.CreateStore(phis[0],
+                          ompBuilder->RIManager.getPrivateReductionVariable(i));
     }
 
     // Save the alloca insertion point on ModuleTranslation stack for use in
@@ -1057,12 +1036,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     // Process the reductions if required.
     if (opInst.getNumReductionVars() > 0) {
       // Collect reduction info
-      SmallVector<OwningReductionGen> owningReductionGens;
-      SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
-      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
-      collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
-                           owningReductionGens, owningAtomicReductionGens,
-                           privateReductionVariables, reductionInfos);
+      collectReductionInfo(opInst, builder, *ompBuilder, moduleTranslation,
+                           reductionDecls);
 
       // Move to region cont block
       builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1072,8 +1047,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       builder.SetInsertPoint(tempTerminator);
 
       llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
-          ompBuilder->createReductions(builder.saveIP(), allocaIP,
-                                       reductionInfos, false);
+          ompBuilder->createReductions(
+              builder.saveIP(), allocaIP,
+              ompBuilder->RIManager.getReductionInfos(), false);
       if (!contInsertPoint.getBlock()) {
         bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
         return;
@@ -1118,7 +1094,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(
       ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
                                  ifCond, numThreads, pbKind, isCancellable));
-
+  ompBuilder->RIManager.clear();
   return bodyGenStatus;
 }
 
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
index 93ab578df9e4e8..ea74ca474420e8 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
@@ -532,7 +532,7 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
 // CHECK: define internal void @[[OUTLINED]]
 
 // Private reduction variable and its initialization.
-// CHECK: %[[PRIVATE:[0-9]+]] = alloca i32
+// CHECK: %[[PRIVATE:private_redvar]] = alloca i32
 // CHECK: store i32 0, ptr %[[PRIVATE]]
 
 // Loop exit:



More information about the llvm-commits mailing list