[flang-commits] [flang] [mlir] [flang][OpenMP] Lower target in_reduction for host fallback (PR #199967)

Sairudra More via flang-commits flang-commits at lists.llvm.org
Wed Jun 24 02:02:27 PDT 2026


https://github.com/Saieiei updated https://github.com/llvm/llvm-project/pull/199967

>From 071dfde8281c6d272d5fa63df3b174072d40f274 Mon Sep 17 00:00:00 2001
From: Sairudra More <sairudra60 at gmail.com>
Date: Fri, 12 Jun 2026 04:57:07 -0500
Subject: [PATCH] [flang][OpenMP] Lower target in_reduction for host fallback

Enable host-fallback lowering for target in_reduction in Flang and MLIR OpenMP translation.

Model target in_reduction through the matching map entry, force address-preserving implicit mapping for Flang in_reduction list items, and emit the host-side task-reduction lookup with __kmpc_task_reduction_get_th_data. The runtime entry point takes and returns a generic, default-address-space pointer, so normalize a non-default-address-space captured pointer to the generic address space before the call and cast the returned private pointer back to the map block argument's address space, mirroring the in_reduction handling on omp.taskloop. Unsupported device/offload-entry and richer reduction forms remain diagnosed.

Add Flang lowering, MLIR verifier/translation, and LLVM IR tests for the supported host-fallback path, including a non-default-address-space case, and the remaining unsupported cases.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  63 ++++++--
 .../Lower/OpenMP/Todo/target-inreduction.f90  |  15 --
 .../OpenMP/target-inreduction-common.f90      |  30 ++++
 .../OpenMP/target-inreduction-llvmir.f90      |  42 +++++
 .../OpenMP/target-inreduction-unused.f90      |  27 ++++
 .../test/Lower/OpenMP/target-inreduction.f90  |  30 ++++
 .../OpenMP/function-filtering-host-ops.mlir   |  11 +-
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  14 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  26 ++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 118 +++++++++++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 149 ++++++++++++++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         |  60 +++++++
 .../openmp-target-in-reduction-multi.mlir     |  75 +++++++++
 .../LLVMIR/openmp-target-in-reduction.mlir    | 107 +++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 113 ++++++++++++-
 15 files changed, 796 insertions(+), 84 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/target-inreduction.f90
 create mode 100644 flang/test/Lower/OpenMP/target-inreduction-common.f90
 create mode 100644 flang/test/Lower/OpenMP/target-inreduction-llvmir.f90
 create mode 100644 flang/test/Lower/OpenMP/target-inreduction-unused.f90
 create mode 100644 flang/test/Lower/OpenMP/target-inreduction.f90
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-in-reduction-multi.mlir
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index facca9867e4bb..865b1b2b68d76 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1988,6 +1988,7 @@ genTargetClauses(lower::AbstractConverter &converter,
                  mlir::omp::TargetOperands &clauseOps,
                  DefaultMapsTy &defaultMaps,
                  llvm::SmallVectorImpl<Object> &hasDeviceAddrObjects,
+                 llvm::SmallVectorImpl<Object> &inReductionObjects,
                  llvm::SmallVectorImpl<Object> &isDevicePtrObjects,
                  llvm::SmallVectorImpl<Object> &mapObjects) {
   ClauseProcessor cp(converter, semaCtx, clauses);
@@ -2003,12 +2004,13 @@ genTargetClauses(lower::AbstractConverter &converter,
     hostEvalInfo->collectValues(clauseOps.hostEvalVars);
   }
   cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
+  cp.processInReduction(loc, clauseOps, inReductionObjects);
   cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrObjects);
   cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
                 &mapObjects);
   cp.processNowait(clauseOps);
   cp.processThreadLimit(stmtCtx, clauseOps);
-  cp.processTODO<clause::Allocate, clause::InReduction, clause::UsesAllocators>(
+  cp.processTODO<clause::Allocate, clause::UsesAllocators>(
       loc, llvm::omp::Directive::OMPD_target);
 }
 
@@ -3046,10 +3048,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   mlir::omp::TargetOperands clauseOps;
   DefaultMapsTy defaultMaps;
   llvm::SmallVector<Object> mapObjects, hasDeviceAddrObjects,
-      isDevicePtrObjects;
+      inReductionObjects, isDevicePtrObjects;
   genTargetClauses(converter, semaCtx, symTable, stmtCtx, eval, item->clauses,
                    loc, clauseOps, defaultMaps, hasDeviceAddrObjects,
-                   isDevicePtrObjects, mapObjects);
+                   inReductionObjects, isDevicePtrObjects, mapObjects);
 
   if (!isDevicePtrObjects.empty()) {
     // is_device_ptr maps get duplicated so the clause and synthesized
@@ -3103,7 +3105,16 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   // symbols used inside the region that do not have explicit data-environment
   // attribute clauses (neither data-sharing; e.g. `private`, nor `map`
   // clauses).
-  auto captureImplicitMap = [&](const semantics::Symbol &sym) {
+  //
+  // When `forceAddressPreserving` is set, the symbol is force-mapped as an
+  // address-preserving `capture(ByRef)` with implicit `tofrom` flags,
+  // bypassing the scalar default capture rules. This is used for `target
+  // in_reduction` list items, whose mapped pointer is passed as the `orig`
+  // argument of `__kmpc_task_reduction_get_th_data`; a ByCopy scalar capture
+  // would break the runtime lookup against the enclosing taskgroup's
+  // task_reduction descriptor.
+  auto captureImplicitMap = [&](const semantics::Symbol &sym,
+                                bool forceAddressPreserving = false) {
     // Structure component symbols don't have bindings, and can only be
     // explicitly mapped individually. If a member is captured implicitly
     // we map the entirety of the derived type when we find its symbol.
@@ -3112,12 +3123,14 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
     // if the symbol is part of an already mapped common block, do not make a
     // map for it.
-    if (const Fortran::semantics::Symbol *common =
-            Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
-      if (llvm::any_of(mapObjects, [=](const Object &object) {
-            return object.sym() == common;
-          }))
-        return;
+    if (!forceAddressPreserving) {
+      if (const Fortran::semantics::Symbol *common =
+              Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
+        if (llvm::any_of(mapObjects, [=](const Object &object) {
+              return object.sym() == common;
+            }))
+          return;
+    }
 
     // If we come across a symbol without a symbol address, we
     // return as we cannot process it, this is intended as a
@@ -3167,13 +3180,21 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
         eleType = refType.getElementType();
 
       std::pair<mlir::omp::ClauseMapFlags, mlir::omp::VariableCaptureKind>
-          mapFlagAndKind = getImplicitMapTypeAndKind(
-              firOpBuilder, converter, defaultMaps, eleType, loc, sym);
+          mapFlagAndKind;
+      if (forceAddressPreserving)
+        mapFlagAndKind = {mlir::omp::ClauseMapFlags::implicit |
+                              mlir::omp::ClauseMapFlags::to |
+                              mlir::omp::ClauseMapFlags::from,
+                          mlir::omp::VariableCaptureKind::ByRef};
+      else
+        mapFlagAndKind = getImplicitMapTypeAndKind(
+            firOpBuilder, converter, defaultMaps, eleType, loc, sym);
 
       mlir::FlatSymbolRefAttr mapperId;
       auto defaultmapBehaviour = getDefaultmapIfPresent(defaultMaps, eleType);
-      if (defaultmapBehaviour ==
-          clause::Defaultmap::ImplicitBehavior::Default) {
+      if (!forceAddressPreserving &&
+          defaultmapBehaviour ==
+              clause::Defaultmap::ImplicitBehavior::Default) {
         const semantics::DerivedTypeSpec *typeSpec =
             sym.GetType() ? sym.GetType()->AsDerived() : nullptr;
         if (typeSpec) {
@@ -3227,6 +3248,15 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
           Object{const_cast<semantics::Symbol *>(&sym), std::nullopt});
     }
   };
+  // OpenMP requires `in_reduction` list items on `target` to be implicitly
+  // data-mapped. Force-map them as address-preserving captures before the
+  // generic implicit-map walk so that walk treats the symbols as already
+  // mapped via `isDuplicateMappedSymbol` and does not downgrade them to
+  // ByCopy.
+  for (const Object &object : inReductionObjects)
+    if (const semantics::Symbol *sym = object.sym())
+      captureImplicitMap(*sym, /*forceAddressPreserving=*/true);
+
   lower::pft::visitAllSymbols(eval, captureImplicitMap);
 
   auto targetOp = mlir::omp::TargetOp::create(firOpBuilder, loc, clauseOps);
@@ -3239,7 +3269,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   args.hasDeviceAddr.objects = hasDeviceAddrObjects;
   args.hasDeviceAddr.vars = hasDeviceAddrBaseValues;
   args.hostEvalVars = clauseOps.hostEvalVars;
-  // TODO: Add in_reduction syms and vars.
+  // `in_reduction` list items do not get their own entry block argument on
+  // `omp.target`; they are implicitly mapped (see the force-map above) and the
+  // target body accesses them through their `map_entries` block argument. The
+  // `in_reduction` operands remain on the op as host-side metadata.
   args.map.objects = mapObjects;
   args.map.vars = mapBaseValues;
   args.priv.objects = makeObjects(dsp.getDelayedPrivSymbols());
diff --git a/flang/test/Lower/OpenMP/Todo/target-inreduction.f90 b/flang/test/Lower/OpenMP/Todo/target-inreduction.f90
deleted file mode 100644
index e5a9cffac5a11..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/target-inreduction.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
-
-!===============================================================================
-! `mergeable` clause
-!===============================================================================
-
-! CHECK: not yet implemented: Unhandled clause IN_REDUCTION in TARGET construct
-subroutine omp_target_inreduction()
-  integer i
-  i = 0
-  !$omp target in_reduction(+:i)
-  i = i + 1
-  !$omp end target
-end subroutine omp_target_inreduction
diff --git a/flang/test/Lower/OpenMP/target-inreduction-common.f90 b/flang/test/Lower/OpenMP/target-inreduction-common.f90
new file mode 100644
index 0000000000000..0988d5edf778d
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-inreduction-common.f90
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! An in_reduction list item that is a member of an explicitly mapped common
+! block must still get its own address-preserving (ByRef + implicit tofrom) map
+! entry, separate from the common-block map. Otherwise the in_reduction operand
+! would have no matching map_entries entry for the host-fallback redirect.
+! Verify that the common block and the member are mapped independently and that
+! the in_reduction clause references the member.
+
+!CHECK-LABEL: func.func @_QPomp_target_in_reduction_common()
+!CHECK:       %[[CB:.*]] = fir.address_of(@cb_) : !fir.ref<!fir.array<8xi8>>
+!CHECK:       %[[IDECL:.*]]:2 = hlfir.declare %{{.*}} storage(%[[CB]][0]) {uniq_name = "_QFomp_target_in_reduction_commonEi"}
+! The whole common block is mapped by the explicit map clause.
+!CHECK:       %[[CBMAP:.*]] = omp.map.info var_ptr(%[[CB]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "cb"}
+! The in_reduction member additionally gets its own implicit ByRef map.
+!CHECK:       %[[IMAP:.*]] = omp.map.info var_ptr(%[[IDECL]]#1 : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "i"}
+!CHECK:       omp.target in_reduction(@{{.*}} %[[IDECL]]#0 : !fir.ref<i32>)
+!CHECK-SAME:    map_entries(%[[CBMAP]] -> %{{[^ ]+}}, %[[IMAP]] -> %[[IARG:[^ ]+]] : !fir.ref<!fir.array<8xi8>>, !fir.ref<i32>)
+!CHECK:         hlfir.declare %[[IARG]]
+!CHECK:         omp.terminator
+
+subroutine omp_target_in_reduction_common()
+  integer :: i, k
+  common /cb/ i, k
+  i = 0
+  !$omp target map(tofrom: /cb/) in_reduction(+:i)
+  i = i + 1
+  !$omp end target
+end subroutine omp_target_in_reduction_common
diff --git a/flang/test/Lower/OpenMP/target-inreduction-llvmir.f90 b/flang/test/Lower/OpenMP/target-inreduction-llvmir.f90
new file mode 100644
index 0000000000000..c10618aa17125
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-inreduction-llvmir.f90
@@ -0,0 +1,42 @@
+! RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-version=50 -o - %s | FileCheck %s
+
+! End-to-end coverage: Flang lowers a taskgroup task_reduction enclosing a
+! target in_reduction all the way to LLVM IR. Verifies that the full host-
+! fallback path emits __kmpc_taskred_init for the enclosing taskgroup and
+! __kmpc_task_reduction_get_th_data for the target's in_reduction lookup,
+! with load/store through the returned private pointer and no direct update
+! through the original shared pointer.
+
+subroutine target_in_reduction_e2e()
+  integer :: i
+  i = 0
+  !$omp taskgroup task_reduction(+:i)
+    !$omp target in_reduction(+:i)
+    i = i + 1
+    !$omp end target
+  !$omp end taskgroup
+end subroutine target_in_reduction_e2e
+
+! CHECK-LABEL: define void @target_in_reduction_e2e_()
+! The enclosing taskgroup emits __kmpc_taskred_init to register the
+! task_reduction descriptor.
+! CHECK:         call ptr @__kmpc_taskred_init(i32 %{{.+}}, i32 1, ptr %{{.+}})
+
+! The host stub calls the outlined target body passing the captured pointer.
+! CHECK:         call void @__omp_offloading_{{.*}}_target_in_reduction_e2e_{{.*}}(ptr %{{.+}}, ptr null)
+
+! Inside the outlined target body, the in_reduction private pointer is
+! obtained from the runtime using the captured original pointer with a NULL
+! descriptor (the runtime walks enclosing taskgroups). All loads and stores
+! go through the returned private pointer.
+! CHECK-LABEL: define internal void @__omp_offloading_{{.*}}_target_in_reduction_e2e_
+! CHECK-SAME:    (ptr %[[ORIG:.+]], ptr %{{.+}})
+! CHECK:         %[[GTID:.+]] = call i32 @__kmpc_global_thread_num(
+! CHECK:         %[[PRIV:.+]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[GTID]], ptr null, ptr %[[ORIG]])
+! CHECK:         %[[VAL:.+]] = load i32, ptr %[[PRIV]]
+! CHECK:         %[[SUM:.+]] = add i32 %[[VAL]], 1
+! CHECK:         store i32 %[[SUM]], ptr %[[PRIV]]
+
+! The outlined body must not store directly through the captured original
+! pointer; all updates go through the runtime-returned private copy.
+! CHECK-NOT:     store i32 %{{.+}}, ptr %[[ORIG]]
diff --git a/flang/test/Lower/OpenMP/target-inreduction-unused.f90 b/flang/test/Lower/OpenMP/target-inreduction-unused.f90
new file mode 100644
index 0000000000000..c002846494916
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-inreduction-unused.f90
@@ -0,0 +1,27 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! Per the OpenMP spec, an in_reduction list item on a target construct is
+! implicitly data-mapped. The lowering must not rely on the variable being
+! referenced inside the target body to discover that map: here `i` only
+! appears in the in_reduction clause and is never read or written inside
+! the region. Verify that an omp.map.info for `i` is still emitted and
+! flows into the omp.target's map_entries.
+
+!CHECK-LABEL: func.func @_QPomp_target_in_reduction_unused()
+!CHECK:       %[[IDECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFomp_target_in_reduction_unusedEi"}
+!CHECK:       %[[IMAP:.*]] = omp.map.info var_ptr(%[[IDECL]]#1 : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "i"}
+!CHECK:       omp.target in_reduction(@{{[^ ]+}} %[[IDECL]]#0 : !fir.ref<i32>)
+!CHECK-SAME:    map_entries(%[[IMAP]] -> %{{[^ ]+}} : !fir.ref<i32>)
+
+subroutine omp_target_in_reduction_unused()
+  interface
+    subroutine sub()
+    end subroutine
+  end interface
+  integer i
+  i = 0
+  !$omp target in_reduction(+:i)
+  call sub()
+  !$omp end target
+end subroutine omp_target_in_reduction_unused
diff --git a/flang/test/Lower/OpenMP/target-inreduction.f90 b/flang/test/Lower/OpenMP/target-inreduction.f90
new file mode 100644
index 0000000000000..40935dc109e94
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-inreduction.f90
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! Verify that in_reduction on a target construct is lowered to an
+! omp.target with both an in_reduction clause and an implicit map_entries
+! entry for the same variable. The in_reduction clause does not define an
+! entry block argument: inside the target body the variable is accessed
+! through its map_entries block argument. The implicit map also captures the
+! original pointer into the target region so the MLIR -> LLVM IR translation
+! can pass it to __kmpc_task_reduction_get_th_data.
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME:  @[[RED_I32_NAME:.*]] : i32 init {
+
+!CHECK-LABEL: func.func @_QPomp_target_in_reduction()
+!CHECK:       %[[IDECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFomp_target_in_reductionEi"}
+!CHECK:       %[[IMAP:.*]] = omp.map.info var_ptr(%[[IDECL]]#1 : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "i"}
+!CHECK:       omp.target in_reduction(@[[RED_I32_NAME]] %[[IDECL]]#0 : !fir.ref<i32>)
+!CHECK-SAME:    map_entries(%[[IMAP]] -> %[[MAPARG:[^ ]+]] : !fir.ref<i32>)
+!CHECK:         hlfir.declare %[[MAPARG]]
+!CHECK:         omp.terminator
+!CHECK:       }
+
+subroutine omp_target_in_reduction()
+  integer i
+  i = 0
+  !$omp target in_reduction(+:i)
+  i = i + 1
+  !$omp end target
+end subroutine omp_target_in_reduction
diff --git a/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir b/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir
index 2df9c5a8c0713..4c397db867a04 100644
--- a/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir
+++ b/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir
@@ -432,12 +432,17 @@ module attributes {omp.is_target_device = true} {
     omp.target_data device(%int : i32) if(%bool) map_entries(%m0 : !fir.ref<i32>) {
       omp.terminator
     }
-    // CHECK-NEXT: omp.target allocate({{[^)]*}}) thread_limit({{[^)]*}}) in_reduction({{[^)]*}}) private({{[^)]*}}) {
+    // The `in_reduction` list item is force-mapped, so it is also captured by a
+    // matching `map_entries` entry referring to the same variable.
+    // CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[REF]] : !fir.ref<i32>, i32)
+    // CHECK-NEXT: omp.target allocate({{[^)]*}}) in_reduction({{[^)]*}}) thread_limit({{[^)]*}}) map_entries(%[[MAP]] -> {{[^)]*}}) private({{[^)]*}}) {
+    %m1 = omp.map.info var_ptr(%ref : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32>
     omp.target allocate(%ref : !fir.ref<i32> -> %ref : !fir.ref<i32>)
                depend(taskdependin -> %ref : !fir.ref<i32>)
                device(%int : i32) if(%bool) thread_limit(%int : i32)
-               in_reduction(@reduction %ref -> %arg0 : !fir.ref<i32>)
-               private(@privatizer %ref -> %arg1 : !fir.ref<i32>) {
+               in_reduction(@reduction %ref : !fir.ref<i32>)
+               map_entries(%m1 -> %arg1 : !fir.ref<i32>)
+               private(@privatizer %ref -> %arg2 : !fir.ref<i32>) {
       omp.terminator
     }
     // CHECK-NOT: omp.target_enter_data
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 71e66018bf10e..b521f730ec130 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -790,9 +790,17 @@ class OpenMP_InReductionClauseSkip<
     OptionalAttr<SymbolRefArrayAttr>:$in_reduction_syms
   );
 
-  // Description varies depending on the operation. Assembly format not defined
-  // because this clause must be processed together with the first region of the
-  // operation, as it defines entry block arguments.
+  // This assembly format should only be used by operations where `in_reduction`
+  // does not define entry block arguments (e.g. `omp.target`). Otherwise, it
+  // must be printed and parsed together with the corresponding region, because
+  // it defines entry block arguments.
+  let optAssemblyFormat = [{
+    `in_reduction` `(`
+      custom<InReductionClause>($in_reduction_vars, type($in_reduction_vars),
+                                $in_reduction_byref, $in_reduction_syms) `)`
+  }];
+
+  // Description varies depending on the operation.
 }
 
 def OpenMP_InReductionClause : OpenMP_InReductionClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 4f54d4b8f524b..3be6c0dee759e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -966,9 +966,10 @@ def TaskOp
                     // TODO: Complete clause list (detach).
                     OpenMP_AffinityClause, OpenMP_AllocateClause,
                     OpenMP_DependClause, OpenMP_FinalClause, OpenMP_IfClause,
-                    OpenMP_InReductionClause, OpenMP_MergeableClause,
-                    OpenMP_PriorityClause, OpenMP_PrivateClause,
-                    OpenMP_UntiedClause, OpenMP_DetachClause],
+                    OpenMP_InReductionClauseSkip<assemblyFormat = true>,
+                    OpenMP_MergeableClause, OpenMP_PriorityClause,
+                    OpenMP_PrivateClause, OpenMP_UntiedClause,
+                    OpenMP_DetachClause],
                 singleRegion = true> {
   let summary = "task construct";
   let description = [{
@@ -1011,9 +1012,10 @@ def TaskloopContextOp : OpenMP_Op<"taskloop.context", traits = [
     DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_FinalClause, OpenMP_GrainsizeClause,
-    OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_MergeableClause,
-    OpenMP_NogroupClause, OpenMP_NumTasksClause, OpenMP_PriorityClause,
-    OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_UntiedClause
+    OpenMP_IfClause, OpenMP_InReductionClauseSkip<assemblyFormat = true>,
+    OpenMP_MergeableClause, OpenMP_NogroupClause, OpenMP_NumTasksClause,
+    OpenMP_PriorityClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+    OpenMP_UntiedClause
   ], singleRegion = true> {
   let summary = "OutlinableOpenMPOpInterface wrapper for taskloop construct";
   let description = [{
@@ -1617,6 +1619,15 @@ def TargetOp : OpenMP_Op<"target", traits = [
   ];
 
   let extraClassDeclaration = [{
+    // Override BlockArgOpenMPOpInterface method because `in_reduction` list
+    // items on `omp.target` do not define entry block arguments. The reduction
+    // variable is accessed inside the target body through its matching
+    // `map_entries` block argument; the `in_reduction` operands are kept only
+    // as host-side metadata used to look up the per-task private storage.
+    unsigned numInReductionBlockArgs() {
+      return 0;
+    }
+
     mlir::Value getMappedValueForPrivateVar(unsigned privVarIdx) {
       std::optional<DenseI64ArrayAttr> privateMapIdices = getPrivateMapsAttr();
 
@@ -1660,8 +1671,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
     ( `map_iterated` `(` $map_iterated^ `:` type($map_iterated) `)` )?
     custom<TargetOpRegion>(
         $region, $has_device_addr_vars, type($has_device_addr_vars),
-        $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
-        type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
+        $host_eval_vars, type($host_eval_vars),
         $map_vars, type($map_vars), $private_vars, type($private_vars),
         $private_syms, $private_needs_barrier, $private_maps) attr-dict
   }];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c2956a1cf7b79..2488720f3b604 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1132,8 +1132,10 @@ static ParseResult parseClauseWithRegionArgs(
         if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
           return failure();
 
-        if (parser.parseOperand(operands.emplace_back()) ||
-            parser.parseArrow() ||
+        if (parser.parseOperand(operands.emplace_back()))
+          return failure();
+
+        if (parser.parseArrow() ||
             parser.parseArgument(regionPrivateArgs.emplace_back()))
           return failure();
 
@@ -1197,6 +1199,42 @@ static ParseResult parseClauseWithRegionArgs(
   return success();
 }
 
+/// Parses an `in_reduction` clause for an operation that does not give its
+/// list items entry block arguments (e.g. `omp.target`). The expected format is
+/// a comma-separated list of `[byref] @sym %var` followed by `: types`.
+static ParseResult parseInReductionClause(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
+    SmallVectorImpl<Type> &inReductionTypes,
+    DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms) {
+  SmallVector<SymbolRefAttr> symbolVec;
+  SmallVector<bool> isByRefVec;
+
+  if (parser.parseCommaSeparatedList([&]() {
+        isByRefVec.push_back(parser.parseOptionalKeyword("byref").succeeded());
+        if (parser.parseAttribute(symbolVec.emplace_back()) ||
+            parser.parseOperand(inReductionVars.emplace_back()))
+          return failure();
+        return success();
+      }))
+    return failure();
+
+  if (parser.parseColon())
+    return failure();
+
+  if (parser.parseCommaSeparatedList(
+          [&]() { return parser.parseType(inReductionTypes.emplace_back()); }))
+    return failure();
+
+  if (inReductionVars.size() != inReductionTypes.size())
+    return failure();
+
+  inReductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
+  SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
+  inReductionSyms = ArrayAttr::get(parser.getContext(), symbolAttrs);
+  return success();
+}
+
 static ParseResult parseBlockArgClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
@@ -1305,9 +1343,6 @@ static ParseResult parseTargetOpRegion(
     SmallVectorImpl<Type> &hasDeviceAddrTypes,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
     SmallVectorImpl<Type> &hostEvalTypes,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
-    SmallVectorImpl<Type> &inReductionTypes,
-    DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
@@ -1316,8 +1351,6 @@ static ParseResult parseTargetOpRegion(
   AllRegionParseArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            privateNeedsBarrier, &privateMaps);
@@ -1507,6 +1540,42 @@ static void printClauseWithRegionArgs(
     p << getPrivateNeedsBarrierSpelling() << " ";
 }
 
+/// Prints an `in_reduction` clause for an operation that does not give its list
+/// items entry block arguments (e.g. `omp.target`). Mirrors
+/// `parseInReductionClause`.
+static void printInReductionClause(OpAsmPrinter &p, Operation *op,
+                                   ValueRange inReductionVars,
+                                   TypeRange inReductionTypes,
+                                   DenseBoolArrayAttr inReductionByref,
+                                   ArrayAttr inReductionSyms) {
+  MLIRContext *ctx = op->getContext();
+
+  ArrayAttr syms = inReductionSyms;
+  if (!syms) {
+    SmallVector<Attribute> values(inReductionVars.size(), nullptr);
+    syms = ArrayAttr::get(ctx, values);
+  }
+
+  DenseBoolArrayAttr byref = inReductionByref;
+  if (!byref) {
+    SmallVector<bool> values(inReductionVars.size(), false);
+    byref = DenseBoolArrayAttr::get(ctx, values);
+  }
+
+  llvm::interleaveComma(
+      llvm::zip_equal(inReductionVars, syms.getValue(), byref.asArrayRef()), p,
+      [&p](auto t) {
+        auto [var, sym, isByRef] = t;
+        if (isByRef)
+          p << "byref ";
+        if (sym)
+          p << sym << " ";
+        p << var;
+      });
+  p << " : ";
+  llvm::interleaveComma(inReductionTypes, p);
+}
+
 static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
                                 StringRef clauseName, ValueRange argsSubrange,
                                 std::optional<MapPrintArgs> mapArgs) {
@@ -1568,20 +1637,18 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 // These parseXyz functions correspond to the custom<Xyz> definitions
 // in the .td file(s).
-static void printTargetOpRegion(
-    OpAsmPrinter &p, Operation *op, Region &region,
-    ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
-    ValueRange hostEvalVars, TypeRange hostEvalTypes,
-    ValueRange inReductionVars, TypeRange inReductionTypes,
-    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
-    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
-    TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
-    DenseI64ArrayAttr privateMaps) {
+static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
+                                ValueRange hasDeviceAddrVars,
+                                TypeRange hasDeviceAddrTypes,
+                                ValueRange hostEvalVars,
+                                TypeRange hostEvalTypes, ValueRange mapVars,
+                                TypeRange mapTypes, ValueRange privateVars,
+                                TypeRange privateTypes, ArrayAttr privateSyms,
+                                UnitAttr privateNeedsBarrier,
+                                DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            privateNeedsBarrier, privateMaps);
@@ -2579,8 +2646,7 @@ LogicalResult TargetUpdateOp::verify() {
 void TargetOp::build(OpBuilder &builder, OperationState &state,
                      const TargetOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
-  // inReductionByref, inReductionSyms.
+  // TODO Store clauses in op: allocateVars, allocatorVars.
   TargetOp::build(
       builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
       makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
@@ -2588,9 +2654,10 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
       clauses.device, clauses.dynGroupprivateAccessGroup,
       clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
       clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
-      /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
-      /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
-      clauses.mapIterated, clauses.nowait, clauses.privateVars,
+      clauses.inReductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+      makeArrayAttr(ctx, clauses.inReductionSyms), clauses.isDevicePtrVars,
+      clauses.mapVars, clauses.mapIterated, clauses.nowait, clauses.privateVars,
       makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
       clauses.threadLimitVars,
       /*private_maps=*/nullptr);
@@ -2617,6 +2684,11 @@ LogicalResult TargetOp::verify() {
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  if (failed(verifyReductionVarList(*this, getInReductionSyms(),
+                                    getInReductionVars(),
+                                    getInReductionByref())))
+    return failure();
+
   return verifyPrivateVarsMapping(*this);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 703f72d1ab5bc..f2512f82fe22c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -370,9 +370,38 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       op.emitWarning("hint clause discarded");
   };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
-        op.getInReductionSyms())
+    if (isa<omp::TargetOp, omp::TaskOp, omp::TaskloopContextOp>(
+            op.getOperation())) {
+      if (auto byrefAttr = op.getInReductionByref())
+        for (bool isByRef : *byrefAttr)
+          if (isByRef) {
+            result = todo("in_reduction with byref modifier");
+            return;
+          }
+      if (isa<omp::TargetOp>(op.getOperation())) {
+        if (auto inReductionSyms = op.getInReductionSyms()) {
+          for (auto sym :
+               (*inReductionSyms).template getAsRange<SymbolRefAttr>()) {
+            auto decl =
+                SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
+                    op, sym);
+            assert(decl &&
+                   "symbol resolution should be guaranteed by the op verifier");
+            if (decl.getInitializerRegion().front().getNumArguments() != 1) {
+              result = todo("in_reduction with two-argument initializer");
+              return;
+            }
+            if (!decl.getCleanupRegion().empty()) {
+              result = todo("in_reduction with cleanup region");
+              return;
+            }
+          }
+        }
+      }
+    } else if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
+               op.getInReductionSyms()) {
       result = todo("in_reduction");
+    }
   };
   auto checkNowait = [&todo](auto op, LogicalResult &result) {
     if (op.getNowait())
@@ -411,14 +440,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           return;
         }
   };
-  auto checkInReductionByref = [&todo](auto op, LogicalResult &result) {
-    if (auto byrefAttr = op.getInReductionByref())
-      for (bool isByRef : *byrefAttr)
-        if (isByRef) {
-          result = todo("in_reduction with byref modifier");
-          return;
-        }
-  };
   auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
@@ -470,7 +491,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
-        checkInReductionByref(op, result);
+        checkInReduction(op, result);
       })
       .Case([&](omp::TaskgroupOp op) {
         checkAllocate(op, result);
@@ -482,7 +503,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::TaskloopContextOp op) {
         checkAllocate(op, result);
-        checkInReductionByref(op, result);
+        checkInReduction(op, result);
         checkReduction(op, result);
         checkReductionByref(op, result);
       })
@@ -8363,6 +8384,52 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   bool isOffloadEntry =
       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
 
+  // Resolve in_reduction clauses on omp.target for the host. From the target
+  // device's perspective an in_reduction list item behaves as a regular
+  // map(tofrom) variable, so no special handling is needed there; only the
+  // host redirects the mapped value to the per-task reduction-private storage
+  // returned by __kmpc_task_reduction_get_th_data (emitted inside the
+  // to-be-outlined target task body). This applies to both offloading and
+  // non-offloading host modules.
+  //
+  // The target body has no dedicated in_reduction block argument: each
+  // in_reduction variable is accessed through its map_entries block argument.
+  // So each in_reduction variable must also be captured by a matching
+  // map_entries entry referring to the same value; without one the outlined
+  // body would reference a value defined in the host function. Record, for each
+  // in_reduction variable, the position of that map entry so the corresponding
+  // map block argument can be redirected inside the body. The mapped pointer is
+  // also used as the `orig` argument of the runtime lookup.
+  SmallVector<llvm::Value *> inRedOrigPtrs;
+  SmallVector<unsigned> inRedMapArgIdx;
+  if (!targetOp.getInReductionVars().empty() && !isTargetDevice) {
+    llvm::SmallDenseMap<Value, unsigned> mapVarPtrToArgIdx;
+    llvm::SmallDenseSet<Value> duplicateMapVarPtrs;
+    for (auto [idx, mapV] : llvm::enumerate(targetOp.getMapVars())) {
+      auto mapInfo = mapV.getDefiningOp<omp::MapInfoOp>();
+      auto [it, inserted] =
+          mapVarPtrToArgIdx.try_emplace(mapInfo.getVarPtr(), idx);
+      if (!inserted)
+        duplicateMapVarPtrs.insert(mapInfo.getVarPtr());
+    }
+    inRedOrigPtrs.reserve(targetOp.getInReductionVars().size());
+    inRedMapArgIdx.reserve(targetOp.getInReductionVars().size());
+    for (Value v : targetOp.getInReductionVars()) {
+      if (duplicateMapVarPtrs.contains(v))
+        return targetOp.emitError()
+               << "in_reduction variable on omp.target has multiple matching "
+                  "map_entries entries for the same var_ptr; the redirect "
+                  "target is ambiguous";
+      auto it = mapVarPtrToArgIdx.find(v);
+      if (it == mapVarPtrToArgIdx.end())
+        return targetOp.emitError()
+               << "not yet implemented: in_reduction variable on omp.target "
+                  "must also be captured by a matching map_entries entry";
+      inRedMapArgIdx.push_back(it->second);
+      inRedOrigPtrs.push_back(moduleTranslation.lookupValue(v));
+    }
+  }
+
   // For some private variables, the MapsForPrivatizedVariablesPass
   // creates MapInfoOp instances. Go through the private variables and
   // the mapped variables so that during codegeneration we are able
@@ -8438,8 +8505,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
         attr.isStringAttribute())
       llvmOutlinedFn->addFnAttr(attr);
 
-    for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
-      auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
+    for (auto [idx, arg] : llvm::enumerate(mapBlockArgs)) {
+      // in_reduction list items on omp.target are accessed through their
+      // map_entries block argument, which is redirected below to the per-task
+      // reduction-private storage returned by the runtime. Skip the default
+      // host-value mapping for those block arguments so the write-once
+      // mapValue mapping is free to be set to the private pointer.
+      if (llvm::is_contained(inRedMapArgIdx, idx))
+        continue;
+      auto mapInfoOp = cast<omp::MapInfoOp>(mapVars[idx].getDefiningOp());
       llvm::Value *mapOpValue =
           moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
       moduleTranslation.mapValue(arg, mapOpValue);
@@ -8475,6 +8549,53 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
             targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // The target body accesses each in_reduction variable through its
+    // map_entries block argument. Redirect that block argument to the per-task
+    // private storage returned by __kmpc_task_reduction_get_th_data so the body
+    // accumulates into the reduction-private copy rather than the mapped
+    // original. The lookup must run inside the target task body so the gtid
+    // corresponds to the executing thread. The descriptor argument is NULL: the
+    // runtime walks enclosing taskgroups to locate the matching task_reduction
+    // registration for `origPtr`. Mirrors the in_reduction handling on
+    // omp.taskloop.context.
+    if (!inRedOrigPtrs.empty()) {
+      llvm::OpenMPIRBuilder &ompB = *ompBuilder;
+      llvm::Module *m = moduleTranslation.getLLVMModule();
+      llvm::LLVMContext &llvmCtx = m->getContext();
+      uint32_t srcLocSize;
+      llvm::Constant *srcLocStr = ompB.getOrCreateDefaultSrcLocStr(srcLocSize);
+      llvm::Value *bodyIdent = ompB.getOrCreateIdent(srcLocStr, srcLocSize);
+      llvm::Function *gtidFn = ompB.getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_global_thread_num);
+      llvm::Value *bodyGtid =
+          builder.CreateCall(gtidFn, {bodyIdent}, "omp_global_thread_num");
+      llvm::FunctionCallee getThData = ompB.getOrCreateRuntimeFunction(
+          *m, llvm::omp::OMPRTL___kmpc_task_reduction_get_th_data);
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(llvmCtx);
+      llvm::Value *nullDesc = llvm::ConstantPointerNull::get(ptrTy);
+      for (auto [mapArgIdx, origPtr] :
+           llvm::zip_equal(inRedMapArgIdx, inRedOrigPtrs)) {
+        // The runtime entry point takes (and returns) a generic,
+        // default-address-space `ptr`, so normalize a
+        // non-default-address-space original pointer to the generic address
+        // space before the call, and cast the returned private pointer back to
+        // the map block argument's address space when that differs. Mirrors the
+        // in_reduction handling on omp.taskloop.context.
+        BlockArgument mapBlockArg = mapBlockArgs[mapArgIdx];
+        if (auto *origPtrTy =
+                llvm::dyn_cast<llvm::PointerType>(origPtr->getType());
+            origPtrTy && origPtrTy->getAddressSpace() != 0)
+          origPtr = builder.CreateAddrSpaceCast(origPtr, ptrTy);
+        llvm::Value *priv = builder.CreateCall(
+            getThData, {bodyGtid, nullDesc, origPtr}, "omp.inred.priv");
+        if (auto *argPtrTy = llvm::dyn_cast<llvm::PointerType>(
+                moduleTranslation.convertType(mapBlockArg.getType()));
+            argPtrTy && argPtrTy->getAddressSpace() != 0)
+          priv = builder.CreateAddrSpaceCast(priv, argPtrTy);
+        moduleTranslation.mapValue(mapBlockArg, priv);
+      }
+    }
+
     LLVM::ModuleTranslation::SaveStack<OpenMPAllocStackFrame> frame(
         moduleTranslation, allocaIP, deallocBlocks);
     llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index b72370edb3cb4..f9c55830afbc7 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3129,6 +3129,66 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
 
 // -----
 
+func.func @omp_target_in_reduction_unresolved(%ptr: !llvm.ptr) {
+  // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
+  omp.target in_reduction(@add_f32 %ptr : !llvm.ptr) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func.func @omp_target_in_reduction_duplicate(%ptr: !llvm.ptr) {
+  // expected-error @below {{op accumulator variable used more than once}}
+  omp.target in_reduction(@add_f32 %ptr, @add_f32 %ptr : !llvm.ptr, !llvm.ptr) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg: i32):
+  %0 = arith.constant 0 : i32
+  omp.yield (%0 : i32)
+}
+combiner {
+^bb1(%arg0: i32, %arg1: i32):
+  %1 = arith.addi %arg0, %arg1 : i32
+  omp.yield (%1 : i32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+  %2 = llvm.load %arg3 : !llvm.ptr -> i32
+  llvm.atomicrmw add %arg2, %2 monotonic : !llvm.ptr, i32
+  omp.yield
+}
+
+func.func @omp_target_in_reduction_type_mismatch(%mem: memref<1xf32>) {
+  // expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr')}}
+  omp.target in_reduction(@add_i32 %mem : memref<1xf32>) {
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_distribute_schedule(%chunk_size : i32, %lb : i32, %ub : i32, %step : i32) -> () {
   // expected-error @below {{op chunk size set without dist_schedule_static being present}}
   "omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 0, 0, 1, 0>}> ({
diff --git a/mlir/test/Target/LLVMIR/openmp-target-in-reduction-multi.mlir b/mlir/test/Target/LLVMIR/openmp-target-in-reduction-multi.mlir
new file mode 100644
index 0000000000000..8083f3c299ce0
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-in-reduction-multi.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// Multiple in_reduction items on omp.target. Each item is captured into the
+// target region through its own map_entries entry and accessed inside the
+// body via the corresponding map_entries block argument. For the host
+// fallback path every item performs an independent
+// __kmpc_task_reduction_get_th_data lookup using its own captured original
+// pointer, and the returned per-task private pointer is bound to that item's
+// map block argument. This test pins down the pairing so it cannot pass if the
+// two items were swapped or collapsed onto a single pointer.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @target_inreduction_multi(%x : !llvm.ptr, %y : !llvm.ptr) {
+  %mx = omp.map.info var_ptr(%x : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr
+  %my = omp.map.info var_ptr(%y : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr
+  omp.target in_reduction(@add_i32 %x, @add_i32 %y : !llvm.ptr, !llvm.ptr)
+      map_entries(%mx -> %mxarg, %my -> %myarg : !llvm.ptr, !llvm.ptr) {
+    // First item (x): load, += 1, store back.
+    %vx = llvm.load %mxarg : !llvm.ptr -> i32
+    %c1 = llvm.mlir.constant(1 : i32) : i32
+    %sx = llvm.add %vx, %c1 : i32
+    llvm.store %sx, %mxarg : i32, !llvm.ptr
+    // Second item (y): load, += 2, store back.
+    %vy = llvm.load %myarg : !llvm.ptr -> i32
+    %c2 = llvm.mlir.constant(2 : i32) : i32
+    %sy = llvm.add %vy, %c2 : i32
+    llvm.store %sy, %myarg : i32, !llvm.ptr
+    omp.terminator
+  }
+  llvm.return
+}
+
+// The host stub forwards both captured pointers into the outlined target
+// kernel (the trailing argument is the unused descriptor slot).
+// CHECK-LABEL: define void @target_inreduction_multi(
+// CHECK:         call void @__omp_offloading_{{.*}}_target_inreduction_multi_{{.*}}(ptr %{{.+}}, ptr %{{.+}}, ptr null)
+
+// The two captured original pointers arrive as distinct kernel arguments.
+// CHECK-LABEL: define internal void @__omp_offloading_{{.*}}_target_inreduction_multi_
+// CHECK-SAME:    (ptr %[[CAPTX:.+]], ptr %[[CAPTY:.+]], ptr %{{.+}})
+
+// A single gtid is shared by both lookups; each item then performs its own
+// __kmpc_task_reduction_get_th_data call against its own captured pointer.
+// CHECK:         %[[GTID:.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK:         %[[PRIVX:.+]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[GTID]], ptr null, ptr %[[CAPTX]])
+// CHECK:         %[[PRIVY:.+]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[GTID]], ptr null, ptr %[[CAPTY]])
+
+// The first item's private storage is the base of the +1 load/store; the
+// CHECK-NOT below ensures the second item's pointer is not touched until the
+// first item's accumulation has completed (i.e. the items are not swapped or
+// merged onto a single private pointer).
+// CHECK:         %[[LX:.+]] = load i32, ptr %[[PRIVX]]
+// CHECK:         %[[SX:.+]] = add i32 %[[LX]], 1
+// CHECK-NOT:     %[[PRIVY]]
+// CHECK:         store i32 %[[SX]], ptr %[[PRIVX]]
+
+// The second item's private storage is the base of the +2 load/store.
+// CHECK:         %[[LY:.+]] = load i32, ptr %[[PRIVY]]
+// CHECK:         %[[SY:.+]] = add i32 %[[LY]], 2
+// CHECK:         store i32 %[[SY]], ptr %[[PRIVY]]
+
+// Exactly two reduction lookups are emitted; no third call sneaks in. The
+// `call` form is used so this does not match the runtime declaration.
+// CHECK-NOT:     call ptr @__kmpc_task_reduction_get_th_data
diff --git a/mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir
new file mode 100644
index 0000000000000..361cb699bc21a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-in-reduction.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// in_reduction on omp.target: the in_reduction variable is also captured
+// into the target region as a map entry (the Flang front-end emits this
+// implicit map). The in_reduction clause does not define an entry block
+// argument; inside the target body the variable is accessed through its
+// map_entries block argument. The captured pointer is passed to
+// __kmpc_task_reduction_get_th_data with a NULL descriptor; the runtime
+// walks enclosing taskgroups to locate the matching task_reduction
+// registration. The returned per-task private pointer is bound to the
+// map_entries block argument so subsequent loads/stores inside the region
+// use the private copy.
+
+omp.declare_reduction @add_i32 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @target_inreduction(%x : !llvm.ptr) {
+  %m = omp.map.info var_ptr(%x : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr
+  omp.target in_reduction(@add_i32 %x : !llvm.ptr) map_entries(%m -> %marg : !llvm.ptr) {
+    %v = llvm.load %marg : !llvm.ptr -> i32
+    %c1 = llvm.mlir.constant(1 : i32) : i32
+    %s = llvm.add %v, %c1 : i32
+    llvm.store %s, %marg : i32, !llvm.ptr
+    omp.terminator
+  }
+  llvm.return
+}
+
+// The host stub forwards the captured pointer into the outlined target
+// kernel.
+// CHECK-LABEL: define void @target_inreduction(
+// CHECK:         call void @__omp_offloading_{{.*}}_target_inreduction_{{.*}}(ptr %{{.+}}, ptr null)
+
+// In the outlined target body the in_reduction private pointer is
+// obtained from the runtime using the captured original pointer; that
+// pointer is then the base of the load and store inside the region.
+// CHECK-LABEL: define internal void @__omp_offloading_{{.*}}_target_inreduction_
+// CHECK-SAME:    (ptr %[[CAPT:.+]], ptr %{{.+}})
+// CHECK:         %[[GTID:.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK:         %[[PRIV:.+]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[GTID]], ptr null, ptr %[[CAPT]])
+// CHECK:         %[[LOADED:.+]] = load i32, ptr %[[PRIV]]
+// CHECK:         %[[SUM:.+]] = add i32 %[[LOADED]], 1
+// CHECK:         store i32 %[[SUM]], ptr %[[PRIV]]
+
+// -----
+
+// Same as the first case but the in_reduction variable lives in a non-default
+// address space (addrspace(1)). __kmpc_task_reduction_get_th_data is declared
+// to take and return a generic (default-AS) `ptr`, so the host-fallback
+// lowering must (1) addrspacecast the captured addrspace(1) original pointer to
+// generic `ptr` before the runtime lookup, and (2) addrspacecast the returned
+// generic private pointer back to addrspace(1) so the body's load/store use the
+// private copy in the right address space. This pins down the address-space
+// handling so a regression that passed the addrspace(1) pointer straight into
+// the runtime call (a bad-signature crash) cannot pass.
+
+omp.declare_reduction @add_i32_as1 : i32
+init {
+^bb0(%arg0: i32):
+  %c0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%c0 : i32)
+}
+combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %s = llvm.add %arg0, %arg1 : i32
+  omp.yield(%s : i32)
+}
+
+llvm.func @target_inreduction_as1(%x : !llvm.ptr<1>) {
+  %m = omp.map.info var_ptr(%x : !llvm.ptr<1>, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<1>
+  omp.target in_reduction(@add_i32_as1 %x : !llvm.ptr<1>) map_entries(%m -> %marg : !llvm.ptr<1>) {
+    %v = llvm.load %marg : !llvm.ptr<1> -> i32
+    %c1 = llvm.mlir.constant(1 : i32) : i32
+    %s = llvm.add %v, %c1 : i32
+    llvm.store %s, %marg : i32, !llvm.ptr<1>
+    omp.terminator
+  }
+  llvm.return
+}
+
+// In the outlined target body the addrspace(1) original pointer is normalized
+// to generic `ptr` before the NULL-descriptor in_reduction lookup, and the
+// returned generic private pointer is cast back to addrspace(1) for the body's
+// load and store. The original captured pointer arrives as the addrspace(1)
+// kernel argument and is only used to derive the generic lookup pointer.
+// CHECK-LABEL: define internal void @__omp_offloading_{{.*}}_target_inreduction_as1_
+// CHECK-SAME:    (ptr addrspace(1) %[[CAPT_AS1:.+]], ptr %{{.+}})
+// CHECK:         %[[GTID_AS1:.+]] = call i32 @__kmpc_global_thread_num(
+// CHECK:         %[[ORIG_GEN:.+]] = addrspacecast ptr addrspace(1) %[[CAPT_AS1]] to ptr
+// CHECK:         %[[PRIV_AS1_GEN:.+]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[GTID_AS1]], ptr null, ptr %[[ORIG_GEN]])
+// CHECK:         %[[PRIV_AS1:.+]] = addrspacecast ptr %[[PRIV_AS1_GEN]] to ptr addrspace(1)
+// CHECK:         %[[LOADED_AS1:.+]] = load i32, ptr addrspace(1) %[[PRIV_AS1]]
+// CHECK:         %[[SUM_AS1:.+]] = add i32 %[[LOADED_AS1]], 1
+// CHECK:         store i32 %[[SUM_AS1]], ptr addrspace(1) %[[PRIV_AS1]]
+
+// The body must not load or store through the original captured addrspace(1)
+// pointer; all accesses go through the runtime-returned private copy.
+// CHECK-NOT:     store i32 %{{.+}}, ptr addrspace(1) %[[CAPT_AS1]]
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 377a5bb799be4..efafdb0535560 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -190,10 +190,117 @@ atomic {
   llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
   omp.yield
 }
-llvm.func @target_in_reduction(%x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause in_reduction in omp.target operation}}
+llvm.func @target_in_reduction_byref(%x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause in_reduction with byref modifier in omp.target operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.target}}
-  omp.target in_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
+  omp.target in_reduction(byref @add_f32 %x : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_cleanup_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+cleanup {
+^bb2(%arg2: f32):
+  omp.yield
+}
+llvm.func @target_in_reduction_cleanup(%x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause in_reduction with cleanup region in omp.target operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_cleanup_f32 %x : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_two_arg_init_i32 : !llvm.ptr alloc {
+^bb0(%arg: !llvm.ptr):
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
+  omp.yield(%1 : !llvm.ptr)
+} init {
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  llvm.store %0, %arg1 : i32, !llvm.ptr
+  omp.yield(%arg1 : !llvm.ptr)
+} combiner {
+^bb1(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  %0 = llvm.load %arg0 : !llvm.ptr -> i32
+  %1 = llvm.load %arg1 : !llvm.ptr -> i32
+  %2 = llvm.add %0, %1 : i32
+  llvm.store %2, %arg0 : i32, !llvm.ptr
+  omp.yield(%arg0 : !llvm.ptr)
+}
+llvm.func @target_in_reduction_two_arg_init(%x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause in_reduction with two-argument initializer in omp.target operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_two_arg_init_i32 %x : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_no_map_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+llvm.func @target_in_reduction_no_map(%x : !llvm.ptr) {
+  // The in_reduction variable %x has no matching map_entries entry. The
+  // outlined target kernel would otherwise reference %x across function
+  // boundaries; the translation must reject this up front.
+  // expected-error at below {{not yet implemented: in_reduction variable on omp.target must also be captured by a matching map_entries entry}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_no_map_f32 %x : !llvm.ptr) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+omp.declare_reduction @add_dup_map_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb0(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+llvm.func @target_in_reduction_duplicate_map(%x : !llvm.ptr) {
+  // The in_reduction variable %x has two matching map_entries entries for the
+  // same var_ptr. The translation cannot disambiguate which map block argument
+  // to redirect, so it must reject this as ambiguous.
+  %m1 = omp.map.info var_ptr(%x : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr
+  %m2 = omp.map.info var_ptr(%x : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr
+  // expected-error at below {{in_reduction variable on omp.target has multiple matching map_entries entries for the same var_ptr; the redirect target is ambiguous}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target in_reduction(@add_dup_map_f32 %x : !llvm.ptr) map_entries(%m1 -> %arg1, %m2 -> %arg2 : !llvm.ptr, !llvm.ptr) {
     omp.terminator
   }
   llvm.return



More information about the flang-commits mailing list