[Mlir-commits] [mlir] [OpenMP][MLIR] Lowering task_reduction clause to LLVMIR (PR #111788)

Tom Eccles llvmlistbot at llvm.org
Thu Oct 10 08:51:56 PDT 2024


================
@@ -1142,19 +1141,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  // TODO: by-ref reduction variables are yet to be handled.
+  if (region.empty() || isByRef[Cnt]) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType =
+      llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                          reductionVariableMap, Cnt);
+  } else if (name == "red_comb") {
+    llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+    llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+    moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+    moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
----------------
tblah wrote:

nit: Adding a block name makes this a lot easier to debug if the region is translated into multiple basic blocks.

https://github.com/llvm/llvm-project/pull/111788


More information about the Mlir-commits mailing list