[clang] [llvm] [mlir] [MLIR][OpenMP] Add codegen for teams reductions (PR #133310)
Jan Leyonberg via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 3 09:01:34 PDT 2025
================
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
".omp.reduction.func", &M);
}
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
- InsertPointTy AllocaIP,
- ArrayRef<ReductionInfo> ReductionInfos,
- ArrayRef<bool> IsByRef, bool IsNoWait) {
- assert(ReductionInfos.size() == IsByRef.size());
- for (const ReductionInfo &RI : ReductionInfos) {
- (void)RI;
- assert(RI.Variable && "expected non-null variable");
- assert(RI.PrivateVariable && "expected non-null private variable");
- assert(RI.ReductionGen && "expected non-null reduction generator callback");
- assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
- "expected variables and their private equivalents to have the same "
- "type");
- assert(RI.Variable->getType()->isPointerTy() &&
- "expected variables to be pointers");
+static Error populateReductionFunction(
+ Function *ReductionFunc,
+ ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+ IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
+ Module *Module = ReductionFunc->getParent();
+ BasicBlock *ReductionFuncBlock =
+ BasicBlock::Create(Module->getContext(), "", ReductionFunc);
+ Builder.SetInsertPoint(ReductionFuncBlock);
+ Value *LHSArrayPtr = nullptr;
+ Value *RHSArrayPtr = nullptr;
+ if (IsGPU) {
+ // Need to alloca memory here and deal with the pointers before getting
+ // LHS/RHS pointers out
+ //
+ Argument *Arg0 = ReductionFunc->getArg(0);
+ Argument *Arg1 = ReductionFunc->getArg(1);
+ Type *Arg0Type = Arg0->getType();
+ Type *Arg1Type = Arg1->getType();
+
+ Value *LHSAlloca =
+ Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+ Value *RHSAlloca =
+ Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+ Value *LHSAddrCast =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
+ Value *RHSAddrCast =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
+ Builder.CreateStore(Arg0, LHSAddrCast);
+ Builder.CreateStore(Arg1, RHSAddrCast);
+ LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+ RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+ } else {
+ LHSArrayPtr = ReductionFunc->getArg(0);
+ RHSArrayPtr = ReductionFunc->getArg(1);
}
+ unsigned NumReductions = ReductionInfos.size();
+ Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
+
+ for (auto En : enumerate(ReductionInfos)) {
+ const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+ Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+ RedArrayTy, LHSArrayPtr, 0, En.index());
+ Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+ Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LHSI8Ptr, RI.Variable->getType());
+ Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+ Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+ RedArrayTy, RHSArrayPtr, 0, En.index());
+ Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+ Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RHSI8Ptr, RI.PrivateVariable->getType());
+ Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+ Value *Reduced;
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+ if (!AfterIP)
+ return AfterIP.takeError();
+
+ Builder.restoreIP(*AfterIP);
+ // TODO: Consider flagging an error.
+ if (!Builder.GetInsertBlock())
+ return Error::success();
----------------
jsjodin wrote:
I'm going to skip this for now. I have some thoughts how to improve the IP handling, but we can discuss that separately.
https://github.com/llvm/llvm-project/pull/133310
More information about the llvm-commits
mailing list