[Mlir-commits] [mlir] [mlir][OpenMP] Translate reductions on taskloop (PR #199670)

Sairudra More llvmlistbot at llvm.org
Thu May 28 00:57:11 PDT 2026


https://github.com/Saieiei updated https://github.com/llvm/llvm-project/pull/199670

>From c3847fed57dfdc41061291b90c6e6ca1ff950d28 Mon Sep 17 00:00:00 2001
From: saieiei <sairudra60 at gmail.com>
Date: Mon, 25 May 2026 12:23:14 -0500
Subject: [PATCH 1/2] [mlir][OpenMP] Translate task_reduction on taskgroup

Add LLVM IR translation support for the task_reduction clause on
omp.taskgroup.

The translation builds task-reduction descriptors for the listed reduction
variables and emits the runtime initialization before the taskgroup body.
The reducer init and combiner callbacks are generated from the corresponding
omp.declare_reduction regions.

This patch keeps taskloop reduction and in_reduction translation unsupported;
those remain follow-up work. Unsupported task_reduction forms are diagnosed
instead of being lowered incorrectly.

Add MLIR translation tests for taskgroup task_reduction, multiple reducers,
plain taskgroup translation, and remaining unsupported cases.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 243 +++++++++++++++++-
 .../openmp-taskgroup-task-reduction.mlir      | 153 +++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  58 ++++-
 3 files changed, 444 insertions(+), 10 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f0511bb4be7dd..dfec09a53075c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -370,10 +370,13 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
       result = todo("reduction with modifier");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
+  auto checkTaskReductionByref = [&todo](auto op, LogicalResult &result) {
+    if (auto byrefAttr = op.getTaskReductionByref())
+      for (bool isByRef : *byrefAttr)
+        if (isByRef) {
+          result = todo("task_reduction with byref modifier");
+          return;
+        }
   };
   auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumTeamsMultiDim())
@@ -426,7 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::TaskgroupOp op) {
         checkAllocate(op, result);
-        checkTaskReduction(op, result);
+        checkTaskReductionByref(op, result);
       })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
@@ -3643,6 +3646,183 @@ convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
   return success();
 }
 
+/// Build an outlined init helper for a task_reduction declare_reduction op.
+/// Signature: void(ptr %priv, ptr %orig). For non-byref reductions, the init
+/// region's mold argument is mapped to the value loaded from %orig, and the
+/// yielded scalar is stored into %priv.
+static llvm::Function *
+emitTaskReductionInitFn(omp::DeclareReductionOp decl, StringRef baseName,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+  llvm::LLVMContext &ctx = llvmModule->getContext();
+  llvm::Type *voidTy = llvm::Type::getVoidTy(ctx);
+  llvm::Type *ptrTy = llvm::PointerType::getUnqual(ctx);
+  llvm::FunctionType *fty =
+      llvm::FunctionType::get(voidTy, {ptrTy, ptrTy}, false);
+  llvm::Function *fn =
+      llvm::Function::Create(fty, llvm::GlobalValue::InternalLinkage,
+                             baseName + ".red.init", llvmModule);
+  fn->setDoesNotRecurse();
+  fn->getArg(0)->setName("priv");
+  fn->getArg(1)->setName("orig");
+
+  llvm::BasicBlock *entry = llvm::BasicBlock::Create(ctx, "entry", fn);
+  llvm::IRBuilder<> b(entry);
+
+  llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
+  llvm::Value *origVal = b.CreateLoad(elemTy, fn->getArg(1), "omp.orig");
+  moduleTranslation.mapValue(decl.getInitializerMoldArg(), origVal);
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(decl.getInitializerRegion(),
+                                     "omp.taskred.init", b, moduleTranslation,
+                                     &phis))) {
+    fn->eraseFromParent();
+    return nullptr;
+  }
+  assert(phis.size() == 1 &&
+         "expected one value yielded from reduction initializer");
+  b.CreateStore(phis[0], fn->getArg(0));
+  b.CreateRetVoid();
+
+  moduleTranslation.forgetMapping(decl.getInitializerRegion());
+  return fn;
+}
+
+/// Build an outlined combiner helper for a task_reduction declare_reduction op.
+/// Signature: void(ptr %lhs, ptr %rhs). For non-byref reductions, the values
+/// at *%lhs and *%rhs are loaded, fed into the combiner region, and the
+/// yielded scalar is stored back into *%lhs.
+static llvm::Function *
+emitTaskReductionCombFn(omp::DeclareReductionOp decl, StringRef baseName,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+  llvm::LLVMContext &ctx = llvmModule->getContext();
+  llvm::Type *voidTy = llvm::Type::getVoidTy(ctx);
+  llvm::Type *ptrTy = llvm::PointerType::getUnqual(ctx);
+  llvm::FunctionType *fty =
+      llvm::FunctionType::get(voidTy, {ptrTy, ptrTy}, false);
+  llvm::Function *fn =
+      llvm::Function::Create(fty, llvm::GlobalValue::InternalLinkage,
+                             baseName + ".red.comb", llvmModule);
+  fn->setDoesNotRecurse();
+  fn->getArg(0)->setName("lhs");
+  fn->getArg(1)->setName("rhs");
+
+  llvm::BasicBlock *entry = llvm::BasicBlock::Create(ctx, "entry", fn);
+  llvm::IRBuilder<> b(entry);
+
+  llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
+  Block &combBlock = decl.getReductionRegion().front();
+  assert(combBlock.getNumArguments() == 2 &&
+         "expected two arguments in declare_reduction combiner");
+  llvm::Value *lhsVal = b.CreateLoad(elemTy, fn->getArg(0), "omp.lhs");
+  llvm::Value *rhsVal = b.CreateLoad(elemTy, fn->getArg(1), "omp.rhs");
+  moduleTranslation.mapValue(combBlock.getArgument(0), lhsVal);
+  moduleTranslation.mapValue(combBlock.getArgument(1), rhsVal);
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
+                                     "omp.taskred.comb", b, moduleTranslation,
+                                     &phis))) {
+    fn->eraseFromParent();
+    return nullptr;
+  }
+  assert(phis.size() == 1 &&
+         "expected one value yielded from reduction combiner");
+  b.CreateStore(phis[0], fn->getArg(0));
+  b.CreateRetVoid();
+
+  moduleTranslation.forgetMapping(decl.getReductionRegion());
+  return fn;
+}
+
+/// Emit the per-taskgroup task_reduction descriptor array and the
+/// `__kmpc_taskred_init` runtime call. Must be called with `builder` set to a
+/// point inside the taskgroup body (after `__kmpc_taskgroup`). The descriptor
+/// array itself is allocated at \p allocaIP.
+///
+/// Only the non-byref form is handled here. Byref task_reduction has already
+/// been rejected by `checkImplementationStatus`.
+static LogicalResult emitTaskgroupTaskReductionInit(
+    omp::TaskgroupOp tgOp, ArrayRef<omp::DeclareReductionOp> redDecls,
+    llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
+    LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+  llvm::LLVMContext &ctx = llvmModule->getContext();
+  const llvm::DataLayout &dl = llvmModule->getDataLayout();
+
+  llvm::Type *ptrTy = llvm::PointerType::getUnqual(ctx);
+  llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx);
+  llvm::Type *sizeTy =
+      llvm::Type::getIntNTy(ctx, dl.getPointerSizeInBits(/*AddrSpace=*/0));
+
+  // Identified `kmp_taskred_input_t` struct, matching the layout used by
+  // Clang's CGOpenMPRuntime::emitTaskReductionInit.
+  llvm::StructType *redInputTy =
+      llvm::StructType::getTypeByName(ctx, "kmp_taskred_input_t");
+  if (!redInputTy)
+    redInputTy = llvm::StructType::create(
+        ctx, {ptrTy, ptrTy, sizeTy, ptrTy, ptrTy, ptrTy, i32Ty},
+        "kmp_taskred_input_t");
+
+  unsigned n = redDecls.size();
+  llvm::ArrayType *arrTy = llvm::ArrayType::get(redInputTy, n);
+
+  // Allocate the descriptor array in the enclosing function's alloca block.
+  llvm::AllocaInst *arrAlloca;
+  {
+    llvm::IRBuilderBase::InsertPointGuard guard(builder);
+    builder.restoreIP(allocaIP);
+    arrAlloca =
+        builder.CreateAlloca(arrTy, /*ArraySize=*/nullptr, ".taskred.input");
+  }
+
+  // Fill each descriptor entry inside the taskgroup body.
+  llvm::Value *zero = builder.getInt32(0);
+  for (unsigned i = 0; i < n; ++i) {
+    omp::DeclareReductionOp decl = redDecls[i];
+    llvm::Value *orig =
+        moduleTranslation.lookupValue(tgOp.getTaskReductionVars()[i]);
+    llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
+    uint64_t size = dl.getTypeAllocSize(elemTy).getFixedValue();
+
+    std::string baseName =
+        (llvm::Twine("__omp_taskred_") + decl.getSymName()).str();
+    llvm::Function *initFn =
+        emitTaskReductionInitFn(decl, baseName, moduleTranslation);
+    llvm::Function *combFn =
+        emitTaskReductionCombFn(decl, baseName, moduleTranslation);
+    if (!initFn || !combFn)
+      return failure();
+    llvm::Value *elemPtr = builder.CreateInBoundsGEP(
+        arrTy, arrAlloca, {zero, builder.getInt32(i)}, ".taskred.elem");
+    auto storeField = [&](unsigned fieldIdx, llvm::Value *val) {
+      llvm::Value *fieldPtr =
+          builder.CreateStructGEP(redInputTy, elemPtr, fieldIdx);
+      builder.CreateStore(val, fieldPtr);
+    };
+    storeField(0, orig);                                  // reduce_shar
+    storeField(1, orig);                                  // reduce_orig
+    storeField(2, llvm::ConstantInt::get(sizeTy, size));  // reduce_size
+    storeField(3, initFn);                                // reduce_init
+    storeField(4, llvm::ConstantPointerNull::get(ptrTy)); // reduce_fini
+    storeField(5, combFn);                                // reduce_comb
+    storeField(6, llvm::ConstantInt::get(i32Ty, 0));      // flags
+  }
+
+  // Emit call: __kmpc_taskred_init(gtid, num, &arr).
+  uint32_t srcLocSize;
+  llvm::Constant *srcLocStr =
+      ompBuilder->getOrCreateDefaultSrcLocStr(srcLocSize);
+  llvm::Value *ident = ompBuilder->getOrCreateIdent(srcLocStr, srcLocSize);
+  llvm::Value *gtid = ompBuilder->getOrCreateThreadID(ident);
+  llvm::FunctionCallee taskredInit = ompBuilder->getOrCreateRuntimeFunction(
+      *llvmModule, llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(taskredInit, {gtid, builder.getInt32(n), arrAlloca});
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -3651,9 +3831,58 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
 
-  auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
-                    llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) {
+  // Resolve and validate task_reduction declarations up front. We only handle
+  // declare_reduction ops shaped like a non-byref scalar reduction in this
+  // first cut; richer shapes (two-argument initializer, cleanup region,
+  // missing combiner) require additional infrastructure.
+  SmallVector<omp::DeclareReductionOp> redDecls;
+  if (auto syms = tgOp.getTaskReductionSyms()) {
+    redDecls.reserve(syms->size());
+    for (auto sym : syms->getAsRange<SymbolRefAttr>()) {
+      auto decl = SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
+          tgOp, sym);
+      if (!decl)
+        return tgOp.emitError()
+               << "failed to resolve task_reduction declare_reduction symbol "
+               << sym.getRootReference() << " in omp.taskgroup";
+      if (decl.getInitializerRegion().front().getNumArguments() != 1)
+        return tgOp.emitError("not yet implemented: task_reduction with "
+                              "two-argument initializer in omp.taskgroup");
+      if (!decl.getCleanupRegion().empty())
+        return tgOp.emitError("not yet implemented: task_reduction with "
+                              "cleanup region in omp.taskgroup");
+      if (decl.getReductionRegion().empty())
+        return tgOp.emitError("task_reduction declare_reduction is missing a "
+                              "combiner region");
+      redDecls.push_back(decl);
+    }
+  }
+
+  auto bodyCB =
+      [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
+          llvm::ArrayRef<llvm::BasicBlock *> deallocBlocks) -> llvm::Error {
     builder.restoreIP(codegenIP);
+
+    if (!redDecls.empty()) {
+      if (failed(emitTaskgroupTaskReductionInit(tgOp, redDecls, builder,
+                                                allocaIP, moduleTranslation)))
+        return llvm::createStringError(
+            llvm::inconvertibleErrorCode(),
+            "failed to emit task_reduction initialization for omp.taskgroup");
+    }
+
+    // Inside the taskgroup body, each task_reduction block argument refers to
+    // the same shared/original storage that the runtime now knows about via
+    // the descriptor array. Inner tasks that declare in_reduction look up
+    // per-task private copies through the runtime; the taskgroup body itself
+    // uses the original variable.
+    for (auto [i, blockArg] :
+         llvm::enumerate(tgOp.getRegion().getArguments())) {
+      llvm::Value *orig =
+          moduleTranslation.lookupValue(tgOp.getTaskReductionVars()[i]);
+      moduleTranslation.mapValue(blockArg, orig);
+    }
+
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
diff --git a/mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir
new file mode 100644
index 0000000000000..353ce6218d9f3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-taskgroup-task-reduction.mlir
@@ -0,0 +1,153 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// Single scalar task_reduction on omp.taskgroup. Verifies that the
+// kmp_taskred_input_t descriptor is allocated, populated, and handed off to
+// __kmpc_taskred_init, and that init / combiner helper functions are emitted.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @taskgroup_task_reduction_single(%x: !llvm.ptr) {
+  omp.taskgroup task_reduction(@add_i32 %x -> %prv : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK: %kmp_taskred_input_t = type { ptr, ptr, i64, ptr, ptr, ptr, i32 }
+
+// CHECK-LABEL: define void @taskgroup_task_reduction_single(
+// CHECK-SAME:    ptr %[[X:.+]])
+// CHECK:         %[[ARR:.+]] = alloca [1 x %kmp_taskred_input_t]
+// CHECK:         call void @__kmpc_taskgroup(
+// Descriptor entry 0.
+// CHECK:         %[[GEP0:.+]] = getelementptr inbounds [1 x %kmp_taskred_input_t], ptr %[[ARR]], i32 0, i32 0
+// CHECK:         %[[SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 0
+// CHECK:         store ptr %[[X]], ptr %[[SHAR]]
+// CHECK:         %[[ORIG:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 1
+// CHECK:         store ptr %[[X]], ptr %[[ORIG]]
+// CHECK:         %[[SZF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 2
+// CHECK:         store i64 4, ptr %[[SZF]]
+// CHECK:         %[[INITF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 3
+// CHECK:         store ptr @__omp_taskred_add_i32.red.init, ptr %[[INITF]]
+// CHECK:         %[[FINIF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 4
+// CHECK:         store ptr null, ptr %[[FINIF]]
+// CHECK:         %[[COMBF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 5
+// CHECK:         store ptr @__omp_taskred_add_i32.red.comb, ptr %[[COMBF]]
+// CHECK:         %[[FLAGSF:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[GEP0]], i32 0, i32 6
+// CHECK:         store i32 0, ptr %[[FLAGSF]]
+// CHECK:         call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 1, ptr %[[ARR]])
+// CHECK:         call void @__kmpc_end_taskgroup(
+
+// CHECK-LABEL: define internal void @__omp_taskred_add_i32.red.init(
+// CHECK-SAME:    ptr %priv, ptr %orig)
+// CHECK:         load i32, ptr %orig
+// CHECK:         store i32 0, ptr %priv
+// CHECK:         ret void
+
+// CHECK-LABEL: define internal void @__omp_taskred_add_i32.red.comb(
+// CHECK-SAME:    ptr %lhs, ptr %rhs)
+// CHECK:         %[[L:.+]] = load i32, ptr %lhs
+// CHECK:         %[[R:.+]] = load i32, ptr %rhs
+// CHECK:         %[[S:.+]] = add i32 %[[L]], %[[R]]
+// CHECK:         store i32 %[[S]], ptr %lhs
+// CHECK:         ret void
+
+// -----
+
+// Multiple task_reduction items on the same taskgroup.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+omp.declare_reduction @mul_i64 : i64
+init {
+^bb0(%arg0: i64):
+  %c1 = llvm.mlir.constant(1 : i64) : i64
+  omp.yield(%c1 : i64)
+}
+combiner {
+^bb0(%arg0: i64, %arg1: i64):
+  %p = llvm.mul %arg0, %arg1 : i64
+  omp.yield(%p : i64)
+}
+
+llvm.func @taskgroup_task_reduction_multi(%x: !llvm.ptr, %y: !llvm.ptr) {
+  omp.taskgroup task_reduction(@add_i32 %x -> %a, @mul_i64 %y -> %b : !llvm.ptr, !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @taskgroup_task_reduction_multi(
+// CHECK-SAME:    ptr %[[XA:[^,)]+]], ptr %[[YA:[^,)]+]])
+// CHECK:         %[[ARR2:.+]] = alloca [2 x %kmp_taskred_input_t]
+// CHECK:         call void @__kmpc_taskgroup(
+// Descriptor entry 0: @add_i32 on %x.
+// CHECK:         %[[E0:.+]] = getelementptr inbounds [2 x %kmp_taskred_input_t], ptr %[[ARR2]], i32 0, i32 0
+// CHECK:         %[[E0_SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 0
+// CHECK:         store ptr %[[XA]], ptr %[[E0_SHAR]]
+// CHECK:         %[[E0_ORIG:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 1
+// CHECK:         store ptr %[[XA]], ptr %[[E0_ORIG]]
+// CHECK:         %[[E0_SZ:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 2
+// CHECK:         store i64 4, ptr %[[E0_SZ]]
+// CHECK:         %[[E0_INIT:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 3
+// CHECK:         store ptr @__omp_taskred_add_i32.red.init, ptr %[[E0_INIT]]
+// CHECK:         %[[E0_FINI:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 4
+// CHECK:         store ptr null, ptr %[[E0_FINI]]
+// CHECK:         %[[E0_COMB:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 5
+// CHECK:         store ptr @__omp_taskred_add_i32.red.comb, ptr %[[E0_COMB]]
+// CHECK:         %[[E0_FL:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E0]], i32 0, i32 6
+// CHECK:         store i32 0, ptr %[[E0_FL]]
+// Descriptor entry 1: @mul_i64 on %y.
+// CHECK:         %[[E1:.+]] = getelementptr inbounds [2 x %kmp_taskred_input_t], ptr %[[ARR2]], i32 0, i32 1
+// CHECK:         %[[E1_SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 0
+// CHECK:         store ptr %[[YA]], ptr %[[E1_SHAR]]
+// CHECK:         %[[E1_ORIG:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 1
+// CHECK:         store ptr %[[YA]], ptr %[[E1_ORIG]]
+// CHECK:         %[[E1_SZ:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 2
+// CHECK:         store i64 8, ptr %[[E1_SZ]]
+// CHECK:         %[[E1_INIT:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 3
+// CHECK:         store ptr @__omp_taskred_mul_i64.red.init, ptr %[[E1_INIT]]
+// CHECK:         %[[E1_FINI:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 4
+// CHECK:         store ptr null, ptr %[[E1_FINI]]
+// CHECK:         %[[E1_COMB:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 5
+// CHECK:         store ptr @__omp_taskred_mul_i64.red.comb, ptr %[[E1_COMB]]
+// CHECK:         %[[E1_FL:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[E1]], i32 0, i32 6
+// CHECK:         store i32 0, ptr %[[E1_FL]]
+// CHECK:         call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 2, ptr %[[ARR2]])
+
+// -----
+
+// Plain taskgroup without task_reduction must still translate (regression
+// guard for the rewrite of convertOmpTaskgroupOp).
+
+llvm.func @taskgroup_plain() {
+  omp.taskgroup {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @taskgroup_plain()
+// CHECK:         call void @__kmpc_taskgroup(
+// CHECK-NOT:     call ptr @__kmpc_taskred_init(
+// CHECK:         call void @__kmpc_end_taskgroup(
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 295ba54dbfb38..d5d0a96779db0 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -301,10 +301,62 @@ atomic {
   llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
   omp.yield
 }
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
+llvm.func @taskgroup_task_reduction_byref(%x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause task_reduction with byref modifier in omp.taskgroup operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
+  omp.taskgroup task_reduction(byref @add_f32 %x -> %prv : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+// -----
+
+omp.declare_reduction @add_i32_cleanup : i32
+init {
+^bb0(%arg: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%a: i32, %b: i32):
+  %s = llvm.add %a, %b : i32
+  omp.yield(%s : i32)
+}
+cleanup {
+^bb0(%a: i32):
+  omp.yield
+}
+llvm.func @taskgroup_task_reduction_cleanup(%x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: task_reduction with cleanup region in omp.taskgroup}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
+  omp.taskgroup task_reduction(@add_i32_cleanup %x -> %prv : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+// -----
+
+omp.declare_reduction @add_i32_2arg_init : !llvm.ptr
+alloc {
+^bb0(%mold: !llvm.ptr):
+  %c1 = llvm.mlir.constant(1 : i32) : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+  omp.yield(%0 : !llvm.ptr)
+}
+init {
+^bb0(%mold: !llvm.ptr, %alloc: !llvm.ptr):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  llvm.store %c0, %alloc : i32, !llvm.ptr
+  omp.yield(%alloc : !llvm.ptr)
+}
+combiner {
+^bb0(%a: !llvm.ptr, %b: !llvm.ptr):
+  omp.yield(%a : !llvm.ptr)
+}
+llvm.func @taskgroup_task_reduction_two_arg_init(%x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: task_reduction with two-argument initializer in omp.taskgroup}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
+  omp.taskgroup task_reduction(@add_i32_2arg_init %x -> %prv : !llvm.ptr) {
     omp.terminator
   }
   llvm.return

>From 1ca9edc6fc44ae75eda94bf65e0cd4d2bca2620c Mon Sep 17 00:00:00 2001
From: Sairudra More <moresair at pe31.hpc.amslabs.hpecorp.net>
Date: Tue, 26 May 2026 04:30:16 -0500
Subject: [PATCH 2/2] [mlir][OpenMP] Translate reductions on taskloop

Add LLVM IR translation for reduction and in_reduction clauses on omp.taskloop.context.

For taskloop reduction, emit the implicit taskgroup reduction setup and map each generated task to runtime-provided private reduction storage through __kmpc_task_reduction_get_th_data. For in_reduction, use the same runtime lookup path with a null descriptor to join an enclosing task reduction context.

Unsupported byref, cleanup, and two-argument initializer forms remain diagnosed.

Add MLIR translation tests for the supported taskloop reduction and in_reduction cases.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 227 ++++++++++++++--
 .../LLVMIR/openmp-taskloop-reduction.mlir     | 245 ++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 102 +++++++-
 3 files changed, 543 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-taskloop-reduction.mlir

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index dfec09a53075c..1120d9fc38d0a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -362,7 +362,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("privatization");
   };
   auto checkReduction = [&todo](auto op, LogicalResult &result) {
-    if (isa<omp::TeamsOp>(op) || isa<omp::TaskloopContextOp>(op))
+    if (isa<omp::TeamsOp>(op))
       if (!op.getReductionVars().empty() || op.getReductionByref() ||
           op.getReductionSyms())
         result = todo("reduction");
@@ -378,6 +378,22 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           return;
         }
   };
+  auto checkReductionByref = [&todo](auto op, LogicalResult &result) {
+    if (auto byrefAttr = op.getReductionByref())
+      for (bool isByRef : *byrefAttr)
+        if (isByRef) {
+          result = todo("reduction with byref modifier");
+          return;
+        }
+  };
+  auto checkInReductionByref = [&todo](auto op, LogicalResult &result) {
+    if (auto byrefAttr = op.getInReductionByref())
+      for (bool isByRef : *byrefAttr)
+        if (isByRef) {
+          result = todo("in_reduction with byref modifier");
+          return;
+        }
+  };
   auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
@@ -437,8 +453,9 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::TaskloopContextOp op) {
         checkAllocate(op, result);
-        checkInReduction(op, result);
+        checkInReductionByref(op, result);
         checkReduction(op, result);
+        checkReductionByref(op, result);
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
@@ -3327,6 +3344,15 @@ computeTaskloopBounds(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder,
   return llvm::Error::success();
 }
 
+// Forward declaration: defined alongside the taskgroup task_reduction
+// lowering further down in this file. Shared between omp.taskgroup and
+// omp.taskloop.context translation.
+static llvm::Value *emitTaskReductionInitCall(
+    ArrayRef<omp::DeclareReductionOp> redDecls,
+    ArrayRef<llvm::Value *> origPtrs, StringRef helperNamePrefix,
+    llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
+    LLVM::ModuleTranslation &moduleTranslation);
+
 // Converts an OpenMP taskloop construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
@@ -3417,6 +3443,90 @@ convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
   // Set up inserttion point for call to createTaskloop()
   builder.SetInsertPoint(taskloopStartBlock);
 
+  // Resolve and validate reduction / in_reduction declarations. Only the
+  // non-byref, single-init-arg, no-cleanup form is supported in this first
+  // cut; richer shapes have been rejected by checkImplementationStatus
+  // (byref) or are rejected here.
+  auto resolveRedDecls =
+      [&](std::optional<ArrayAttr> syms, StringRef clauseName,
+          SmallVectorImpl<omp::DeclareReductionOp> &out) -> LogicalResult {
+    if (!syms)
+      return success();
+    out.reserve(syms->size());
+    for (auto sym : syms->getAsRange<SymbolRefAttr>()) {
+      auto decl = SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
+          contextOp, sym);
+      if (!decl)
+        return contextOp.emitError()
+               << "failed to resolve " << clauseName
+               << " declare_reduction symbol " << sym.getRootReference()
+               << " in omp.taskloop.context";
+      if (decl.getInitializerRegion().front().getNumArguments() != 1)
+        return contextOp.emitError()
+               << "not yet implemented: " << clauseName
+               << " with two-argument initializer in omp.taskloop.context";
+      if (!decl.getCleanupRegion().empty())
+        return contextOp.emitError()
+               << "not yet implemented: " << clauseName
+               << " with cleanup region in omp.taskloop.context";
+      if (decl.getReductionRegion().empty())
+        return contextOp.emitError()
+               << clauseName
+               << " declare_reduction is missing a combiner region";
+      out.push_back(decl);
+    }
+    return success();
+  };
+
+  SmallVector<omp::DeclareReductionOp> redDecls;
+  if (failed(
+          resolveRedDecls(contextOp.getReductionSyms(), "reduction", redDecls)))
+    return failure();
+  SmallVector<omp::DeclareReductionOp> inRedDecls;
+  if (failed(resolveRedDecls(contextOp.getInReductionSyms(), "in_reduction",
+                             inRedDecls)))
+    return failure();
+
+  // The op verifier rejects nogroup + reduction, so no check is needed here.
+
+  SmallVector<llvm::Value *> redOrigPtrs;
+  redOrigPtrs.reserve(redDecls.size());
+  for (Value v : contextOp.getReductionVars())
+    redOrigPtrs.push_back(moduleTranslation.lookupValue(v));
+  SmallVector<llvm::Value *> inRedOrigPtrs;
+  inRedOrigPtrs.reserve(inRedDecls.size());
+  for (Value v : contextOp.getInReductionVars())
+    inRedOrigPtrs.push_back(moduleTranslation.lookupValue(v));
+
+  llvm::OpenMPIRBuilder &ompBuilderRef = *moduleTranslation.getOpenMPBuilder();
+  llvm::Module *llvmModuleForRed = moduleTranslation.getLLVMModule();
+
+  // If we have task_reduction items, we must emit our own implicit
+  // __kmpc_taskgroup so that the descriptor returned by __kmpc_taskred_init
+  // is associated with that taskgroup. We then force NoGroup=true so that
+  // OpenMPIRBuilder::createTaskloop does not emit a second taskgroup.
+  bool implicitTaskgroup = !redDecls.empty();
+  llvm::Value *redDesc = nullptr;
+  if (implicitTaskgroup) {
+    uint32_t srcLocSize;
+    llvm::Constant *srcLocStr =
+        ompBuilderRef.getOrCreateDefaultSrcLocStr(srcLocSize);
+    llvm::Value *ident = ompBuilderRef.getOrCreateIdent(srcLocStr, srcLocSize);
+    llvm::Function *gtidFn = ompBuilderRef.getOrCreateRuntimeFunctionPtr(
+        llvm::omp::OMPRTL___kmpc_global_thread_num);
+    llvm::Value *outerGtid =
+        builder.CreateCall(gtidFn, {ident}, "omp_global_thread_num");
+    llvm::FunctionCallee taskgroupFn = ompBuilderRef.getOrCreateRuntimeFunction(
+        *llvmModuleForRed, llvm::omp::OMPRTL___kmpc_taskgroup);
+    builder.CreateCall(taskgroupFn, {ident, outerGtid});
+
+    redDesc = emitTaskReductionInitCall(redDecls, redOrigPtrs,
+                                        "__omp_taskloop_taskred_", builder,
+                                        allocaIP, moduleTranslation);
+    if (!redDesc)
+      return failure();
+  }
+
   auto loopOp = cast<omp::LoopNestOp>(loopWrapperOp.getWrappedLoop());
   llvm::Value *lbVal = nullptr;
   llvm::Value *ubVal = nullptr;
@@ -3491,6 +3601,49 @@ convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
       moduleTranslation.mapValue(blockArg, llvmPrivateVar);
     }
 
+    // Map reduction and in_reduction block arguments to the per-task private
+    // storage returned by __kmpc_task_reduction_get_th_data. This call must
+    // be emitted inside the to-be-outlined task body so that it returns the
+    // *executing* thread's gtid (not the encountering thread's). The
+    // taskgroup descriptor `redDesc` is computed in the outer scope and is
+    // auto-captured into the task shareds aggregate by CodeExtractor during
+    // OpenMPIRBuilder::finalize. For in_reduction the descriptor is NULL:
+    // the runtime walks up enclosing taskgroups to find the matching
+    // task_reduction registration for `origPtr`.
+    if (!redDecls.empty() || !inRedDecls.empty()) {
+      auto iface =
+          cast<omp::BlockArgOpenMPOpInterface>(contextOp.getOperation());
+      llvm::OpenMPIRBuilder &ompB = *moduleTranslation.getOpenMPBuilder();
+      llvm::Module *m = moduleTranslation.getLLVMModule();
+      llvm::LLVMContext &llvmCtx = m->getContext();
+      uint32_t srcLocSize;
+      llvm::Constant *srcLocStr = ompB.getOrCreateDefaultSrcLocStr(srcLocSize);
+      llvm::Value *bodyIdent = ompB.getOrCreateIdent(srcLocStr, srcLocSize);
+      llvm::Function *gtidFn = ompB.getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_global_thread_num);
+      llvm::Value *bodyGtid =
+          builder.CreateCall(gtidFn, {bodyIdent}, "omp_global_thread_num");
+      llvm::FunctionCallee getThData = ompB.getOrCreateRuntimeFunction(
+          *m, llvm::omp::OMPRTL___kmpc_task_reduction_get_th_data);
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(llvmCtx);
+
+      ArrayRef<BlockArgument> redBlockArgs = iface.getReductionBlockArgs();
+      for (auto [blockArg, origPtr] :
+           llvm::zip_equal(redBlockArgs, redOrigPtrs)) {
+        llvm::Value *priv = builder.CreateCall(
+            getThData, {bodyGtid, redDesc, origPtr}, "omp.taskred.priv");
+        moduleTranslation.mapValue(blockArg, priv);
+      }
+      ArrayRef<BlockArgument> inRedBlockArgs = iface.getInReductionBlockArgs();
+      llvm::Value *nullDesc = llvm::ConstantPointerNull::get(ptrTy);
+      for (auto [blockArg, origPtr] :
+           llvm::zip_equal(inRedBlockArgs, inRedOrigPtrs)) {
+        llvm::Value *priv = builder.CreateCall(
+            getThData, {bodyGtid, nullDesc, origPtr}, "omp.inred.priv");
+        moduleTranslation.mapValue(blockArg, priv);
+      }
+    }
+
     // Lower the contents of the taskloop context region: this is the body of
     // the generated task, not the loop.
     auto continuationBlockOrError = convertOmpOpRegions(
@@ -3626,12 +3779,12 @@ convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
                            llvm::omp::Directive::OMPD_taskgroup);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  bool effectiveNoGroup = contextOp.getNogroup() || implicitTaskgroup;
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTaskloop(
           ompLoc, allocaIP, deallocBlocks, bodyCB, loopInfo, lbVal, ubVal,
-          stepVal, contextOp.getUntied(), ifCond, grainsize,
-          contextOp.getNogroup(), sched,
-          moduleTranslation.lookupValue(contextOp.getFinal()),
+          stepVal, contextOp.getUntied(), ifCond, grainsize, effectiveNoGroup,
+          sched, moduleTranslation.lookupValue(contextOp.getFinal()),
           contextOp.getMergeable(),
           moduleTranslation.lookupValue(contextOp.getPriority()),
           loopOp.getCollapseNumLoops(), taskDupOrNull,
@@ -3643,6 +3796,23 @@ convertOmpTaskloopContextOp(omp::TaskloopContextOp contextOp,
   popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
 
   builder.restoreIP(*afterIP);
+
+  // Close the implicit taskgroup we opened for task_reduction. The end call
+  // must execute on the encountering thread, so use the outer-scope gtid.
+  if (implicitTaskgroup) {
+    uint32_t srcLocSize;
+    llvm::Constant *srcLocStr =
+        ompBuilder.getOrCreateDefaultSrcLocStr(srcLocSize);
+    llvm::Value *ident = ompBuilder.getOrCreateIdent(srcLocStr, srcLocSize);
+    llvm::Function *gtidFn = ompBuilder.getOrCreateRuntimeFunctionPtr(
+        llvm::omp::OMPRTL___kmpc_global_thread_num);
+    llvm::Value *outerGtid =
+        builder.CreateCall(gtidFn, {ident}, "omp_global_thread_num");
+    llvm::FunctionCallee endTgFn = ompBuilder.getOrCreateRuntimeFunction(
+        *moduleTranslation.getLLVMModule(),
+        llvm::omp::OMPRTL___kmpc_end_taskgroup);
+    builder.CreateCall(endTgFn, {ident, outerGtid});
+  }
   return success();
 }
 
@@ -3737,16 +3907,25 @@ emitTaskReductionCombFn(omp::DeclareReductionOp decl, StringRef baseName,
 }
 
 /// Emit the per-taskgroup task_reduction descriptor array and the
-/// `__kmpc_taskred_init` runtime call. Must be called with `builder` set to a
-/// point inside the taskgroup body (after `__kmpc_taskgroup`). The descriptor
-/// array itself is allocated at \p allocaIP.
+/// `__kmpc_taskred_init` runtime call. \p origPtrs holds the LLVM values for
+/// the original (shared) variables, one per declaration in \p redDecls.
+/// `builder` must be set to the point at which the descriptor stores and the
+/// init call should be emitted; the descriptor array itself is allocated at
+/// \p allocaIP. \p helperNamePrefix is used to disambiguate the generated
+/// init/combiner helper symbol names between taskgroup and taskloop callers.
+///
+/// Returns the `ptr` value produced by `__kmpc_taskred_init` (the taskgroup
+/// reduction handle), or null on failure.
 ///
-/// Only the non-byref form is handled here. Byref task_reduction has already
+/// Only the non-byref form is handled here. Byref reductions have already
 /// been rejected by `checkImplementationStatus`.
-static LogicalResult emitTaskgroupTaskReductionInit(
-    omp::TaskgroupOp tgOp, ArrayRef<omp::DeclareReductionOp> redDecls,
+static llvm::Value *emitTaskReductionInitCall(
+    ArrayRef<omp::DeclareReductionOp> redDecls,
+    ArrayRef<llvm::Value *> origPtrs, StringRef helperNamePrefix,
     llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
     LLVM::ModuleTranslation &moduleTranslation) {
+  assert(redDecls.size() == origPtrs.size() &&
+         "expected one orig pointer per reduction decl");
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
   llvm::LLVMContext &ctx = llvmModule->getContext();
@@ -3778,23 +3957,22 @@ static LogicalResult emitTaskgroupTaskReductionInit(
         builder.CreateAlloca(arrTy, /*ArraySize=*/nullptr, ".taskred.input");
   }
 
-  // Fill each descriptor entry inside the taskgroup body.
+  // Fill each descriptor entry at the current builder insertion point.
   llvm::Value *zero = builder.getInt32(0);
   for (unsigned i = 0; i < n; ++i) {
     omp::DeclareReductionOp decl = redDecls[i];
-    llvm::Value *orig =
-        moduleTranslation.lookupValue(tgOp.getTaskReductionVars()[i]);
+    llvm::Value *orig = origPtrs[i];
     llvm::Type *elemTy = moduleTranslation.convertType(decl.getType());
     uint64_t size = dl.getTypeAllocSize(elemTy).getFixedValue();
 
     std::string baseName =
-        (llvm::Twine("__omp_taskred_") + decl.getSymName()).str();
+        (llvm::Twine(helperNamePrefix) + decl.getSymName()).str();
     llvm::Function *initFn =
         emitTaskReductionInitFn(decl, baseName, moduleTranslation);
     llvm::Function *combFn =
         emitTaskReductionCombFn(decl, baseName, moduleTranslation);
     if (!initFn || !combFn)
-      return failure();
+      return nullptr;
     llvm::Value *elemPtr = builder.CreateInBoundsGEP(
         arrTy, arrAlloca, {zero, builder.getInt32(i)}, ".taskred.elem");
     auto storeField = [&](unsigned fieldIdx, llvm::Value *val) {
@@ -3816,11 +3994,14 @@ static LogicalResult emitTaskgroupTaskReductionInit(
   llvm::Constant *srcLocStr =
       ompBuilder->getOrCreateDefaultSrcLocStr(srcLocSize);
   llvm::Value *ident = ompBuilder->getOrCreateIdent(srcLocStr, srcLocSize);
-  llvm::Value *gtid = ompBuilder->getOrCreateThreadID(ident);
+  llvm::Function *gtidFn = ompBuilder->getOrCreateRuntimeFunctionPtr(
+      llvm::omp::OMPRTL___kmpc_global_thread_num);
+  llvm::Value *gtid =
+      builder.CreateCall(gtidFn, {ident}, "omp_global_thread_num");
   llvm::FunctionCallee taskredInit = ompBuilder->getOrCreateRuntimeFunction(
       *llvmModule, llvm::omp::OMPRTL___kmpc_taskred_init);
-  builder.CreateCall(taskredInit, {gtid, builder.getInt32(n), arrAlloca});
-  return success();
+  return builder.CreateCall(taskredInit, {gtid, builder.getInt32(n), arrAlloca},
+                            ".taskred.desc");
 }
 
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
@@ -3864,8 +4045,12 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     builder.restoreIP(codegenIP);
 
     if (!redDecls.empty()) {
-      if (failed(emitTaskgroupTaskReductionInit(tgOp, redDecls, builder,
-                                                allocaIP, moduleTranslation)))
+      SmallVector<llvm::Value *> origPtrs;
+      origPtrs.reserve(redDecls.size());
+      for (Value v : tgOp.getTaskReductionVars())
+        origPtrs.push_back(moduleTranslation.lookupValue(v));
+      if (!emitTaskReductionInitCall(redDecls, origPtrs, "__omp_taskred_",
+                                     builder, allocaIP, moduleTranslation))
         return llvm::createStringError(
             llvm::inconvertibleErrorCode(),
             "failed to emit task_reduction initialization for omp.taskgroup");
diff --git a/mlir/test/Target/LLVMIR/openmp-taskloop-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-taskloop-reduction.mlir
new file mode 100644
index 0000000000000..0043f75bfe227
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-taskloop-reduction.mlir
@@ -0,0 +1,245 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// Single scalar reduction on omp.taskloop.context. The lowering must:
+//   1. Emit an implicit __kmpc_taskgroup in the encountering function (since
+//      the user did not write nogroup);
+//   2. Build a kmp_taskred_input_t descriptor array and call
+//      __kmpc_taskred_init, capturing the returned descriptor handle;
+//   3. Force nogroup=1 on the inner __kmpc_taskloop call so that the
+//      OpenMPIRBuilder does not emit a second taskgroup;
+//   4. Inside the outlined task body, call __kmpc_global_thread_num to obtain
+//      the executing thread's gtid, then look up the per-task private storage
+//      via __kmpc_task_reduction_get_th_data(gtid, redDesc, orig);
+//   5. Close the implicit taskgroup with __kmpc_end_taskgroup.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @taskloop_reduction_single(%x : !llvm.ptr, %lb : i32, %ub : i32, %step : i32) {
+  omp.taskloop.context reduction(@add_i32 %x -> %prv : !llvm.ptr) {
+    omp.taskloop.wrapper {
+      omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+        %v = llvm.load %prv : !llvm.ptr -> i32
+        %s = llvm.add %v, %iv : i32
+        llvm.store %s, %prv : i32, !llvm.ptr
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK: %kmp_taskred_input_t = type { ptr, ptr, i64, ptr, ptr, ptr, i32 }
+
+// Encountering function emits taskgroup + descriptor + taskred_init.
+// CHECK-LABEL: define void @taskloop_reduction_single(
+// CHECK-SAME:    ptr %[[X:[^,]+]],
+// CHECK:         %[[ARR:.+]] = alloca [1 x %kmp_taskred_input_t]
+// CHECK:         call void @__kmpc_taskgroup(
+// CHECK:         %[[ELEM:.+]] = getelementptr inbounds [1 x %kmp_taskred_input_t], ptr %[[ARR]], i32 0, i32 0
+// CHECK:         %[[SHAR:.+]] = getelementptr {{.+}} %kmp_taskred_input_t, ptr %[[ELEM]], i32 0, i32 0
+// CHECK:         store ptr %[[X]], ptr %[[SHAR]]
+// CHECK:         store ptr @__omp_taskloop_taskred_add_i32.red.init
+// CHECK:         store ptr @__omp_taskloop_taskred_add_i32.red.comb
+// CHECK:         %[[DESC:.+]] = call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 1, ptr %[[ARR]])
+// The returned descriptor is stored into the structArg captured by
+// __kmpc_omp_task_alloc so the outlined task body can load it back.
+// CHECK:         store ptr %[[DESC]], ptr %{{.+}}
+// __kmpc_taskloop must be called with nogroup=1 because we already opened
+// our own taskgroup above.
+// CHECK:         call void @__kmpc_taskloop(ptr {{.+}}, i32 {{.+}}, ptr {{.+}}, i32 1,
+// CHECK:         call void @__kmpc_end_taskgroup(
+
+// Outlined task body looks up per-task storage via the runtime, passing the
+// reloaded descriptor (not null) as the second argument.
+// CHECK-LABEL: define internal void @taskloop_reduction_single..omp_par(
+// CHECK:         %[[BODY_DESC:.+]] = load ptr, ptr %gep_.taskred.desc
+// CHECK:         %[[BODY_ORIG:.+]] = load ptr, ptr %gep_,
+// CHECK:         %[[BODY_GTID:.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK:         %[[PRIV:.+]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[BODY_GTID]], ptr %[[BODY_DESC]], ptr %[[BODY_ORIG]])
+// CHECK:         load i32, ptr %[[PRIV]]
+// CHECK:         store i32 %{{.+}}, ptr %[[PRIV]]
+
+// -----
+
+// Multiple reductions: each entry in the descriptor array gets distinct
+// init / combiner helpers and the body issues one
+// __kmpc_task_reduction_get_th_data per reduction.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+omp.declare_reduction @mul_i64 : i64
+init {
+^bb0(%arg0: i64):
+  %c1 = llvm.mlir.constant(1 : i64) : i64
+  omp.yield(%c1 : i64)
+}
+combiner {
+^bb0(%arg0: i64, %arg1: i64):
+  %p = llvm.mul %arg0, %arg1 : i64
+  omp.yield(%p : i64)
+}
+
+llvm.func @taskloop_reduction_multi(%x : !llvm.ptr, %y : !llvm.ptr, %lb : i32, %ub : i32, %step : i32) {
+  omp.taskloop.context reduction(@add_i32 %x -> %a, @mul_i64 %y -> %b : !llvm.ptr, !llvm.ptr) {
+    omp.taskloop.wrapper {
+      omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+        %va = llvm.load %a : !llvm.ptr -> i32
+        %vai = llvm.add %va, %iv : i32
+        llvm.store %vai, %a : i32, !llvm.ptr
+        %vb = llvm.load %b : !llvm.ptr -> i64
+        %iv64 = llvm.sext %iv : i32 to i64
+        %vbi = llvm.mul %vb, %iv64 : i64
+        llvm.store %vbi, %b : i64, !llvm.ptr
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @taskloop_reduction_multi(
+// CHECK:         %[[ARR2:.+]] = alloca [2 x %kmp_taskred_input_t]
+// CHECK:         call void @__kmpc_taskgroup(
+// CHECK:         store i64 4
+// CHECK:         store ptr @__omp_taskloop_taskred_add_i32.red.init
+// CHECK:         store ptr @__omp_taskloop_taskred_add_i32.red.comb
+// CHECK:         store i64 8
+// CHECK:         store ptr @__omp_taskloop_taskred_mul_i64.red.init
+// CHECK:         store ptr @__omp_taskloop_taskred_mul_i64.red.comb
+// CHECK:         %[[DESC2:.+]] = call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 2, ptr %[[ARR2]])
+// The descriptor is captured into structArg so the outlined task can reload it.
+// CHECK:         store ptr %[[DESC2]], ptr %{{.+}}
+// CHECK:         call void @__kmpc_taskloop(ptr {{.+}}, i32 {{.+}}, ptr {{.+}}, i32 1,
+// CHECK:         call void @__kmpc_end_taskgroup(
+
+// CHECK-LABEL: define internal void @taskloop_reduction_multi..omp_par(
+// CHECK:         %[[BODY_GTID2:.+]] = call i32 @__kmpc_global_thread_num(
+// Both get_th_data calls share the same body gtid; the descriptor argument
+// must be a reloaded SSA value (not null).
+// CHECK:         call ptr @__kmpc_task_reduction_get_th_data(i32 %[[BODY_GTID2]], ptr %{{[^,]+}}, ptr %{{.+}})
+// CHECK:         call ptr @__kmpc_task_reduction_get_th_data(i32 %[[BODY_GTID2]], ptr %{{[^,]+}}, ptr %{{.+}})
+
+// -----
+
+// in_reduction on omp.taskloop.context nested inside an outer taskgroup
+// task_reduction. No new __kmpc_taskgroup must be emitted for the taskloop
+// itself (the user did not write reduction on it), and the get_th_data call
+// must pass a NULL descriptor so the runtime walks up to the enclosing
+// taskgroup.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @taskloop_inreduction(%x : !llvm.ptr, %lb : i32, %ub : i32, %step : i32) {
+  omp.taskgroup task_reduction(@add_i32 %x -> %tg : !llvm.ptr) {
+    omp.taskloop.context in_reduction(@add_i32 %x -> %prv : !llvm.ptr) {
+      omp.taskloop.wrapper {
+        omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+          %v = llvm.load %prv : !llvm.ptr -> i32
+          %s = llvm.add %v, %iv : i32
+          llvm.store %s, %prv : i32, !llvm.ptr
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @taskloop_inreduction(
+// Outer taskgroup opens once; we expect only ONE __kmpc_taskgroup for the
+// outer construct (the taskloop itself must not open a second one).
+// CHECK:         call void @__kmpc_taskgroup(
+// CHECK-NOT:     call void @__kmpc_taskgroup(
+// The outer descriptor is built; the taskloop must NOT build its own
+// taskred_init.
+// CHECK:         call ptr @__kmpc_taskred_init(
+// CHECK-NOT:     call ptr @__kmpc_taskred_init(
+// CHECK:         call void @__kmpc_taskloop(
+// CHECK:         call void @__kmpc_end_taskgroup(
+
+// In the outlined taskloop task body, the in_reduction lookup passes NULL
+// as the descriptor argument so the runtime walks up enclosing taskgroups.
+// CHECK-LABEL: define internal void @taskloop_inreduction..omp_par(
+// CHECK:         call i32 @__kmpc_global_thread_num(
+// CHECK:         call ptr @__kmpc_task_reduction_get_th_data(i32 %{{.+}}, ptr null, ptr %{{.+}})
+
+// -----
+
+// nogroup + in_reduction: the user wrote `nogroup` on the taskloop and only an
+// in_reduction clause, so the translator must NOT open an implicit taskgroup
+// and must NOT build a taskred descriptor for the taskloop itself; `nogroup`
+// must be propagated to __kmpc_taskloop as 1, and the outlined body must look
+// up the participant with a NULL descriptor so the runtime walks up.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @taskloop_nogroup_inreduction(%x : !llvm.ptr, %lb : i32, %ub : i32, %step : i32) {
+  omp.taskloop.context nogroup in_reduction(@add_i32 %x -> %prv : !llvm.ptr) {
+    omp.taskloop.wrapper {
+      omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+        %v = llvm.load %prv : !llvm.ptr -> i32
+        %s = llvm.add %v, %iv : i32
+        llvm.store %s, %prv : i32, !llvm.ptr
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// Outer caller: no implicit taskgroup, no taskred_init, nogroup=1 to taskloop.
+// CHECK-LABEL: define void @taskloop_nogroup_inreduction(
+// CHECK-NOT:     call void @__kmpc_taskgroup(
+// CHECK-NOT:     call ptr @__kmpc_taskred_init(
+// CHECK-NOT:     call void @__kmpc_end_taskgroup(
+// CHECK:         call void @__kmpc_taskloop(ptr {{[^,]+}}, i32 {{[^,]+}}, ptr {{[^,]+}}, i32 1,
+
+// In the outlined task body, the in_reduction lookup uses a NULL descriptor.
+// CHECK-LABEL: define internal void @taskloop_nogroup_inreduction..omp_par(
+// CHECK:         call i32 @__kmpc_global_thread_num(
+// CHECK:         call ptr @__kmpc_task_reduction_get_th_data(i32 %{{.+}}, ptr null, ptr %{{.+}})
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d5d0a96779db0..5c22f7f081bb5 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -378,20 +378,20 @@ llvm.func @taskloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr)
 }
 
 // -----
- omp.declare_reduction @add_reduction_i32 : i32 init {
+omp.declare_reduction @add_reduction_i32 : i32 init {
   ^bb0(%arg0: i32):
     %0 = llvm.mlir.constant(0 : i32) : i32
     omp.yield(%0 : i32)
-  }combiner {
+  } combiner {
   ^bb0(%arg0: i32, %arg1: i32):
     %0 = llvm.add %arg0, %arg1 : i32
     omp.yield(%0 : i32)
   }
 
-llvm.func @taskloop_inreduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+llvm.func @taskloop_inreduction_byref(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause in_reduction with byref modifier in omp.taskloop.context operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.taskloop.context}}
-  // expected-error at below {{not yet implemented: Unhandled clause in_reduction in omp.taskloop.context operation}}
-  omp.taskloop.context in_reduction(@add_reduction_i32 %x -> %arg0 : !llvm.ptr) {
+  omp.taskloop.context in_reduction(byref @add_reduction_i32 %x -> %arg0 : !llvm.ptr) {
     omp.taskloop.wrapper {
       omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
         omp.yield
@@ -403,20 +403,102 @@ llvm.func @taskloop_inreduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.pt
 }
 
 // -----
- omp.declare_reduction @add_reduction_i32 : i32 init {
+omp.declare_reduction @add_reduction_i32 : i32 init {
   ^bb0(%arg0: i32):
     %0 = llvm.mlir.constant(0 : i32) : i32
     omp.yield(%0 : i32)
-  }combiner {
+  } combiner {
   ^bb0(%arg0: i32, %arg1: i32):
     %0 = llvm.add %arg0, %arg1 : i32
     omp.yield(%0 : i32)
   }
 
-llvm.func @taskloop_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+llvm.func @taskloop_reduction_byref(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause reduction with byref modifier in omp.taskloop.context operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.taskloop.context}}
-  // expected-error at below {{not yet implemented: Unhandled clause reduction in omp.taskloop.context operation}}
-  omp.taskloop.context reduction(@add_reduction_i32 %x -> %arg0 : !llvm.ptr) {
+  omp.taskloop.context reduction(byref @add_reduction_i32 %x -> %arg0 : !llvm.ptr) {
+    omp.taskloop.wrapper {
+      omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+omp.declare_reduction @add_reduction_cleanup_i32 : i32 init {
+  ^bb0(%arg0: i32):
+    %0 = llvm.mlir.constant(0 : i32) : i32
+    omp.yield(%0 : i32)
+  } combiner {
+  ^bb0(%arg0: i32, %arg1: i32):
+    %0 = llvm.add %arg0, %arg1 : i32
+    omp.yield(%0 : i32)
+  } cleanup {
+  ^bb0(%arg0: i32):
+    omp.yield
+  }
+
+llvm.func @taskloop_reduction_cleanup(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: reduction with cleanup region in omp.taskloop.context}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.taskloop.context}}
+  omp.taskloop.context reduction(@add_reduction_cleanup_i32 %x -> %arg0 : !llvm.ptr) {
+    omp.taskloop.wrapper {
+      omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_reduction_modifier_i32 : i32 init {
+  ^bb0(%arg0: i32):
+    %0 = llvm.mlir.constant(0 : i32) : i32
+    omp.yield(%0 : i32)
+  } combiner {
+  ^bb0(%arg0: i32, %arg1: i32):
+    %0 = llvm.add %arg0, %arg1 : i32
+    omp.yield(%0 : i32)
+  }
+
+llvm.func @taskloop_reduction_modifier(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.taskloop.context operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.taskloop.context}}
+  omp.taskloop.context reduction(mod:inscan, @add_reduction_modifier_i32 %x -> %arg0 : !llvm.ptr) {
+    omp.taskloop.wrapper {
+      omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_reduction_two_arg_init_i32 : !llvm.ptr alloc {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
+  omp.yield(%1 : !llvm.ptr)
+} init {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    omp.yield(%arg1 : !llvm.ptr)
+} combiner {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    omp.yield(%arg0 : !llvm.ptr)
+}
+
+llvm.func @taskloop_reduction_two_arg_init(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: reduction with two-argument initializer in omp.taskloop.context}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.taskloop.context}}
+  omp.taskloop.context reduction(@add_reduction_two_arg_init_i32 %x -> %arg0 : !llvm.ptr) {
     omp.taskloop.wrapper {
       omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
         omp.yield



More information about the Mlir-commits mailing list