[Mlir-commits] [mlir] [mlir][OpenMP] Translate task_reduction on omp.taskgroup (PR #199565)
Sairudra More
llvmlistbot at llvm.org
Thu May 28 00:36:52 PDT 2026
https://github.com/Saieiei updated https://github.com/llvm/llvm-project/pull/199565
>From c3847fed57dfdc41061291b90c6e6ca1ff950d28 Mon Sep 17 00:00:00 2001
From: saieiei <sairudra60 at gmail.com>
Date: Mon, 25 May 2026 12:23:14 -0500
Subject: [PATCH] [mlir][OpenMP] Translate task_reduction on taskgroup
Add LLVM IR translation support for the task_reduction clause on
omp.taskgroup.
The translation builds task-reduction descriptors for the listed reduction
variables and emits the runtime initialization before the taskgroup body.
The reducer init and combiner callbacks are generated from the corresponding
omp.declare_reduction regions.
This patch keeps taskloop reduction and in_reduction translation unsupported;
those remain follow-up work. Unsupported task_reduction forms are diagnosed
instead of being lowered incorrectly.
Add MLIR translation tests for taskgroup task_reduction, multiple reducers,
plain taskgroup translation, and remaining unsupported cases.
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 243 +++++++++++++++++-
.../openmp-taskgroup-task-reduction.mlir | 153 +++++++++++
mlir/test/Target/LLVMIR/openmp-todo.mlir | 58 ++++-
3 files changed, 444 insertions(+), 10 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f0511bb4be7dd..dfec09a53075c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -370,10 +370,13 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
result = todo("reduction with modifier");
};
- auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
- if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
- op.getTaskReductionSyms())
- result = todo("task_reduction");
+ auto checkTaskReductionByref = [&todo](auto op, LogicalResult &result) {
+ if (auto byrefAttr = op.getTaskReductionByref())
+ for (bool isByRef : *byrefAttr)
+ if (isByRef) {
+ result = todo("task_reduction with byref modifier");
+ return;
+ }
};
auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
if (op.hasNumTeamsMultiDim())
@@ -426,7 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::TaskgroupOp op) {
checkAllocate(op, result);
- checkTaskReduction(op, result);
+ checkTaskReductionByref(op, result);
})
.Case([&](omp::TaskwaitOp op) {
checkDepend(op, result);
@@ -3643,6 +3646,183 @@ convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
return success();
}
+/// Build an outlined init helper for a task_reduction declare_reduction op.
+/// Signature: void(ptr %priv, ptr %orig). For non-byref reductions, the init
+/// region's mold argument is mapped to the value loaded from %orig, and the
+/// yielded scalar is stored into %priv.
+static llvm::Function *
+emitTaskReductionInitFn(omp::DeclareReductionOp decl, StringRef baseName,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::LLVMContext &ctx = llvmModule->getContext();
+ llvm::Type *voidTy = llvm::Type::getVoidTy(ctx);
+ llvm::Type *ptrTy = llvm::PointerType::getUnqual(ctx);
+ llvm::FunctionType *fty =
+ llvm::FunctionType::get(voidTy, {ptrTy, ptrTy}, false);
+ llvm::Function *fn =
+ llvm::Function::Create(fty, llvm::GlobalValue::InternalLinkage,
+ baseName + ".red.init", llvmModule);
+ fn->setDoesNotRecurse();
+ fn->getArg(0)->setName("priv");
+ fn->getArg(1)->setName("orig");
+
+ llvm::BasicBlock *entry = llvm::BasicBlock::Create(ctx, "entry", fn);
+ llvm::IRBuilder<> b(entry);
+
+ llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
+ llvm::Value *origVal = b.CreateLoad(elemTy, fn->getArg(1), "omp.orig");
+ moduleTranslation.mapValue(decl.getInitializerMoldArg(), origVal);
+ SmallVector<llvm::Value *, 1> phis;
+ if (failed(inlineConvertOmpRegions(decl.getInitializerRegion(),
+ "omp.taskred.init", b, moduleTranslation,
+ &phis))) {
+ fn->eraseFromParent();
+ return nullptr;
+ }
+ assert(phis.size() == 1 &&
+ "expected one value yielded from reduction initializer");
+ b.CreateStore(phis[0], fn->getArg(0));
+ b.CreateRetVoid();
+
+ moduleTranslation.forgetMapping(decl.getInitializerRegion());
+ return fn;
+}
+
+/// Build an outlined combiner helper for a task_reduction declare_reduction op.
+/// Signature: void(ptr %lhs, ptr %rhs). For non-byref reductions, the values
+/// at *%lhs and *%rhs are loaded, fed into the combiner region, and the
+/// yielded scalar is stored back into *%lhs.
+static llvm::Function *
+emitTaskReductionCombFn(omp::DeclareReductionOp decl, StringRef baseName,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::LLVMContext &ctx = llvmModule->getContext();
+ llvm::Type *voidTy = llvm::Type::getVoidTy(ctx);
+ llvm::Type *ptrTy = llvm::PointerType::getUnqual(ctx);
+ llvm::FunctionType *fty =
+ llvm::FunctionType::get(voidTy, {ptrTy, ptrTy}, false);
+ llvm::Function *fn =
+ llvm::Function::Create(fty, llvm::GlobalValue::InternalLinkage,
+ baseName + ".red.comb", llvmModule);
+ fn->setDoesNotRecurse();
+ fn->getArg(0)->setName("lhs");
+ fn->getArg(1)->setName("rhs");
+
+ llvm::BasicBlock *entry = llvm::BasicBlock::Create(ctx, "entry", fn);
+ llvm::IRBuilder<> b(entry);
+
+ llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
+ Block &combBlock = decl.getReductionRegion().front();
+ assert(combBlock.getNumArguments() == 2 &&
+ "expected two arguments in declare_reduction combiner");
+ llvm::Value *lhsVal = b.CreateLoad(elemTy, fn->getArg(0), "omp.lhs");
+ llvm::Value *rhsVal = b.CreateLoad(elemTy, fn->getArg(1), "omp.rhs");
+ moduleTranslation.mapValue(combBlock.getArgument(0), lhsVal);
+ moduleTranslation.mapValue(combBlock.getArgument(1), rhsVal);
+
+ SmallVector<llvm::Value *, 1> phis;
+ if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
+ "omp.taskred.comb", b, moduleTranslation,
+ &phis))) {
+ fn->eraseFromParent();
+ return nullptr;
+ }
+ assert(phis.size() == 1 &&
+ "expected one value yielded from reduction combiner");
+ b.CreateStore(phis[0], fn->getArg(0));
+ b.CreateRetVoid();
+
+ moduleTranslation.forgetMapping(decl.getReductionRegion());
+ return fn;
+}
+
+/// Emit the per-taskgroup task_reduction descriptor array and the
+/// `__kmpc_taskred_init` runtime call. Must be called with `builder` set to a
+/// point inside the taskgroup body (after `__kmpc_taskgroup`). The descriptor
+/// array itself is allocated at \p allocaIP.
+///
+/// Only the non-byref form is handled here. Byref task_reduction has already
+/// been rejected by `checkImplementationStatus`.
+static LogicalResult emitTaskgroupTaskReductionInit(
+ omp::TaskgroupOp tgOp, ArrayRef<omp::DeclareReductionOp> redDecls,
+ llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::LLVMContext &ctx = llvmModule->getContext();
+ const llvm::DataLayout &dl = llvmModule->getDataLayout();
+
+ llvm::Type *ptrTy = llvm::PointerType::getUnqual(ctx);
+ llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx);
+ llvm::Type *sizeTy =
+ llvm::Type::getIntNTy(ctx, dl.getPointerSizeInBits(/*AddrSpace=*/0));
+
+ // Identified `kmp_taskred_input_t` struct, matching the layout used by
+ // Clang's CGOpenMPRuntime::emitTaskReductionInit.
+ llvm::StructType *redInputTy =
+ llvm::StructType::getTypeByName(ctx, "kmp_taskred_input_t");
+ if (!redInputTy)
+ redInputTy = llvm::StructType::create(
+ ctx, {ptrTy, ptrTy, sizeTy, ptrTy, ptrTy, ptrTy, i32Ty},
+ "kmp_taskred_input_t");
+
+ unsigned n = redDecls.size();
+ llvm::ArrayType *arrTy = llvm::ArrayType::get(redInputTy, n);
+
+ // Allocate the descriptor array in the enclosing function's alloca block.
+ llvm::AllocaInst *arrAlloca;
+ {
+ llvm::IRBuilderBase::InsertPointGuard guard(builder);
+ builder.restoreIP(allocaIP);
+ arrAlloca =
+ builder.CreateAlloca(arrTy, /*ArraySize=*/nullptr, ".taskred.input");
+ }
+
+ // Fill each descriptor entry inside the taskgroup body.
+ llvm::Value *zero = builder.getInt32(0);
+ for (unsigned i = 0; i < n; ++i) {
+ omp::DeclareReductionOp decl = redDecls[i];
+ llvm::Value *orig =
+ moduleTranslation.lookupValue(tgOp.getTaskReductionVars()[i]);
+ llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
+ uint64_t size = dl.getTypeAllocSize(elemTy).getFixedValue();
+
+ std::string baseName =
+ (llvm::Twine("__omp_taskred_") + decl.getSymName()).str();
+ llvm::Function *initFn =
+ emitTaskReductionInitFn(decl, baseName, moduleTranslation);
+ llvm::Function *combFn =
+ emitTaskReductionCombFn(decl, baseName, moduleTranslation);
+ if (!initFn || !combFn)
+ return failure();
+ llvm::Value *elemPtr = builder.CreateInBoundsGEP(
+ arrTy, arrAlloca, {zero, builder.getInt32(i)}, ".taskred.elem");
+ auto storeField = [&](unsigned fieldIdx, llvm::Value *val) {
+ llvm::Value *fieldPtr =
+ builder.CreateStructGEP(redInputTy, elemPtr, fieldIdx);
+ builder.CreateStore(val, fieldPtr);
+ };
+ storeField(0, orig); // reduce_shar
+ storeField(1, orig); // reduce_orig
+ storeField(2, llvm::ConstantInt::get(sizeTy, size)); // reduce_size
+ storeField(3, initFn); // reduce_init
+ storeField(4, llvm::ConstantPointerNull::get(ptrTy)); // reduce_fini
+ storeField(5, combFn); // reduce_comb
+ storeField(6, llvm::ConstantInt::get(i32Ty, 0)); // flags
+ }
+
+ // Emit call: __kmpc_taskred_init(gtid, num, &arr).
+ uint32_t srcLocSize;
+ llvm::Constant *srcLocStr =
+ ompBuilder->getOrCreateDefaultSrcLocStr(srcLocSize);
+ llvm::Value *ident = ompBuilder->getOrCreateIdent(srcLocStr, srcLocSize);
+ llvm::Value *gtid = ompBuilder->getOrCreateThreadID(ident);
+ llvm::FunctionCallee taskredInit = ompBuilder->getOrCreateRuntimeFunction(
+ *llvmModule, llvm::omp::OMPRTL___kmpc_taskred_init);
+ builder.CreateCall(taskredInit, {gtid, builder.getInt32(n), arrAlloca});
+ return success();
+}
+
/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -3651,9 +3831,58 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*tgOp)))
return failure();
- auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
- llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
+ // Resolve and validate task_reduction declarations up front. We only handle
+ // declare_reduction ops shaped like a non-byref scalar reduction in this
+ // first cut; richer shapes (two-argument initializer, cleanup region,
+ // missing combiner) require additional infrastructure.
+ SmallVector<omp::DeclareReductionOp> redDecls;
+ if (auto syms = tgOp.getTaskReductionSyms()) {
+ redDecls.reserve(syms->size());
+ for (auto sym : syms->getAsRange<SymbolRefAttr>()) {
+ auto decl = SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
+ tgOp, sym);
+ if (!decl)
+ return tgOp.emitError()
+ << "failed to resolve task_reduction declare_reduction symbol "
+ << sym.getRootReference() << " in omp.taskgroup";
+ if (decl.getInitializerRegion().front().getNumArguments() != 1)
+ return tgOp.emitError("not yet implemented: task_reduction with "
+ "two-argument initializer in omp.taskgroup");
+ if (!decl.getCleanupRegion().empty())
+ return tgOp.emitError("not yet implemented: task_reduction with "
+ "cleanup region in omp.taskgroup");
+ if (decl.getReductionRegion().empty())
+ return tgOp.emitError("task_reduction declare_reduction is missing a "
+ "combiner region");
+ redDecls.push_back(decl);
+ }
+ }
+
+ auto bodyCB =
+ [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
+ llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
builder.restoreIP(codegenIP);
+
+ if (!redDecls.empty()) {
+ if (failed(emitTaskgroupTaskReductionInit(tgOp, redDecls, builder,
+ allocaIP, moduleTranslation)))
+ return llvm::createStringError(
+ llvm::inconvertibleErrorCode(),
+ "failed to emit task_reduction initialization for omp.taskgroup");
+ }
+
+ // Inside the taskgroup body, each task_reduction block argument refers to
+ // the same shared/original storage that the runtime now knows about via
+ // the descriptor array. Inner tasks that declare in_reduction look up
+ // per-task private copies through the runtime; the taskgroup body itself
+ // uses the original variable.
+ for (auto [i, blockArg] :
+ llvm::enumerate(tgOp.getRegion().getArguments())) {
+ llvm::Value *orig =
+ moduleTranslation.lookupValue(tgOp.getTaskReductionVars()[i]);
+ moduleTranslation.mapValue(blockArg, orig);
+ }
+
return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
builder, moduleTranslation)
.takeError();
diff --git a/mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir
new file mode 100644
index 0000000000000..353ce6218d9f3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir
@@ -0,0 +1,153 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// Single scalar task_reduction on omp.taskgroup. Verifies that the
+// kmp_taskred_input_t descriptor is allocated, populated, and handed off to
+// __kmpc_taskred_init, and that init / combiner helper functions are emitted.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+ %s = llvm.add %arg0, %arg1 : i32
+ omp.yield(%s : i32)
+}
+
+llvm.func @taskgroup_task_reduction_single(%x: !llvm.ptr) {
+ omp.taskgroup task_reduction(@add_i32 %x -> %prv : !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK: %kmp_taskred_input_t = type { ptr, ptr, i64, ptr, ptr, ptr, i32 }
+
+// CHECK-LABEL: define void @taskgroup_task_reduction_single(
+// CHECK-SAME: ptr %[[X:.+]])
+// CHECK: %[[ARR:.+]] = alloca [1 x %kmp_taskred_input_t]
+// CHECK: call void @__kmpc_taskgroup(
+// Descriptor entry 0.
+// CHECK: %[[GEP0:.+]] = getelementptr inbounds [1 x %kmp_taskred_input_t], ptr %[[ARR]], i32 0, i32 0
+// CHECK: %[[SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 0
+// CHECK: store ptr %[[X]], ptr %[[SHAR]]
+// CHECK: %[[ORIG:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 1
+// CHECK: store ptr %[[X]], ptr %[[ORIG]]
+// CHECK: %[[SZF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 2
+// CHECK: store i64 4, ptr %[[SZF]]
+// CHECK: %[[INITF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 3
+// CHECK: store ptr @__omp_taskred_add_i32.red.init, ptr %[[INITF]]
+// CHECK: %[[FINIF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 4
+// CHECK: store ptr null, ptr %[[FINIF]]
+// CHECK: %[[COMBF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 5
+// CHECK: store ptr @__omp_taskred_add_i32.red.comb, ptr %[[COMBF]]
+// CHECK: %[[FLAGSF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 6
+// CHECK: store i32 0, ptr %[[FLAGSF]]
+// CHECK: call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 1, ptr %[[ARR]])
+// CHECK: call void @__kmpc_end_taskgroup(
+
+// CHECK-LABEL: define internal void @__omp_taskred_add_i32.red.init(
+// CHECK-SAME: ptr %priv, ptr %orig)
+// CHECK: load i32, ptr %orig
+// CHECK: store i32 0, ptr %priv
+// CHECK: ret void
+
+// CHECK-LABEL: define internal void @__omp_taskred_add_i32.red.comb(
+// CHECK-SAME: ptr %lhs, ptr %rhs)
+// CHECK: %[[L:.+]] = load i32, ptr %lhs
+// CHECK: %[[R:.+]] = load i32, ptr %rhs
+// CHECK: %[[S:.+]] = add i32 %[[L]], %[[R]]
+// CHECK: store i32 %[[S]], ptr %lhs
+// CHECK: ret void
+
+// -----
+
+// Multiple task_reduction items on the same taskgroup.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+ %s = llvm.add %arg0, %arg1 : i32
+ omp.yield(%s : i32)
+}
+
+omp.declare_reduction @mul_i64 : i64
+init {
+^bb0(%arg0: i64):
+ %c1 = llvm.mlir.constant(1 : i64) : i64
+ omp.yield(%c1 : i64)
+}
+combiner {
+^bb0(%arg0: i64, %arg1: i64):
+ %p = llvm.mul %arg0, %arg1 : i64
+ omp.yield(%p : i64)
+}
+
+llvm.func @taskgroup_task_reduction_multi(%x: !llvm.ptr, %y: !llvm.ptr) {
+ omp.taskgroup task_reduction(@add_i32 %x -> %a, @mul_i64 %y -> %b : !llvm.ptr, !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: define void @taskgroup_task_reduction_multi(
+// CHECK-SAME: ptr %[[XA:[^,)]+]], ptr %[[YA:[^,)]+]])
+// CHECK: %[[ARR2:.+]] = alloca [2 x %kmp_taskred_input_t]
+// CHECK: call void @__kmpc_taskgroup(
+// Descriptor entry 0: @add_i32 on %x.
+// CHECK: %[[E0:.+]] = getelementptr inbounds [2 x %kmp_taskred_input_t], ptr %[[ARR2]], i32 0, i32 0
+// CHECK: %[[E0_SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 0
+// CHECK: store ptr %[[XA]], ptr %[[E0_SHAR]]
+// CHECK: %[[E0_ORIG:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 1
+// CHECK: store ptr %[[XA]], ptr %[[E0_ORIG]]
+// CHECK: %[[E0_SZ:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 2
+// CHECK: store i64 4, ptr %[[E0_SZ]]
+// CHECK: %[[E0_INIT:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 3
+// CHECK: store ptr @__omp_taskred_add_i32.red.init, ptr %[[E0_INIT]]
+// CHECK: %[[E0_FINI:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 4
+// CHECK: store ptr null, ptr %[[E0_FINI]]
+// CHECK: %[[E0_COMB:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 5
+// CHECK: store ptr @__omp_taskred_add_i32.red.comb, ptr %[[E0_COMB]]
+// CHECK: %[[E0_FL:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 6
+// CHECK: store i32 0, ptr %[[E0_FL]]
+// Descriptor entry 1: @mul_i64 on %y.
+// CHECK: %[[E1:.+]] = getelementptr inbounds [2 x %kmp_taskred_input_t], ptr %[[ARR2]], i32 0, i32 1
+// CHECK: %[[E1_SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 0
+// CHECK: store ptr %[[YA]], ptr %[[E1_SHAR]]
+// CHECK: %[[E1_ORIG:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 1
+// CHECK: store ptr %[[YA]], ptr %[[E1_ORIG]]
+// CHECK: %[[E1_SZ:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 2
+// CHECK: store i64 8, ptr %[[E1_SZ]]
+// CHECK: %[[E1_INIT:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 3
+// CHECK: store ptr @__omp_taskred_mul_i64.red.init, ptr %[[E1_INIT]]
+// CHECK: %[[E1_FINI:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 4
+// CHECK: store ptr null, ptr %[[E1_FINI]]
+// CHECK: %[[E1_COMB:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 5
+// CHECK: store ptr @__omp_taskred_mul_i64.red.comb, ptr %[[E1_COMB]]
+// CHECK: %[[E1_FL:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 6
+// CHECK: store i32 0, ptr %[[E1_FL]]
+// CHECK: call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 2, ptr %[[ARR2]])
+
+// -----
+
+// Plain taskgroup without task_reduction must still translate (regression
+// guard for the rewrite of convertOmpTaskgroupOp).
+
+llvm.func @taskgroup_plain() {
+ omp.taskgroup {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: define void @taskgroup_plain()
+// CHECK: call void @__kmpc_taskgroup(
+// CHECK-NOT: call ptr @__kmpc_taskred_init(
+// CHECK: call void @__kmpc_end_taskgroup(
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 295ba54dbfb38..d5d0a96779db0 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -301,10 +301,62 @@ atomic {
llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
omp.yield
}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
- // expected-error at below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
+llvm.func @taskgroup_task_reduction_byref(%x : !llvm.ptr) {
+ // expected-error at below {{not yet implemented: Unhandled clause task_reduction with byref modifier in omp.taskgroup operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
- omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
+ omp.taskgroup task_reduction(byref @add_f32 %x -> %prv : !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+}
+// -----
+
+omp.declare_reduction @add_i32_cleanup : i32
+init {
+^bb0(%arg: i32):
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%a: i32, %b: i32):
+ %s = llvm.add %a, %b : i32
+ omp.yield(%s : i32)
+}
+cleanup {
+^bb0(%a: i32):
+ omp.yield
+}
+llvm.func @taskgroup_task_reduction_cleanup(%x : !llvm.ptr) {
+ // expected-error at below {{not yet implemented: task_reduction with cleanup region in omp.taskgroup}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
+ omp.taskgroup task_reduction(@add_i32_cleanup %x -> %prv : !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+}
+// -----
+
+omp.declare_reduction @add_i32_2arg_init : !llvm.ptr
+alloc {
+^bb0(%mold: !llvm.ptr):
+ %c1 = llvm.mlir.constant(1 : i32) : i32
+ %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+ omp.yield(%0 : !llvm.ptr)
+}
+init {
+^bb0(%mold: !llvm.ptr, %alloc: !llvm.ptr):
+ %c0 = llvm.mlir.constant(0 : i32) : i32
+ llvm.store %c0, %alloc : i32, !llvm.ptr
+ omp.yield(%alloc : !llvm.ptr)
+}
+combiner {
+^bb0(%a: !llvm.ptr, %b: !llvm.ptr):
+ omp.yield(%a : !llvm.ptr)
+}
+llvm.func @taskgroup_task_reduction_two_arg_init(%x : !llvm.ptr) {
+ // expected-error at below {{not yet implemented: task_reduction with two-argument initializer in omp.taskgroup}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
+ omp.taskgroup task_reduction(@add_i32_2arg_init %x -> %prv : !llvm.ptr) {
omp.terminator
}
llvm.return
More information about the Mlir-commits
mailing list