[Mlir-commits] [mlir] [Flang][MLIR] Add support for translating in_reduction to LLVMIR (PR #121045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 24 01:24:07 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Kaviya Rajendiran (kaviya2510)
<details>
<summary>Changes</summary>
This PR add support for translation of inreduction to LLVM IR by emitting runtime call to __kmpc_task_reduction_get_th_data.
---
Full diff: https://github.com/llvm/llvm-project/pull/121045.diff
3 Files Affected:
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+42-1)
- (added) mlir/test/Target/LLVMIR/openmp-task-inreduction.mlir (+195)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-26)
``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 060113c4123241..7b13179304519a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -256,7 +256,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::TaskOp op) {
checkAllocate(op, result);
- checkInReduction(op, result);
checkPriority(op, result);
checkUntied(op, result);
})
@@ -1748,6 +1747,48 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
// translate the body of the task:
builder.restoreIP(codegenIP);
+ if (taskOp.getInReductionSyms()) {
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ builder.restoreIP(codegenIP);
+ uint32_t SrcLocStrSize;
+ llvm::LLVMContext &Context = builder.getContext();
+ llvm::Constant *SrcLocStr =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(
+ ompLoc, SrcLocStrSize);
+ llvm::Value *Ident =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(SrcLocStr,
+ SrcLocStrSize);
+ llvm::Value *ThreadID =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
+ MutableArrayRef<BlockArgument> reductionArgs =
+ taskOp.getRegion().getArguments();
+ llvm::Value *tg =
+ llvm::ConstantPointerNull::get(llvm::PointerType::get(Context, 0));
+
+ // Check whether in_reduction is enclosed in a taskgroup
+ // and pass the taskgroup_id enclosing the task
+ auto op = taskOp->getParentOfType<omp::TaskgroupOp>();
+ if (op && moduleTranslation.lookupCall(op)) {
+ tg = moduleTranslation.lookupCall(op);
+ }
+ // Call back function for emitTaskReductionGetThdata
+ auto emitTaskReductionGetThdata = [&](llvm::Value *OrigVal) {
+ llvm::Value *args[] = {ThreadID, tg, OrigVal};
+ // Emit the runtime call
+ llvm::Function *ReductionGetFn =
+ moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
+ llvm::omp::OMPRTL___kmpc_task_reduction_get_th_data);
+ return builder.CreateCall(ReductionGetFn, args);
+ };
+ std::size_t index = 0;
+ for (mlir::Value inReductionVar : taskOp.getInReductionVars()) {
+ llvm::Value *Result = emitTaskReductionGetThdata(
+ moduleTranslation.lookupValue(inReductionVar));
+ moduleTranslation.mapValue(reductionArgs[index], Result);
+ index++;
+ }
+ }
+
auto continuationBlockOrError = convertOmpOpRegions(
taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
if (failed(handleError(continuationBlockOrError, *taskOp)))
diff --git a/mlir/test/Target/LLVMIR/openmp-task-inreduction.mlir b/mlir/test/Target/LLVMIR/openmp-task-inreduction.mlir
new file mode 100644
index 00000000000000..7de88e8c3c8bc5
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-inreduction.mlir
@@ -0,0 +1,195 @@
+ // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+ omp.declare_reduction @add_reduction_i32 : i32 init {
+ ^bb0(%arg0: i32):
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ omp.yield(%0 : i32)
+ } combiner {
+ ^bb0(%arg0: i32, %arg1: i32):
+ %0 = llvm.add %arg0, %arg1 : i32
+ omp.yield(%0 : i32)
+ }
+
+
+ llvm.func @_QPtest_inreduction() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "x", pinned} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %4 = llvm.mlir.constant(0 : i32) : i32
+ llvm.store %4, %3 : i32, !llvm.ptr
+ %5 = llvm.load %3 : !llvm.ptr -> i32
+ llvm.store %5, %1 : i32, !llvm.ptr
+ omp.task in_reduction(@add_reduction_i32 %3 -> %arg0 : !llvm.ptr) {
+ %6 = llvm.load %arg0 : !llvm.ptr -> i32
+ %7 = llvm.mlir.constant(1 : i32) : i32
+ %8 = llvm.add %6, %7 : i32
+ llvm.store %8, %arg0 : i32, !llvm.ptr
+ omp.terminator
+ }
+ llvm.return
+ }
+
+//CHECK-LABEL: define void @_QPtest_inreduction() {
+//CHECK: %[[STRUCTARG:.*]] = alloca { i32, ptr }, align 8
+//CHECK: %[[VAL1:.*]] = alloca i32, i64 1, align 4
+//CHECK: %[[VAL2:.*]] = alloca i32, i64 1, align 4
+//CHECK: store i32 0, ptr %[[VAL2]], align 4
+//CHECK: %[[VAL3:.*]] = load i32, ptr %[[VAL2]], align 4
+//CHECK: store i32 %[[VAL3]], ptr %[[VAL1]], align 4
+//CHECK: br label %entry
+
+//CHECK: entry:
+//CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK: br label %codeRepl
+
+//CHECK: codeRepl:
+//CHECK: %[[TID2:.*]] = getelementptr { i32, ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
+//CHECK: store i32 %[[TID]], ptr %[[TID2]], align 4
+//CHECK: %[[VAL4:.*]] = getelementptr { i32, ptr }, ptr %[[STRUCTARG]], i32 0, i32 1
+//CHECK: store ptr %[[VAL2]], ptr %[[VAL4]], align 8
+//CHECK: %[[TID3:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
+//CHECK: %[[VAL5:.*]] = call ptr @__kmpc_omp_task_alloc(ptr @1, i32 %[[TID3]], i32 1, i64 40, i64 16, ptr @_QPtest_inreduction..omp_par)
+//CHECK: %[[VAL6:.*]] = load ptr, ptr %[[VAL5]], align 8
+//CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL6]], ptr align 1 %[[STRUCTARG]], i64 16, i1 false)
+//CHECK: %[[VAL7:.*]] = call i32 @__kmpc_omp_task(ptr @1, i32 %[[TID3]], ptr %[[VAL5]])
+//CHECK: br label %task.exit
+
+//CHECK: task.exit:
+//CHECK: ret void
+//CHECK: }
+
+//CHECK-LABEL: define internal void @_QPtest_inreduction..omp_par(i32 %{{.*}}, ptr %{{.*}}) {
+//CHECK: task.alloca:
+//CHECK: %[[VAL9:.*]] = load ptr, ptr %{{.*}}, align 8
+//CHECK: %[[TID4:.*]] = getelementptr { i32, ptr }, ptr %[[VAL9]], i32 0, i32 0
+//CHECK: %[[VAL10:.*]] = load i32, ptr %[[TID4]], align 4
+//CHECK: %[[VAL11:.*]] = getelementptr { i32, ptr }, ptr %[[VAL9]], i32 0, i32 1
+//CHECK: %[[VAL12:.*]] = load ptr, ptr %[[VAL11]], align 8
+//CHECK: br label %task.body
+
+//CHECK: task.body:
+//CHECK: %[[VAL13:.*]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[VAL10]], ptr null, ptr %[[VAL12]])
+//CHECK: br label %omp.task.region
+
+//CHECK: omp.task.region:
+//CHECK: %[[VAL14:.*]] = load i32, ptr %[[VAL13]], align 4
+//CHECK: %[[VAL15:.*]] = add i32 %[[VAL14]], 1
+//CHECK: store i32 %[[VAL15]], ptr %[[VAL13]], align 4
+//CHECK: br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK: br label %task.exit.exitStub
+
+//CHECK: task.exit.exitStub:
+//CHECK: ret void
+//CHECK: }
+
+// -----
+
+ llvm.func @_QPtest() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.mlir.constant(1 : i32) : i32
+ llvm.store %3, %1 : i32, !llvm.ptr
+ omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
+ omp.task in_reduction(@add_reduction_i32 %1 -> %arg1 : !llvm.ptr) {
+ %4 = llvm.load %arg1 : !llvm.ptr -> i32
+ %5 = llvm.mlir.constant(1 : i32) : i32
+ %6 = llvm.add %4, %5 : i32
+ llvm.store %6, %arg1 : i32, !llvm.ptr
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+
+//CHECK-LABEL: define void @_QPtest() {
+//CHECK: %[[STRUCTARG:.*]] = alloca { i32, ptr, ptr }, align 8
+//CHECK: %[[X_VAL:.*]] = alloca i32, i64 1, align 4
+//CHECK: store i32 1, ptr %[[X_VAL]], align 4
+//CHECK: %[[KMP_TASKRED_ARRAY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
+//CHECK: br label %entry
+
+//CHECK: entry:
+//CHECK: %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+//CHECK: call void @__kmpc_taskgroup(ptr @1, i32 %[[OMP_GLOBAL_THREAD_NUM]])
+//CHECK: %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[KMP_TASKRED_ARRAY]], i32 0, i32 0
+//CHECK: %[[REDUCE_SHAR:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
+//CHECK: store ptr %[[X_VAL]], ptr %[[REDUCE_SHAR]], align 8
+//CHECK: %[[REDUCE_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
+//CHECK: store ptr %[[X_VAL]], ptr %[[REDUCE_ORIG]], align 8
+//CHECK: %[[REDUCE_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
+//CHECK: store i64 4, ptr %[[REDUCE_SIZE]], align 4
+//CHECK: %[[REDUCE_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
+//CHECK: store ptr @red_init, ptr %[[REDUCE_INIT]], align 8
+//CHECK: %[[REDUCE_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
+//CHECK: store ptr null, ptr %[[REDUCE_FINI]], align 8
+//CHECK: %[[REDUCE_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT:.*]], i32 0, i32 5
+//CHECK: store ptr @red_comb, ptr %[[REDUCE_COMB]], align 8
+//CHECK: %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
+//CHECK: store i64 0, ptr %[[FLAGS]], align 4
+//CHECK: %[[OMP_GLOBAL_THREAD_NUM1:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+//CHECK: %[[TASKRED_INIT_CALL:.*]] = call ptr @__kmpc_taskred_init(i32 %[[OMP_GLOBAL_THREAD_NUM1]], i32 1, ptr %[[KMP_TASKRED_ARRAY]])
+//CHECK: br label %omp.taskgroup.region
+
+//CHECK: omp.taskgroup.region:
+//CHECK: %[[OMP_GLOBAL_THREAD_NUM2:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+//CHECK: br label %codeRepl
+
+//CHECK: codeRepl:
+//CHECK: %[[GEP_OMP_GLOBAL_THREAD_NUM2:.*]] = getelementptr { i32, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
+//CHECK: store i32 %[[OMP_GLOBAL_THREAD_NUM2]], ptr %[[GEP_OMP_GLOBAL_THREAD_NUM2]], align 4
+//CHECK: %[[GEP_:.*]] = getelementptr { i32, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 1
+//CHECK: store ptr %[[TASKRED_INIT_CALL]], ptr %[[GEP_]], align 8
+//CHECK: %[[GEP_4:.*]] = getelementptr { i32, ptr, ptr }, ptr %[[STRUCTARG]], i32 0, i32 2
+//CHECK: store ptr %1, ptr %[[GEP_4]], align 8
+//CHECK: %[[OMP_GLOBAL_THREAD_NUM5:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+//CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc(ptr @1, i32 %[[OMP_GLOBAL_THREAD_NUM5]], i32 1, i64 40, i64 24, ptr @_QPtest..omp_par)
+//CHECK: %[[VAL1:.*]] = load ptr, ptr %[[TASK_ALLOC]], align 8
+//CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL1]], ptr align 1 %[[STRUCTARG]], i64 24, i1 false)
+//CHECK: %[[VAL2:.*]] = call i32 @__kmpc_omp_task(ptr @1, i32 %[[OMP_GLOBAL_THREAD_NUM5]], ptr %3)
+//CHECK: br label %task.exit
+
+//CHECK: task.exit:
+//CHECK: br label %omp.region.cont
+
+//CHECK: omp.region.cont:
+//CHECK: br label %taskgroup.exit
+
+//CHECK: taskgroup.exit:
+//CHECK: call void @__kmpc_end_taskgroup(ptr @1, i32 %[[OMP_GLOBAL_THREAD_NUM]])
+//CHECK: ret void
+//CHECK: }
+
+//CHECK: define internal void @_QPtest..omp_par(i32 %global.tid.val, ptr %0) {
+//CHECK: task.alloca:
+//CHECK: %[[VAL3:.*]] = load ptr, ptr %0, align 8
+//CHECK: %[[GEP_OMP_GLOBAL_THREAD_NUM2:.*]] = getelementptr { i32, ptr, ptr }, ptr %[[VAL3]], i32 0, i32 0
+//CHECK: %[[LOADGEP_OMP_GLOBAL_THREAD_NUM2:.*]] = load i32, ptr %[[GEP_OMP_GLOBAL_THREAD_NUM2:.*]], align 4
+//CHECK: %[[GEP_:.*]] = getelementptr { i32, ptr, ptr }, ptr %[[VAL3]], i32 0, i32 1
+//CHECK: %[[LOADGEP_:.*]] = load ptr, ptr %[[GEP_]], align 8
+//CHECK: %[[GEP_1:.*]] = getelementptr { i32, ptr, ptr }, ptr %[[VAL3]], i32 0, i32 2
+//CHECK: %[[LOADGEP_2:.*]] = load ptr, ptr %[[GEP_1]], align 8
+//CHECK: br label %task.body
+
+//CHECK: task.body:
+//CHECK: %[[GET_TH_DATA:.*]] = call ptr @__kmpc_task_reduction_get_th_data(i32 %[[LOADGEP_OMP_GLOBAL_THREAD_NUM2]], ptr %[[LOADGEP_]], ptr %[[LOADGEP_2]])
+//CHECK: br label %omp.task.region
+
+//CHECK: omp.task.region:
+//CHECK: %[[VAL4:.*]] = load i32, ptr %[[GET_TH_DATA]], align 4
+//CHECK: %[[VAL5:.*]] = add i32 %[[VAL4]], 1
+//CHECK: store i32 %[[VAL5]], ptr %[[GET_TH_DATA]], align 4
+//CHECK: br label %omp.region.cont3
+
+//CHECK: omp.region.cont3:
+//CHECK: br label %task.exit.exitStub
+
+//CHECK: task.exit.exitStub:
+//CHECK: ret void
+//CHECK: }
+
+// -----
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 8f3e466cfbbeb6..feb215470bd1a1 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -399,33 +399,7 @@ llvm.func @task_allocate(%x : !llvm.ptr) {
llvm.return
}
-// -----
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
- %0 = llvm.mlir.constant(0.0 : f32) : f32
- omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
- %1 = llvm.fadd %arg0, %arg1 : f32
- omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
- %2 = llvm.load %arg3 : !llvm.ptr -> f32
- llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
- omp.yield
-}
-llvm.func @task_in_reduction(%x : !llvm.ptr) {
- // expected-error at below {{not yet implemented: Unhandled clause in_reduction in omp.task operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.task}}
- omp.task in_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
- omp.terminator
- }
- llvm.return
-}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/121045
More information about the Mlir-commits
mailing list