[Mlir-commits] [mlir] [OpenMP][MLIR] Lowering task_reduction clause to LLVMIR (PR #111788)
Tom Eccles
llvmlistbot at llvm.org
Thu Oct 10 08:51:54 PDT 2024
================
@@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return bodyGenStatus;
}
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+ llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+ LLVM::ModuleTranslation &moduleTranslation,
+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region ®ion,
+ OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+ SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+ DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+ llvm::LLVMContext &Context = builder.getContext();
+ llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+ // TODO: by-ref reduction variables are yet to be handled.
+ if (region.empty() || isByRef[Cnt]) {
+ return llvm::Constant::getNullValue(OpaquePtrTy);
+ }
+ llvm::FunctionType *funcType =
+ llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+ llvm::Function *function =
+ llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+ builder.GetInsertBlock()->getModule());
+ function->setDoesNotRecurse();
+ llvm::BasicBlock *entry =
+ llvm::BasicBlock::Create(Context, "entry", function);
+ llvm::IRBuilder<> bbBuilder(entry);
+
+ llvm::Value *arg0 = function->getArg(0);
+ llvm::Value *arg1 = function->getArg(1);
+
+ if (name == "red_init") {
+ function->addParamAttr(0, llvm::Attribute::NoAlias);
+ function->addParamAttr(1, llvm::Attribute::NoAlias);
+ mapInitializationArgs(op, moduleTranslation, reductionDecls,
+ reductionVariableMap, Cnt);
+ } else if (name == "red_comb") {
+ llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+ llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+ moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+ moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+ }
+
+ SmallVector<llvm::Value *, 1> phis;
+ if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+ &phis)))
+ return nullptr;
+ assert(
+ phis.size() == 1 &&
+ "expected one value to be yielded from the reduction declaration region");
+
+ bbBuilder.CreateStore(phis[0], arg0);
+ bbBuilder.CreateRet(arg0); // Return from the function
+ return function;
+}
+
+void emitTaskRedInitCall(
+ llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+ const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+ llvm::Value *ArrayAlloca) {
+
+ llvm::LLVMContext &Context = builder.getContext();
+ uint32_t SrcLocStrSize;
+ llvm::Constant *SrcLocStr =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+ SrcLocStrSize);
+ llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+ SrcLocStr, SrcLocStrSize);
+ llvm::Value *ThreadID =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+ llvm::Constant *ConstInt =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+
+ llvm::Function *TaskRedInitFn =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+ llvm::omp::OMPRTL___kmpc_taskred_init);
+ builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
----------------
tblah wrote:
For non-task reductions these sorts of function calls are generated by OpenMPIRBuilder so that we can share code with clang.
Are the clang people happy with us having diverging implementations here? If so I don't mind.
https://github.com/llvm/llvm-project/pull/111788
More information about the Mlir-commits
mailing list