[Mlir-commits] [mlir] [flang][mlir] Add support for translating task_reduction to LLVMIR (PR #120957)

Kareem Ergawy llvmlistbot at llvm.org
Thu Jan 9 08:08:45 PST 2025


================
@@ -1787,16 +1779,264 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
----------------
ergawy wrote:

Is this temporary (a todo) or permenant? Can we simply use `moduleTranslation.convertType(op.getReductionVars()[Cnt].getType())` (or something to the same effect)?

https://github.com/llvm/llvm-project/pull/120957


More information about the Mlir-commits mailing list