[Mlir-commits] [mlir] [flang][mlir] Add support for translating task_reduction to LLVMIR (PR #120957)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 23 03:45:30 PST 2024


https://github.com/NimishMishra updated https://github.com/llvm/llvm-project/pull/120957

>From 6e63de547cfa14cb7693be9ca4d8e0d0808daa92 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Mon, 23 Dec 2024 16:16:53 +0530
Subject: [PATCH 1/2] [flang][mlir] Add support for translating task_reduction
 to LLVMIR

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |   6 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 260 +++++++++++++++++-
 .../Target/LLVMIR/openmp-task-reduction.mlir  |  79 ++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  28 --
 4 files changed, 335 insertions(+), 38 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-task-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..4b9f85e8baf468 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1207,6 +1207,12 @@ class OpenMP_TaskReductionClauseSkip<
     unsigned numTaskReductionBlockArgs() {
       return getTaskReductionVars().size();
     }
+
+    /// Returns the number of reduction variables.
+    unsigned getNumReductionVars() { return getReductionVars().size(); }
+
+
+   auto getReductionSyms() { return getTaskReductionSyms(); }
   }];
 
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..7126a0803189a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -228,11 +228,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getThreadLimit())
       result = todo("thread_limit");
   };
-  auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
-        op.getTaskReductionSyms())
-      result = todo("task_reduction");
-  };
   auto checkUntied = [&todo](auto op, LogicalResult &result) {
     if (op.getUntied())
       result = todo("untied");
@@ -259,10 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkInReduction(op, result);
         checkPriority(op, result);
       })
-      .Case([&](omp::TaskgroupOp op) {
-        checkAllocate(op, result);
-        checkTaskReduction(op, result);
-      })
+      .Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
       .Case([&](omp::TaskwaitOp op) {
         checkDepend(op, result);
         checkNowait(op, result);
@@ -1787,6 +1779,238 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+template <typename OP>
+llvm::Value *createTaskReductionFunction(
+    llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
+    LLVM::ModuleTranslation &moduleTranslation,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
+    OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+  llvm::LLVMContext &Context = builder.getContext();
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
+  if (region.empty()) {
+    return llvm::Constant::getNullValue(OpaquePtrTy);
+  }
+  llvm::FunctionType *funcType = nullptr;
+  if (isByRef[Cnt])
+    funcType = llvm::FunctionType::get(builder.getVoidTy(),
+                                       {OpaquePtrTy, OpaquePtrTy}, false);
+  else
+    funcType =
+        llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
+  llvm::Function *function =
+      llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
+                             builder.GetInsertBlock()->getModule());
+  function->setDoesNotRecurse();
+  llvm::BasicBlock *entry =
+      llvm::BasicBlock::Create(Context, "entry", function);
+  llvm::IRBuilder<> bbBuilder(entry);
+
+  llvm::Value *arg0 = function->getArg(0);
+  llvm::Value *arg1 = function->getArg(1);
+
+  if (name == "red_init") {
+    function->addParamAttr(0, llvm::Attribute::NoAlias);
+    function->addParamAttr(1, llvm::Attribute::NoAlias);
+    if (isByRef[Cnt]) {
+      // TODO: Handle case where the initializer uses initialization from
+      // declare reduction construct using `arg1Alloca`.
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *LoadVal =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      moduleTranslation.mapValue(reductionDecls[Cnt].getInitializerAllocArg(),
+                                 LoadVal);
+    } else {
+      mapInitializationArgs(op, moduleTranslation, reductionDecls,
+                            reductionVariableMap, Cnt);
+    }
+  } else if (name == "red_comb") {
+    if (isByRef[Cnt]) {
+      llvm::AllocaInst *arg0Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      llvm::AllocaInst *arg1Alloca =
+          bbBuilder.CreateAlloca(bbBuilder.getPtrTy());
+      bbBuilder.CreateStore(arg0, arg0Alloca);
+      bbBuilder.CreateStore(arg1, arg1Alloca);
+      llvm::Value *arg0L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg0Alloca);
+      llvm::Value *arg1L =
+          bbBuilder.CreateLoad(bbBuilder.getPtrTy(), arg1Alloca);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    } else {
+      llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
+      llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
+      moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
+      moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
+    }
+  }
+
+  SmallVector<llvm::Value *, 1> phis;
+  if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
+                                     &phis)))
+    return nullptr;
+  assert(
+      phis.size() == 1 &&
+      "expected one value to be yielded from the reduction declaration region");
+  if (!isByRef[Cnt]) {
+    bbBuilder.CreateStore(phis[0], arg0);
+    bbBuilder.CreateRet(arg0); // Return from the function
+  } else {
+    bbBuilder.CreateRet(nullptr);
+  }
+  return function;
+}
+
+void emitTaskRedInitCall(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
+    llvm::Value *ArrayAlloca) {
+  llvm::LLVMContext &Context = builder.getContext();
+  uint32_t SrcLocStrSize;
+  llvm::Constant *SrcLocStr =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
+                                                                 SrcLocStrSize);
+  llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
+      SrcLocStr, SrcLocStrSize);
+  llvm::Value *ThreadID =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+  llvm::Constant *ConstInt =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);
+  llvm::Function *TaskRedInitFn =
+      moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+          llvm::omp::OMPRTL___kmpc_taskred_init);
+  builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
+}
+
+template <typename OP>
+static LogicalResult allocAndInitializeTaskReductionVars(
+    OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+    SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+    SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+    DenseMap<Value, llvm::Value *> &reductionVariableMap,
+    llvm::ArrayRef<bool> isByRef) {
+
+  if (op.getNumReductionVars() == 0)
+    return success();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::LLVMContext &Context = builder.getContext();
+  SmallVector<DeferredStore> deferredStores;
+
+  // Save the current insertion point
+  auto oldIP = builder.saveIP();
+
+  // Set insertion point after the allocations
+  builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
+
+  // Define the kmp_taskred_input_t structure
+  llvm::StructType *kmp_taskred_input_t =
+      llvm::StructType::create(Context, "kmp_taskred_input_t");
+  llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
+  llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
+  llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)
+
+  // Structure members
+  std::vector<llvm::Type *> structMembers = {
+      OpaquePtrTy, // reduce_shar (void*)
+      OpaquePtrTy, // reduce_orig (void*)
+      SizeTy,      // reduce_size (size_t)
+      OpaquePtrTy, // reduce_init (void*)
+      OpaquePtrTy, // reduce_fini (void*)
+      OpaquePtrTy, // reduce_comb (void*)
+      FlagsTy      // flags (i32)
+  };
+
+  kmp_taskred_input_t->setBody(structMembers);
+  int arraySize = op.getNumReductionVars();
+  llvm::ArrayType *ArrayTy =
+      llvm::ArrayType::get(kmp_taskred_input_t, arraySize);
+
+  // Allocate the array for kmp_taskred_input_t
+  llvm::AllocaInst *ArrayAlloca =
+      builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");
+
+  // Restore the insertion point
+  builder.restoreIP(oldIP);
+  llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();
+
+  for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
+    llvm::Value *shared =
+        moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
+    // Create a GEP to access the reduction element
+    llvm::Value *StructPtr = builder.CreateGEP(
+        ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
+        "red_element");
+    llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
+    builder.CreateStore(shared, FieldPtrReduceShar);
+
+    llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
+    builder.CreateStore(shared, FieldPtrReduceOrig);
+
+    // Store size of the reduction variable
+    llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 2, "reduce_size");
+    llvm::Type *redTy;
+    if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
+      redTy = alloca->getAllocatedType();
+      uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);
+      llvm::ConstantInt *sizeConst =
+          llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
+      builder.CreateStore(sizeConst, FieldPtrReduceSize);
+    } else {
+      llvm_unreachable("Non alloca instruction found.");
+    }
+
+    // Initialize reduction variable
+    llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 3, "reduce_init");
+    llvm::Value *initFunction = createTaskReductionFunction(
+        builder, "red_init", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(initFunction, FieldPtrReduceInit);
+
+    // Create finish and combine functions
+    llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
+    llvm::Value *finiFunction = createTaskReductionFunction(
+        builder, "red_fini", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(finiFunction, FieldPtrReduceFini);
+
+    llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
+        kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
+    llvm::Value *combFunction = createTaskReductionFunction(
+        builder, "red_comb", redTy, moduleTranslation, reductionDecls,
+        reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
+        privateReductionVariables, reductionVariableMap);
+    builder.CreateStore(combFunction, FieldPtrReduceComb);
+
+    llvm::Value *FieldPtrFlags =
+        builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
+    llvm::ConstantInt *flagVal =
+        llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
+    builder.CreateStore(flagVal, FieldPtrFlags);
+  }
+
+  // Emit the runtime call
+  emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
+                      ArrayAlloca);
+  return success();
+}
+
 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
@@ -1794,9 +2018,25 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   if (failed(checkImplementationStatus(*tgOp)))
     return failure();
+  LogicalResult bodyGenStatus = success();
+  // Setup for `task_reduction`
+  llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
+  assert(isByRef.size() == tgOp.getNumReductionVars());
+  SmallVector<omp::DeclareReductionOp> reductionDecls;
+  collectReductionDecls(tgOp, reductionDecls);
+  SmallVector<llvm::Value *> privateReductionVariables(
+      tgOp.getNumReductionVars());
+  DenseMap<Value, llvm::Value *> reductionVariableMap;
+  MutableArrayRef<BlockArgument> reductionArgs =
+      tgOp.getRegion().getArguments();
 
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     builder.restoreIP(codegenIP);
+    if (failed(allocAndInitializeTaskReductionVars(
+            tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
+            reductionDecls, privateReductionVariables, reductionVariableMap,
+            isByRef)))
+      bodyGenStatus = failure();
     return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
                                builder, moduleTranslation)
         .takeError();
@@ -1812,7 +2052,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.restoreIP(*afterIP);
-  return success();
+  return bodyGenStatus;
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
new file mode 100644
index 00000000000000..1d4d22d5413c61
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+llvm.func @_QPtest_task_reduciton() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+      %2 = llvm.load %1 : !llvm.ptr -> i32
+      %3 = llvm.mlir.constant(1 : i32) : i32
+      %4 = llvm.add %2, %3 : i32
+      llvm.store %4, %1 : i32, !llvm.ptr
+      omp.terminator
+  }
+  llvm.return
+}
+
+//CHECK-LABEL: define void @_QPtest_task_reduciton() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
+//CHECK:   %4 = add i32 %[[VAL3]], 1
+//CHECK:   store i32 %4, ptr %[[VAL1]], align 4
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
+//CHECK: entry:
+//CHECK:   store i32 0, ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
+
+//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
+//CHECK: entry:
+//CHECK:   %[[LD0:.*]] = load i32, ptr %0, align 4
+//CHECK:   %[[LD1:.*]] = load i32, ptr %1, align 4
+//CHECK:   %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
+//CHECK:   store i32 %[[RES]], ptr %0, align 4
+//CHECK:   ret ptr %0
+//CHECK: }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8ae795ec1ec6b0..a0774d859eecf6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -451,34 +451,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
-  omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
   // expected-error at below {{not yet implemented: omp.taskloop}}
   // expected-error at below {{LLVM Translation failed for operation: omp.taskloop}}

>From 827d2a84149b5f2ae6685dc5086ff2c8c600ece3 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Mon, 23 Dec 2024 17:14:56 +0530
Subject: [PATCH 2/2] Add test for byref

---
 .../LLVMIR/openmp-task-reduction-byref.mlir   | 93 +++++++++++++++++++
 1 file changed, 93 insertions(+)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-task-reduction-byref.mlir

diff --git a/mlir/test/Target/LLVMIR/openmp-task-reduction-byref.mlir b/mlir/test/Target/LLVMIR/openmp-task-reduction-byref.mlir
new file mode 100644
index 00000000000000..3398f11c5c14a4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-reduction-byref.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+omp.declare_reduction @add_reduction_byref_i32 : !llvm.ptr alloc {
+   %0 = llvm.mlir.constant(1 : i64) : i64
+   %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
+   %2 = llvm.mlir.constant(1 : i64) : i64
+   omp.yield(%1 : !llvm.ptr)
+} init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+   %0 = llvm.mlir.constant(0 : i32) : i32
+   llvm.store %0, %arg1 : i32, !llvm.ptr
+   omp.yield(%arg1 : !llvm.ptr)
+} combiner {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+   %0 = llvm.load %arg0 : !llvm.ptr -> i32
+   %1 = llvm.load %arg1 : !llvm.ptr -> i32
+   %2 = llvm.add %0, %1 : i32
+   llvm.store %2, %arg0 : i32, !llvm.ptr
+   omp.yield(%arg0 : !llvm.ptr)
+}
+llvm.func @_QPtest_task_reduction() {
+   %0 = llvm.mlir.constant(1 : i64) : i64
+   %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+   %2 = llvm.mlir.constant(1 : i64) : i64
+   omp.taskgroup task_reduction(byref @add_reduction_byref_i32 %1 -> %arg0 : !llvm.ptr) {
+     omp.terminator
+   }
+  llvm.return
+} 
+
+//CHECK-LABEL: define void @_QPtest_task_reduction() {
+//CHECK:   %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK:   %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK:   br label %entry
+
+//CHECK: entry:
+//CHECK:   %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
+//CHECK:   %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
+//CHECK:   %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
+//CHECK:   %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK:   store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
+//CHECK:   %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK:   store i64 4, ptr %[[RED_SIZE]], align 4
+//CHECK:   %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK:   store ptr @red_init, ptr %[[RED_INIT]], align 8
+//CHECK:   %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK:   store ptr null, ptr %[[RED_FINI]], align 8
+//CHECK:   %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
+//CHECK:   store ptr @red_comb, ptr %[[RED_COMB]], align 8
+//CHECK:   %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK:   store i64 0, ptr %[[FLAGS]], align 4
+//CHECK:   %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK:   %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
+//CHECK:   br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK:   br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK:   br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK:   call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
+//CHECK:   ret void
+//CHECK: }
+
+//CHECK: define void @red_init(ptr noalias %[[ARG_1:.*]], ptr noalias %[[ARG_2:.*]]) #2 {
+//CHECK: entry:
+//CHECK: %[[ALLOCA_1:.*]] = alloca ptr, align 8
+//CHECK: %[[ALLOCA_2:.*]] = alloca ptr, align 8
+//CHECK: store ptr %[[ARG_1]], ptr %[[ALLOCA_1]], align 8
+//CHECK: store ptr %[[ARG_2]], ptr %[[ALLOCA_2]], align 8
+//CHECK: %[[LOAD:.*]] = load ptr, ptr %[[ALLOCA_1]], align 8
+//CHECK: store i32 0, ptr %[[LOAD]], align 4
+//CHECK: ret void
+//CHECK: }
+
+//CHECK: define void @red_comb(ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]]) #2 {
+//CHECK: entry:
+//CHECK: %[[ALLOCA_1:.*]] = alloca ptr, align 8
+//CHECK: %[[ALLOCA_2:.*]] = alloca ptr, align 8
+//CHECK: store ptr %[[ARG_1]], ptr %[[ALLOCA_1]], align 8
+//CHECK: store ptr %[[ARG_2]], ptr %[[ALLOCA_2]], align 8
+//CHECK: %[[LOAD_1:.*]] = load ptr, ptr %[[ALLOCA_1]], align 8
+//CHECK: %[[LOAD_2:.*]] = load ptr, ptr %[[ALLOCA_2]], align 8
+//CHECK: %[[LOAD_1_I32:.*]] = load i32, ptr %[[LOAD_1]], align 4
+//CHECK: %[[LOAD_2_I32:.*]] = load i32, ptr %[[LOAD_2]], align 4
+//CHECK: %[[ADD:.*]] = add i32 %[[LOAD_1_I32]], %[[LOAD_2_I32]]
+//CHECK: store i32 %[[ADD]], ptr %[[LOAD_1]], align 4
+//CHECK: ret void
+//CHECK: }



More information about the Mlir-commits mailing list