[Mlir-commits] [mlir] [mlir] Translating task_reduction clause for pass-by-value vars to LLVMIR (PR #125218)
Jack Styles
llvmlistbot at llvm.org
Mon Jan 5 04:12:31 PST 2026
================
@@ -2469,6 +2473,240 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return success();
}
+/*
+ * Utility function for translating `red_init`, `red_comb`, and `red_fini` to
+ * LLVMIR. The ulitity first (commonly) generates a skeleton for any of the
+ * three functions, and then generates the function body based on the
+ * specific operations involved in `red_init` (codegen related to initialization
+ * of task reduction variables) and `red_comb` (codegen related to combination).
+ * Currently, codegen for `red_fini` is skipped since finalization is optional
+ * for `task_reduction` clause, but this ulitity has the capability of defining
+ * finalization if needed. Finally, the returned `llvm::Function` is used to
+ * populate the relevant entries in the task reduction specific data structure.
+ */
+static llvm::Value *createTaskReductionFunction(
+ omp::TaskgroupOp &op, llvm::IRBuilderBase &builder, const std::string &name,
+ llvm::Type *redTy, LLVM::ModuleTranslation &moduleTranslation,
+ omp::DeclareReductionOp &reductionDecl, Region ®ion, unsigned cnt,
+ SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+ DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+
+ llvm::LLVMContext &Context = builder.getContext();
+ // TODO: by-ref reduction variables are yet to be handled.
+ llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+ llvm::Type *OpaquePtrTy =
+ llvm::PointerType::get(Context, DL.getProgramAddressSpace());
+ if (region.empty() && name == "red_fini")
+ // Finalization is optional for reductions.
+ return llvm::Constant::getNullValue(OpaquePtrTy);
+
+ // Prepare a general structure of the function to be emitted
+ llvm::FunctionType *funcType =
+ llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+ llvm::Function *function =
+ llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+ builder.GetInsertBlock()->getModule());
+ function->setDoesNotRecurse();
+ llvm::BasicBlock *entry =
+ llvm::BasicBlock::Create(Context, "entry", function);
+ llvm::IRBuilder<> bbBuilder(entry);
+
+ // Prepare the function arguments
+ llvm::Value *arg0 = function->getArg(0);
+ llvm::Value *arg1 = function->getArg(1);
+
+ if (name == "red_init") {
+ // For the initialization, map the reduction variables
+ // to the arguments of the function
+ function->addParamAttr(0, llvm::Attribute::NoAlias);
+ function->addParamAttr(1, llvm::Attribute::NoAlias);
+ Region &initializerRegion = reductionDecl.getInitializerRegion();
+ Block &entry = initializerRegion.front();
+
+ mlir::Value mlirSource = op.getTaskReductionVars()[cnt];
+ llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
+ llvm::Value *origVal = llvmSource;
+
+ moduleTranslation.mapValue(reductionDecl.getInitializerMoldArg(), origVal);
+
+ if (entry.getNumArguments() > 1) {
+ llvm::Value *allocation =
+ reductionVariableMap.lookup(op.getTaskReductionVars()[cnt]);
+ moduleTranslation.mapValue(reductionDecl.getInitializerAllocArg(),
+ allocation);
+ }
+
+ } else if (name == "red_comb") {
+ // For the combiner, perform a load for each argument
+ // and map it to the combiner region.
+ llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+ llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+ moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+ moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+ }
+
+ // Emit an empty function body in case of empty region
+ if (region.empty()) {
+ bbBuilder.CreateRet(arg0); // Return from the function
+ return function;
+ }
+
+ SmallVector<llvm::Value *, 1> phis;
+ if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+ &phis)))
+ return nullptr;
+ assert(
+ phis.size() == 1 &&
+ "expected one value to be yielded from the reduction declaration region");
+ bbBuilder.CreateStore(phis[0], arg0);
+ bbBuilder.CreateRet(arg0); // Return from the function
+ return function;
+}
+
+static void
+emitTaskRedInitCall(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ const llvm::OpenMPIRBuilder::LocationDescription &ompLoc,
+ int arraySize, llvm::Value *ArrayAlloca) {
+
+ llvm::LLVMContext &Context = builder.getContext();
+ uint32_t SrcLocStrSize;
+ llvm::Constant *SrcLocStr =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+ SrcLocStrSize);
+ llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+ SrcLocStr, SrcLocStrSize);
+ llvm::Value *ThreadID =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+ llvm::Constant *ConstInt =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+
+ llvm::Function *TaskRedInitFn =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+ llvm::omp::OMPRTL___kmpc_taskred_init);
+ builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+ OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+ SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+ DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+
+ if (op.getNumReductionVars() == 0)
+ return success();
+
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::LLVMContext &Context = builder.getContext();
+ llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+ // Define the kmp_taskred_input_t structure
+ llvm::StructType *kmp_taskred_input_t =
+ llvm::StructType::create(Context, "kmp_taskred_input_t");
+ llvm::Type *OpaquePtrTy =
+ llvm::PointerType::get(Context,
+ DL.getProgramAddressSpace()); // void*
+ llvm::Type *SizeTy = DL.getIntPtrType(Context); // size_t
+ llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+ // Structure members
+ std::vector<llvm::Type *> structMembers = {
+ OpaquePtrTy, // reduce_shar (void*)
+ OpaquePtrTy, // reduce_orig (void*)
+ SizeTy, // reduce_size (size_t)
+ OpaquePtrTy, // reduce_init (void*)
+ OpaquePtrTy, // reduce_fini (void*)
+ OpaquePtrTy, // reduce_comb (void*)
+ FlagsTy // flags (i32)
+ };
+
+ kmp_taskred_input_t->setBody(structMembers);
+ int arraySize = op.getTaskReductionVars().size();
+ llvm::ArrayType *ArrayTy =
+ llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+ // Save the current insertion point
+ auto oldIP = builder.saveIP();
+
+ // Set insertion point after the allocations
+ builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+ // Allocate the array for kmp_taskred_input_t
+ llvm::AllocaInst *ArrayAlloca =
+ builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+ // Restore the insertion point
+ builder.restoreIP(oldIP);
+
+ for (int cnt = 0; cnt < arraySize; ++cnt) {
+ llvm::Value *shared =
+ moduleTranslation.lookupValue(op.getTaskReductionVars()[cnt]);
+
+ // Create a GEP to access the reduction element
+ llvm::Value *StructPtr = builder.CreateGEP(
+ ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(cnt)},
+ "red_element");
+
+ llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+ kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+ builder.CreateStore(shared, FieldPtrReduceShar);
+
+ llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+ kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+ builder.CreateStore(shared, FieldPtrReduceOrig);
+
+ // Store size of the reduction variable
+ llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+ kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+ llvm::Type *redTy =
+ moduleTranslation.convertType(reductionDecls[cnt].getType());
+ uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+ llvm::ConstantInt *sizeConst =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+ builder.CreateStore(sizeConst, FieldPtrReduceSize);
+
+ // Initialize reduction variable
+ llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+ kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+ llvm::Value *initFunction = createTaskReductionFunction(
+ op, builder, "red_init", redTy, moduleTranslation, reductionDecls[cnt],
+ reductionDecls[cnt].getInitializerRegion(), cnt,
+ privateReductionVariables, reductionVariableMap);
+ builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+ // Create finish and combine functions
+ llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+ kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+ llvm::Value *finiFunction = createTaskReductionFunction(
+ op, builder, "red_fini", redTy, moduleTranslation, reductionDecls[cnt],
+ reductionDecls[cnt].getCleanupRegion(), cnt, privateReductionVariables,
+ reductionVariableMap);
+ builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+ llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+ kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+ llvm::Value *combFunction = createTaskReductionFunction(
+ op, builder, "red_comb", redTy, moduleTranslation, reductionDecls[cnt],
+ reductionDecls[cnt].getReductionRegion(), cnt,
+ privateReductionVariables, reductionVariableMap);
+ builder.CreateStore(combFunction, FieldPtrReduceComb);
----------------
Stylie777 wrote:
Same here for for the assert to check `combFunction` is not nullptr
https://github.com/llvm/llvm-project/pull/125218
More information about the Mlir-commits
mailing list