[flang-commits] [flang] [llvm] [mlir] [Flang][OpenMP] Implement device clause lowering for target directive (PR #173509)
via flang-commits
flang-commits at lists.llvm.org
Wed Dec 24 12:28:21 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Chi-Chun, Chen (chichunchen)
<details>
<summary>Changes</summary>
Add lowering support for the OpenMP `device` clause on the `target` directive in Flang.
The device expression is propagated through MLIR OpenMP and passed to the host-side `__tgt_target_kernel` call.
---
Full diff: https://github.com/llvm/llvm-project/pull/173509.diff
7 Files Affected:
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-1)
- (modified) flang/test/Lower/OpenMP/target.f90 (+41)
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+1-1)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+6-8)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+14-4)
- (added) mlir/test/Target/LLVMIR/omptarget-device.mlir (+68)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-2)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7965119764e5d..4f2b8ef15519c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4087,7 +4087,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::Mergeable>(clause.u) &&
!std::holds_alternative<clause::Untied>(clause.u) &&
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
- !std::holds_alternative<clause::Detach>(clause.u)) {
+ !std::holds_alternative<clause::Detach>(clause.u) &&
+ !std::holds_alternative<clause::Device>(clause.u)) {
std::string name =
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
!$omp end target
!CHECK: }
end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+ integer :: dev32
+ integer(kind=8) :: dev64
+ integer(kind=2) :: dev16
+
+ dev32 = 1
+ dev64 = 2_8
+ dev16 = 3_2
+
+ !$omp target device(dev32)
+ !$omp end target
+ ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+ ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+ !$omp target device(dev64)
+ !$omp end target
+ ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+ ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+ !$omp target device(dev16)
+ !$omp end target
+ ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+ ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+ !$omp target device(2)
+ !$omp end target
+ ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+ ! CHECK: omp.target device(%[[C2]] : i32)
+
+ !$omp target device(5_8)
+ !$omp end target
+ ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+ ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f5eb6222fd58d..8103a7e9504ea 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3341,7 +3341,7 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
- TargetRegionEntryInfo &EntryInfo,
+ Value *DeviceID, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 716f8582dd7b2..3be96350cb058 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8548,7 +8548,7 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
static void emitTargetCall(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::TargetDataInfo &Info,
+ OpenMPIRBuilder::TargetDataInfo &Info, Value *DeviceID,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
@@ -8680,8 +8680,6 @@ static void emitTargetCall(
}
unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8740,7 +8738,7 @@ static void emitTargetCall(
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetDataInfo &Info,
+ InsertPointTy CodeGenIP, TargetDataInfo &Info, Value *DeviceID,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
@@ -8770,10 +8768,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
- IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
- CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
- DynCGroupMemFallback);
+ emitTargetCall(*this, Builder, AllocaIP, Info, DeviceID, DefaultAttrs,
+ RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Inputs,
+ GenMapInfoCB, CustomMapperCB, Dependencies, HasNowait,
+ DynCGroupMem, DynCGroupMemFallback);
return Builder.saveIP();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03d67a52853f6..ac2d6c93b890e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -320,7 +320,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("depend");
};
auto checkDevice = [&todo](auto op, LogicalResult &result) {
- if (op.getDevice())
+ if (op.getDevice() && !isa<omp::TargetOp>(op))
result = todo("device");
};
auto checkHint = [](auto op, LogicalResult &) {
@@ -5961,6 +5961,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
bool isGPU = ompBuilder->Config.isGPU();
+ llvm::Value *deviceIDValue = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+
+ if (!isTargetDevice) {
+ if (mlir::Value devId = targetOp.getDevice()) {
+ deviceIDValue = moduleTranslation.lookupValue(devId);
+ deviceIDValue =
+ builder.CreateSExtOrTrunc(deviceIDValue, builder.getInt64Ty());
+ }
+ }
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
@@ -6235,9 +6244,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
- defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+ ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info,
+ deviceIDValue, entryInfo, defaultAttrs, runtimeAttrs, ifCond,
+ kernelInput, genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
+ targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+ llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+ %x = llvm.mlir.constant(0 : i32) : i32
+
+ // Constant i16 -> i64 in the runtime call.
+ %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+ omp.target device(%c1_i16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i32 -> i64 in the runtime call.
+ %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+ omp.target device(%c2_i32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i64 stays i64 in the runtime call.
+ %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+ omp.target device(%c3_i64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i16 -> cast to i64.
+ omp.target device(%d16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i32 -> cast to i64.
+ omp.target device(%d32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i64 stays i64.
+ omp.target device(%d64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ llvm.return
+ }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
// -----
llvm.func @target_device(%x : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause device in omp.target operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.target}}
omp.target device(%x : i32) {
omp.terminator
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/173509
More information about the flang-commits
mailing list