[flang] [llvm] [flang][openmp] Fix GPU byref reduction descriptor initialization (PR #178934)

Sunil Shrestha via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 30 10:11:04 PST 2026


https://github.com/sshrestha-aa created https://github.com/llvm/llvm-project/pull/178934

When generating GPU reduction code for arrays passed by reference, only the base_ptr field was initialized in the shuffled descriptor, leaving extent, stride, and rank fields uninitialized. This caused garbage metadata to be passed to user reduction combiners, resulting in incorrect iteration bounds and crashes on GPU targets.

Fix by copying the entire source descriptor and then updating the base_ptr to point to thread-private storage. This preserves all metadata (extents, strides, rank) while correctly pointing to the shuffled data location.

The fix applies to three reduction helper functions:
- _omp_reduction_shuffle_and_reduce_func (warp-level shuffle)
- _omp_reduction_list_to_global_reduce_func (block-to-global)
- _omp_reduction_global_to_list_copy_func (global-to-block)

Fixes multi-dimensional array reductions on GPU target regions with teams distribute parallel for directives.

>From 99e9aaaf317462f2a34da303e11dfee31df3f659 Mon Sep 17 00:00:00 2001
From: Sunil Shrestha <sshrestha at pe28vega.hpc.amslabs.hpecorp.net>
Date: Wed, 14 Jan 2026 18:07:44 -0600
Subject: [PATCH] [flang][openmp] Fix GPU byref reduction descriptor
 initialization

When generating GPU reduction code for arrays passed by reference,
only the base_ptr field was initialized in the shuffled descriptor,
leaving extent, stride, and rank fields uninitialized. This caused
garbage metadata to be passed to user reduction combiners, resulting
in incorrect iteration bounds and crashes on GPU targets.

Fix by copying the entire source descriptor and then updating the
base_ptr to point to thread-private storage. This preserves all
metadata (extents, strides, rank) while correctly pointing to the
shuffled data location.

The fix applies to three reduction helper functions:
- _omp_reduction_shuffle_and_reduce_func (warp-level shuffle)
- _omp_reduction_list_to_global_reduce_func (block-to-global)
- _omp_reduction_global_to_list_copy_func (global-to-block)

Fixes multi-dimensional array reductions on GPU target regions with
teams distribute parallel for directives.
---
 .../target-reduction-array-descriptor.f90     | 36 +++++++++
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 15 ++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 79 +++++++++++++++----
 3 files changed, 114 insertions(+), 16 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/target-reduction-array-descriptor.f90

diff --git a/flang/test/Lower/OpenMP/target-reduction-array-descriptor.f90 b/flang/test/Lower/OpenMP/target-reduction-array-descriptor.f90
new file mode 100644
index 0000000000000..66f590d7c4fec
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-reduction-array-descriptor.f90
@@ -0,0 +1,36 @@
+! RUN: %if amdgpu-registered-target %{ %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-is-target-device -emit-llvm %s -o - | FileCheck %s %}
+! RUN: %if nvptx-registered-target %{ %flang_fc1 -triple nvptx64-nvidia-cuda -fopenmp -fopenmp-is-target-device -emit-llvm %s -o - | FileCheck %s %}
+
+! Test that array reductions on target regions properly generate descriptors
+! for GPU device code.
+
+subroutine test_array_reduction()
+  integer*4 :: red_array(4)
+  integer*4 :: input_array(4,1000)
+  integer :: i
+
+  red_array = 0
+  input_array = 1
+
+  !$omp target teams distribute parallel do reduction(+:red_array) map(tofrom:red_array,input_array)
+  do i = 1, 1000
+    red_array = red_array + input_array(:,i)
+  end do
+  !$omp end target teams distribute parallel do
+
+  print *, red_array
+end subroutine test_array_reduction
+
+! Verify all three reduction functions are generated for teams construct
+! CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func
+! CHECK-DAG: define internal void @_omp_reduction_list_to_global_reduce_func
+! CHECK-DAG: define internal void @_omp_reduction_global_to_list_copy_func
+
+! Verify descriptor is copied via memcpy in global_to_list_copy function
+! CHECK: call void @llvm.memcpy{{.*}}(ptr {{.*}}, ptr {{.*}}, i64 {{[0-9]+}}, i1 false)
+
+! Verify base_ptr is updated after memcpy
+! CHECK: getelementptr {{.*}} ptr {{%.*}}, i32 0, i32 0
+! CHECK: store ptr {{%.*}}, ptr
+
+! Verify the three reduction helper functions are generated
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 037fcaa863fe7..778c3cc72281b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1935,6 +1935,21 @@ class OpenMPIRBuilder {
   /// Get the function name of a reduction function.
   std::string getReductionFuncName(StringRef Name) const;
 
+  /// Generate a Fortran descriptor for array reductions
+  ///
+  /// \param DescriptorAddr Address of the descriptor to initialize
+  /// \param DataPtr Pointer to the actual data the descriptor should reference
+  /// \param ElemType Type of elements in the array (may be array type)
+  /// \param DescriptorType Type of the descriptor structure
+  /// \param DataPtrPtrGen Callback to get the base_ptr field in the descriptor
+  ///
+  /// \return Error if DataPtrPtrGen fails, otherwise success.
+  InsertPointOrErrorTy generateReductionDescriptor(
+      Value *DescriptorAddr, Value *DataPtr, Value *SrcDescriptorAddr,
+      Type *DescriptorType,
+      function_ref<InsertPointOrErrorTy(InsertPointTy, Value *, Value *&)>
+          DataPtrPtrGen);
+
   /// Emits reduction function.
   /// \param ReducerName Name of the function calling the reduction.
   /// \param ReductionInfos Array type containing the ReductionOps.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8d7a207a91f5a..06a93b8826839 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3076,19 +3076,16 @@ Error OpenMPIRBuilder::emitReductionListCopy(
                       RemoteLaneOffset, ReductionArrayTy, IsByRefElem);
 
       if (IsByRefElem) {
-        Value *GEP;
-        InsertPointOrErrorTy GenResult =
-            RI.DataPtrPtrGen(Builder.saveIP(),
-                             Builder.CreatePointerBitCastOrAddrSpaceCast(
-                                 DestAlloca, Builder.getPtrTy(), ".ascast"),
-                             GEP);
+        // Copy descriptor from source and update base_ptr to shuffled data
+        Value *DestDescriptorAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+            DestAlloca, Builder.getPtrTy(), ".ascast");
+
+        InsertPointOrErrorTy GenResult = generateReductionDescriptor(
+            DestDescriptorAddr, LocalStorage, SrcElementAddr,
+            RI.ByRefAllocatedType, RI.DataPtrPtrGen);
 
         if (!GenResult)
           return GenResult.takeError();
-
-        Builder.CreateStore(Builder.CreatePointerBitCastOrAddrSpaceCast(
-                                LocalStorage, Builder.getPtrTy(), ".ascast"),
-                            GEP);
       }
     } else {
       switch (RI.EvaluationKind) {
@@ -3578,6 +3575,37 @@ Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
   return SarFunc;
 }
 
+OpenMPIRBuilder::InsertPointOrErrorTy
+OpenMPIRBuilder::generateReductionDescriptor(
+    Value *DescriptorAddr, Value *DataPtr, Value *SrcDescriptorAddr,
+    Type *DescriptorType,
+    function_ref<InsertPointOrErrorTy(InsertPointTy, Value *, Value *&)>
+        DataPtrPtrGen) {
+
+  // Copy the source descriptor to preserve all metadata (rank, extents,
+  // strides, etc.)
+  Value *DescriptorSize =
+      Builder.getInt64(M.getDataLayout().getTypeStoreSize(DescriptorType));
+  Builder.CreateMemCpy(
+      DescriptorAddr, M.getDataLayout().getPrefTypeAlign(DescriptorType),
+      SrcDescriptorAddr, M.getDataLayout().getPrefTypeAlign(DescriptorType),
+      DescriptorSize);
+
+  // Update the base pointer field to point to the local shuffled data
+  Value *DataPtrField;
+  InsertPointOrErrorTy GenResult =
+      DataPtrPtrGen(Builder.saveIP(), DescriptorAddr, DataPtrField);
+
+  if (!GenResult)
+    return GenResult.takeError();
+
+  Builder.CreateStore(Builder.CreatePointerBitCastOrAddrSpaceCast(
+                          DataPtr, Builder.getPtrTy(), ".ascast"),
+                      DataPtrField);
+
+  return Builder.saveIP();
+}
+
 Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
     AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
@@ -3790,15 +3818,24 @@ Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
         ReductionsBufferTy, BufferVD, 0, En.index());
 
     if (!IsByRef.empty() && IsByRef[En.index()]) {
-      Value *ByRefDataPtr;
+      // Get source descriptor from the reduce list argument
+      Value *ReduceList =
+          Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+      Value *SrcElementPtrPtr =
+          Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+                                    {ConstantInt::get(IndexTy, 0),
+                                     ConstantInt::get(IndexTy, En.index())});
+      Value *SrcDescriptorAddr =
+          Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrPtr);
 
+      // Copy descriptor from source and update base_ptr to global buffer data
       InsertPointOrErrorTy GenResult =
-          RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr);
+          generateReductionDescriptor(ByRefAlloc, GlobValPtr, SrcDescriptorAddr,
+                                      RI.ByRefAllocatedType, RI.DataPtrPtrGen);
 
       if (!GenResult)
         return GenResult.takeError();
 
-      Builder.CreateStore(GlobValPtr, ByRefDataPtr);
       Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
     } else {
       Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
@@ -4024,13 +4061,23 @@ Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
         ReductionsBufferTy, BufferVD, 0, En.index());
 
     if (!IsByRef.empty() && IsByRef[En.index()]) {
-      Value *ByRefDataPtr;
+      // Get source descriptor from the reduce list
+      Value *ReduceListVal =
+          Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+      Value *SrcElementPtrPtr =
+          Builder.CreateInBoundsGEP(RedListArrayTy, ReduceListVal,
+                                    {ConstantInt::get(IndexTy, 0),
+                                     ConstantInt::get(IndexTy, En.index())});
+      Value *SrcDescriptorAddr =
+          Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrPtr);
+
+      // Copy descriptor from source and update base_ptr to global buffer data
       InsertPointOrErrorTy GenResult =
-          RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr);
+          generateReductionDescriptor(ByRefAlloc, GlobValPtr, SrcDescriptorAddr,
+                                      RI.ByRefAllocatedType, RI.DataPtrPtrGen);
       if (!GenResult)
         return GenResult.takeError();
 
-      Builder.CreateStore(GlobValPtr, ByRefDataPtr);
       Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
     } else {
       Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);



More information about the llvm-commits mailing list