[Mlir-commits] [llvm] [mlir] [OpenMP][flang] Support GPU team reductions on allocatables (PR #169651)

Kareem Ergawy llvmlistbot at llvm.org
Thu Nov 27 01:10:29 PST 2025


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/169651

>From 008decad1c3b1d5fdbff85b0aeb36853aa919926 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 26 Nov 2025 04:42:19 -0600
Subject: [PATCH] [OpenMP][flang] Support GPU team reductions on allocatables

Extends the work started in #165714 by supporting team reductions.
Similar to what was done in #165714, this PR introduces proper
allocations, loads, and stores for by-ref reductions in teams-related
callbacks:
* `_omp_reduction_list_to_global_copy_func`,
* `_omp_reduction_list_to_global_reduce_func`,
* `_omp_reduction_global_to_list_copy_func`, and
* `_omp_reduction_global_to_list_reduce_func`.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  24 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 195 +++++++++++++-----
 .../LLVMIR/allocatable_gpu_reduction.mlir     |   2 +
 .../allocatable_gpu_reduction_teams.mlir      | 121 +++++++++++
 4 files changed, 285 insertions(+), 57 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7b097d1ac0ee0..71234c7ee898a 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1731,9 +1731,10 @@ class OpenMPIRBuilder {
   ///                  need to be copied to the new function.
   ///
   /// \return The ListToGlobalCopy function.
-  Function *emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
-                                         Type *ReductionsBufferTy,
-                                         AttributeList FuncAttrs);
+  Expected<Function *>
+  emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
+                               Type *ReductionsBufferTy,
+                               AttributeList FuncAttrs, ArrayRef<bool> IsByRef);
 
   /// This function emits a helper that copies all the reduction variables from
   /// the team into the provided global buffer for the reduction variables.
@@ -1748,9 +1749,10 @@ class OpenMPIRBuilder {
   ///                  need to be copied to the new function.
   ///
   /// \return The GlobalToList function.
-  Function *emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
-                                         Type *ReductionsBufferTy,
-                                         AttributeList FuncAttrs);
+  Expected<Function *>
+  emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
+                               Type *ReductionsBufferTy,
+                               AttributeList FuncAttrs, ArrayRef<bool> IsByRef);
 
   /// This function emits a helper that reduces all the reduction variables from
   /// the team into the provided global buffer for the reduction variables.
@@ -1769,10 +1771,11 @@ class OpenMPIRBuilder {
   ///                  need to be copied to the new function.
   ///
   /// \return The ListToGlobalReduce function.
-  Function *
+  Expected<Function *>
   emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
                                  Function *ReduceFn, Type *ReductionsBufferTy,
-                                 AttributeList FuncAttrs);
+                                 AttributeList FuncAttrs,
+                                 ArrayRef<bool> IsByRef);
 
   /// This function emits a helper that reduces all the reduction variables from
   /// the team into the provided global buffer for the reduction variables.
@@ -1791,10 +1794,11 @@ class OpenMPIRBuilder {
   ///                  need to be copied to the new function.
   ///
   /// \return The GlobalToListReduce function.
-  Function *
+  Expected<Function *>
   emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
                                  Function *ReduceFn, Type *ReductionsBufferTy,
-                                 AttributeList FuncAttrs);
+                                 AttributeList FuncAttrs,
+                                 ArrayRef<bool> IsByRef);
 
   /// Get the function name of a reduction function.
   std::string getReductionFuncName(StringRef Name) const;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index c962368859730..4264d7cb986af 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3151,9 +3151,9 @@ Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
   return SarFunc;
 }
 
-Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
+Expected<Function *> OpenMPIRBuilder::emitListToGlobalCopyFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
-    AttributeList FuncAttrs) {
+    AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   FunctionType *FuncTy = FunctionType::get(
@@ -3223,7 +3223,21 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
 
     switch (RI.EvaluationKind) {
     case EvalKind::Scalar: {
-      Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+      Value *TargetElement;
+
+      if (IsByRef.empty() || !IsByRef[En.index()]) {
+        TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+      } else {
+        InsertPointOrErrorTy GenResult =
+            RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
+
+        if (!GenResult)
+          return GenResult.takeError();
+
+        ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
+        TargetElement = Builder.CreateLoad(RI.ByRefElementType, ElemPtr);
+      }
+
       Builder.CreateStore(TargetElement, GlobVal);
       break;
     }
@@ -3261,9 +3275,9 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
   return LtGCFunc;
 }
 
-Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
+Expected<Function *> OpenMPIRBuilder::emitListToGlobalReduceFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
-    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+    Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   FunctionType *FuncTy = FunctionType::get(
@@ -3302,6 +3316,8 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
   Value *LocalReduceList =
       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
 
+  InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
+
   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
       BufferArgAlloca, Builder.getPtrTy(),
       BufferArgAlloca->getName() + ".ascast");
@@ -3323,6 +3339,20 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
   Type *IndexTy = Builder.getIndexTy(
       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
   for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *ByRefAlloc;
+
+    if (!IsByRef.empty() && IsByRef[En.index()]) {
+      InsertPointTy OldIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+
+      ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
+      ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
+
+      Builder.restoreIP(OldIP);
+    }
+
     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
         RedListArrayTy, LocalReduceListAddrCast,
         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3331,7 +3361,21 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
     // Global = Buffer.VD[Idx];
     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
         ReductionsBufferTy, BufferVD, 0, En.index());
-    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+
+    if (!IsByRef.empty() && IsByRef[En.index()]) {
+      Value *ByRefDataPtr;
+
+      InsertPointOrErrorTy GenResult =
+          RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr);
+
+      if (!GenResult)
+        return GenResult.takeError();
+
+      Builder.CreateStore(GlobValPtr, ByRefDataPtr);
+      Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
+    } else {
+      Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+    }
   }
 
   // Call reduce_function(GlobalReduceList, ReduceList)
@@ -3344,32 +3388,32 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
   return LtGRFunc;
 }
 
-Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
+Expected<Function *> OpenMPIRBuilder::emitGlobalToListCopyFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
-    AttributeList FuncAttrs) {
+    AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   FunctionType *FuncTy = FunctionType::get(
       Builder.getVoidTy(),
       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
       /* IsVarArg */ false);
-  Function *LtGCFunc =
+  Function *GtLCFunc =
       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
                        "_omp_reduction_global_to_list_copy_func", &M);
-  LtGCFunc->setAttributes(FuncAttrs);
-  LtGCFunc->addParamAttr(0, Attribute::NoUndef);
-  LtGCFunc->addParamAttr(1, Attribute::NoUndef);
-  LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+  GtLCFunc->setAttributes(FuncAttrs);
+  GtLCFunc->addParamAttr(0, Attribute::NoUndef);
+  GtLCFunc->addParamAttr(1, Attribute::NoUndef);
+  GtLCFunc->addParamAttr(2, Attribute::NoUndef);
 
-  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLCFunc);
   Builder.SetInsertPoint(EntryBlock);
 
   // Buffer: global reduction buffer.
-  Argument *BufferArg = LtGCFunc->getArg(0);
+  Argument *BufferArg = GtLCFunc->getArg(0);
   // Idx: index of the buffer.
-  Argument *IdxArg = LtGCFunc->getArg(1);
+  Argument *IdxArg = GtLCFunc->getArg(1);
   // ReduceList: thread local Reduce list.
-  Argument *ReduceListArg = LtGCFunc->getArg(2);
+  Argument *ReduceListArg = GtLCFunc->getArg(2);
 
   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
                                                 BufferArg->getName() + ".addr");
@@ -3413,7 +3457,20 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
 
     switch (RI.EvaluationKind) {
     case EvalKind::Scalar: {
-      Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+      Type *ElemType = RI.ElementType;
+
+      if (!IsByRef.empty() && IsByRef[En.index()]) {
+        ElemType = RI.ByRefElementType;
+        InsertPointOrErrorTy GenResult =
+            RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr);
+
+        if (!GenResult)
+          return GenResult.takeError();
+
+        ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
+      }
+
+      Value *TargetElement = Builder.CreateLoad(ElemType, GlobValPtr);
       Builder.CreateStore(TargetElement, ElemPtr);
       break;
     }
@@ -3449,35 +3506,35 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
 
   Builder.CreateRetVoid();
   Builder.restoreIP(OldIP);
-  return LtGCFunc;
+  return GtLCFunc;
 }
 
-Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
+Expected<Function *> OpenMPIRBuilder::emitGlobalToListReduceFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
-    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+    Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   auto *FuncTy = FunctionType::get(
       Builder.getVoidTy(),
       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
       /* IsVarArg */ false);
-  Function *LtGRFunc =
+  Function *GtLRFunc =
       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
                        "_omp_reduction_global_to_list_reduce_func", &M);
-  LtGRFunc->setAttributes(FuncAttrs);
-  LtGRFunc->addParamAttr(0, Attribute::NoUndef);
-  LtGRFunc->addParamAttr(1, Attribute::NoUndef);
-  LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+  GtLRFunc->setAttributes(FuncAttrs);
+  GtLRFunc->addParamAttr(0, Attribute::NoUndef);
+  GtLRFunc->addParamAttr(1, Attribute::NoUndef);
+  GtLRFunc->addParamAttr(2, Attribute::NoUndef);
 
-  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLRFunc);
   Builder.SetInsertPoint(EntryBlock);
 
   // Buffer: global reduction buffer.
-  Argument *BufferArg = LtGRFunc->getArg(0);
+  Argument *BufferArg = GtLRFunc->getArg(0);
   // Idx: index of the buffer.
-  Argument *IdxArg = LtGRFunc->getArg(1);
+  Argument *IdxArg = GtLRFunc->getArg(1);
   // ReduceList: thread local Reduce list.
-  Argument *ReduceListArg = LtGRFunc->getArg(2);
+  Argument *ReduceListArg = GtLRFunc->getArg(2);
 
   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
                                                 BufferArg->getName() + ".addr");
@@ -3493,6 +3550,8 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
   Value *LocalReduceList =
       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
 
+  InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
+
   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
       BufferArgAlloca, Builder.getPtrTy(),
       BufferArgAlloca->getName() + ".ascast");
@@ -3514,6 +3573,20 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
   Type *IndexTy = Builder.getIndexTy(
       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
   for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *ByRefAlloc;
+
+    if (!IsByRef.empty() && IsByRef[En.index()]) {
+      InsertPointTy OldIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+
+      ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
+      ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
+
+      Builder.restoreIP(OldIP);
+    }
+
     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
         RedListArrayTy, ReductionList,
         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3522,7 +3595,19 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
         ReductionsBufferTy, BufferVD, 0, En.index());
-    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+
+    if (!IsByRef.empty() && IsByRef[En.index()]) {
+      Value *ByRefDataPtr;
+      InsertPointOrErrorTy GenResult =
+          RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr);
+      if (!GenResult)
+        return GenResult.takeError();
+
+      Builder.CreateStore(GlobValPtr, ByRefDataPtr);
+      Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
+    } else {
+      Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+    }
   }
 
   // Call reduce_function(ReduceList, GlobalReduceList)
@@ -3532,7 +3617,7 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
       ->addFnAttr(Attribute::NoUnwind);
   Builder.CreateRetVoid();
   Builder.restoreIP(OldIP);
-  return LtGRFunc;
+  return GtLRFunc;
 }
 
 std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
@@ -3788,7 +3873,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
     if (Size > MaxDataSize)
       MaxDataSize = Size;
-    ReductionTypeArgs.emplace_back(En.value().ElementType);
+    Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
+                           ? En.value().ByRefElementType
+                           : En.value().ElementType;
+    ReductionTypeArgs.emplace_back(RedTypeArg);
   }
   Value *ReductionDataSize =
       Builder.getInt64(MaxDataSize * ReductionInfos.size());
@@ -3806,20 +3894,33 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     CodeGenIP = Builder.saveIP();
     StructType *ReductionsBufferTy = StructType::create(
         Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
-    Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
+    Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
         RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
-    Function *LtGCFunc = emitListToGlobalCopyFunction(
-        ReductionInfos, ReductionsBufferTy, FuncAttrs);
-    Function *LtGRFunc = emitListToGlobalReduceFunction(
-        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
-    Function *GtLCFunc = emitGlobalToListCopyFunction(
-        ReductionInfos, ReductionsBufferTy, FuncAttrs);
-    Function *GtLRFunc = emitGlobalToListReduceFunction(
-        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+
+    Expected<Function *> LtGCFunc = emitListToGlobalCopyFunction(
+        ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
+    if (!LtGCFunc)
+      return LtGCFunc.takeError();
+
+    Expected<Function *> LtGRFunc = emitListToGlobalReduceFunction(
+        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
+    if (!LtGRFunc)
+      return LtGRFunc.takeError();
+
+    Expected<Function *> GtLCFunc = emitGlobalToListCopyFunction(
+        ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
+    if (!GtLCFunc)
+      return GtLCFunc.takeError();
+
+    Expected<Function *> GtLRFunc = emitGlobalToListReduceFunction(
+        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
+    if (!GtLRFunc)
+      return GtLRFunc.takeError();
+
     Builder.restoreIP(CodeGenIP);
 
     Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
-        RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
+        RedFixedBufferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
 
     Value *Args3[] = {SrcLocInfo,
                       KernelTeamsReductionPtr,
@@ -3828,10 +3929,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
                       RL,
                       *SarFunc,
                       WcFunc,
-                      LtGCFunc,
-                      LtGRFunc,
-                      GtLCFunc,
-                      GtLRFunc};
+                      *LtGCFunc,
+                      *LtGRFunc,
+                      *GtLCFunc,
+                      *GtLRFunc};
 
     Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
         RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
index df606150b760a..95d12f304aca0 100644
--- a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
@@ -1,3 +1,5 @@
+// Tests single-team by-ref GPU reductions.
+
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
new file mode 100644
index 0000000000000..1c73a49b0bf9f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
@@ -0,0 +1,121 @@
+// Tests cross-teams by-ref GPU reductions.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+  omp.private {type = private} @_QFfooEi_private_i32 : i32
+  omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5>
+    %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+    omp.yield(%2 : !llvm.ptr)
+  } init {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    omp.yield(%arg1 : !llvm.ptr)
+  } combiner {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+    %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+    %3 = llvm.mlir.constant(1 : i32) : i32
+    %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+    %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+    %6 = llvm.mlir.constant(24 : i32) : i32
+    "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+    %7 = llvm.mlir.constant(24 : i32) : i32
+    "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+    %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr
+    %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+    %12 = llvm.load %9 : !llvm.ptr -> f32
+    %13 = llvm.load %11 : !llvm.ptr -> f32
+    %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32
+    llvm.store %14, %9 : f32, !llvm.ptr
+    omp.yield(%arg0 : !llvm.ptr)
+  } data_ptr_ptr {
+  ^bb0(%arg0: !llvm.ptr):
+    %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    omp.yield(%0 : !llvm.ptr)
+  }
+
+  llvm.func @foo_() {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5>
+    %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+    %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""}
+    %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, descriptor, to, attach) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"}
+    omp.target map_entries(%10 -> %arg0 : !llvm.ptr) {
+      %14 = llvm.mlir.constant(1000000 : i32) : i32
+      %15 = llvm.mlir.constant(1 : i32) : i32
+      omp.teams reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg3 : !llvm.ptr) {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg3 -> %arg5 : !llvm.ptr) {
+              omp.loop_nest (%arg6) : i32 = (%15) to (%14) inclusive step (%15) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK: %[[GLOBALIZED_LOCALS:.*]] = type { float }
+
+// CHECK: define internal void @_omp_reduction_list_to_global_copy_func({{.*}}) {{.*}} {
+// CHECK:   %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK:   %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8
+// CHECK:   %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK:   %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0
+// CHECK:   %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK:   %[[ALLOC_VAL:.*]] = load float, ptr %[[ALLOC_PTR]], align 4
+// Verify that the actual value managed by the descriptor is stored in the globalized 
+// locals arrays; rather than a pointer to the descriptor or a pointer to the value.
+// CHECK:   store float %[[ALLOC_VAL]], ptr %[[GLOB_ELEM_PTR]], align 4
+// CHECK: }
+
+// CHECK: define internal void @_omp_reduction_list_to_global_reduce_func({{.*}}) {{.*}} {
+// Allocate a descriptor to manage the element retrieved from the globalized local array.
+// CHECK:   %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5)
+// CHECK:   %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr
+
+// CHECK:   %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK:   %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK:   %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0
+// Store the pointer to the gloalized local element into the locally allocated descriptor.
+// CHECK:   store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK:   store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8
+// CHECK: }
+
+// CHECK: define internal void @_omp_reduction_global_to_list_copy_func({{.*}}) {{.*}} {
+// CHECK:   %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK:   %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8
+// CHECK:   %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK:   %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0
+// Similar to _omp_reduction_list_to_global_copy_func(...) but in the reverse direction; i.e.
+// the globalized local array is copied from rather than copied to.
+// CHECK:   %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK:   %[[ALLOC_VAL:.*]] = load float, ptr %[[GLOB_ELEM_PTR]], align 4
+// CHECK:   store float %[[ALLOC_VAL]], ptr %[[ALLOC_PTR]], align 4
+// CHECK: }
+
+// CHECK: define internal void @_omp_reduction_global_to_list_reduce_func({{.*}}) {{.*}} {
+// Allocate a descriptor to manage the element retrieved from the globalized local array.
+// CHECK:   %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5)
+// CHECK:   %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr
+
+// CHECK:   %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK:   %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0
+// CHECK:   %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0
+// Store the pointer to the gloalized local element into the locally allocated descriptor.
+// CHECK:   store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8
+// CHECK:   store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8
+// CHECK: }



More information about the Mlir-commits mailing list