[Mlir-commits] [mlir] 3f5d91b - [Flang][OpenMP] Implement device clause lowering for target directive (#173509)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 6 09:10:11 PST 2026
Author: Chi-Chun, Chen
Date: 2026-01-06T11:10:03-06:00
New Revision: 3f5d91bfbc17a487fc14ac2c7f2d866fb97e3906
URL: https://github.com/llvm/llvm-project/commit/3f5d91bfbc17a487fc14ac2c7f2d866fb97e3906
DIFF: https://github.com/llvm/llvm-project/commit/3f5d91bfbc17a487fc14ac2c7f2d866fb97e3906.diff
LOG: [Flang][OpenMP] Implement device clause lowering for target directive (#173509)
Add lowering support for the OpenMP `device` clause on the `target`
directive in Flang.
The device expression is propagated through MLIR OpenMP and passed to
the host-side `__tgt_target_kernel` call.
Added:
mlir/test/Target/LLVMIR/omptarget-device.mlir
Modified:
flang/docs/OpenMPSupport.md
flang/lib/Lower/OpenMP/OpenMP.cpp
flang/test/Lower/OpenMP/target.f90
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/openmp-todo.mlir
Removed:
################################################################################
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index c76cafd1b3a5f..21966c5489108 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -38,7 +38,7 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
| declare simd construct | N | |
| do simd construct | P | linear clause is not supported |
| target data construct | P | device clause not supported |
-| target construct | P | device clause not supported |
+| target construct | Y | |
| target update construct | P | device clause not supported |
| declare target directive | Y | |
| teams construct | Y | |
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 090d608503f26..4381d1e9064cf 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4122,7 +4122,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::Mergeable>(clause.u) &&
!std::holds_alternative<clause::Untied>(clause.u) &&
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
- !std::holds_alternative<clause::Detach>(clause.u)) {
+ !std::holds_alternative<clause::Detach>(clause.u) &&
+ !std::holds_alternative<clause::Device>(clause.u)) {
std::string name =
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
!$omp end target
!CHECK: }
end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+ integer :: dev32
+ integer(kind=8) :: dev64
+ integer(kind=2) :: dev16
+
+ dev32 = 1
+ dev64 = 2_8
+ dev16 = 3_2
+
+ !$omp target device(dev32)
+ !$omp end target
+ ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+ ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+ !$omp target device(dev64)
+ !$omp end target
+ ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+ ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+ !$omp target device(dev16)
+ !$omp end target
+ ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+ ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+ !$omp target device(2)
+ !$omp end target
+ ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+ ! CHECK: omp.target device(%[[C2]] : i32)
+
+ !$omp target device(5_8)
+ !$omp end target
+ ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+ ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index c8db40d3cf51b..d6fc49afb6fdb 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2519,6 +2519,9 @@ class OpenMPIRBuilder {
/// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
/// it is a generic kernel.
Value *LoopTripCount = nullptr;
+
+ /// Device ID value used in the kernel launch.
+ Value *DeviceID = nullptr;
};
/// Data structure that contains the needed information to construct the
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index b6a3d9e66fb9c..a3e7c5ea8059b 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8712,8 +8712,6 @@ static void emitTargetCall(
}
unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8739,13 +8737,13 @@ static void emitTargetCall(
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask)
- return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
- Dependencies, KArgs.RTArgs,
- Info.HasNoWait);
+ return OMPBuilder.emitTargetTask(TaskBodyCB, RuntimeAttrs.DeviceID,
+ RTLoc, AllocaIP, Dependencies,
+ KArgs.RTArgs, Info.HasNoWait);
- return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
- EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP);
+ return OMPBuilder.emitKernelLaunch(
+ Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
+ RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
}());
Builder.restoreIP(AfterIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 7710e66cf87b4..4e35e6819076c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6501,6 +6501,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+ RuntimeAttrs.DeviceID = Builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
@@ -6834,6 +6835,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+ RuntimeAttrs.DeviceID = Builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
llvm::OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b7630bd97dca4..614f06017a324 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -454,12 +454,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
- .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
+ .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
[&](auto op) { checkDepend(op, result); })
+ .Case<omp::TargetUpdateOp>([&](auto op) {
+ checkDepend(op, result);
+ checkDevice(op, result);
+ })
+ .Case<omp::TargetDataOp>([&](auto op) { checkDevice(op, result); })
.Case([&](omp::TargetOp op) {
checkAllocate(op, result);
checkBare(op, result);
- checkDevice(op, result);
checkInReduction(op, result);
})
.Default([](Operation &) {
@@ -5998,6 +6002,13 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
{}, /*HasNUW=*/true);
}
}
+
+ attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+ if (mlir::Value devId = targetOp.getDevice()) {
+ attrs.DeviceID = moduleTranslation.lookupValue(devId);
+ attrs.DeviceID =
+ builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
+ }
}
static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+ llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+ %x = llvm.mlir.constant(0 : i32) : i32
+
+ // Constant i16 -> i64 in the runtime call.
+ %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+ omp.target device(%c1_i16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i32 -> i64 in the runtime call.
+ %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+ omp.target device(%c2_i32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i64 stays i64 in the runtime call.
+ %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+ omp.target device(%c3_i64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i16 -> cast to i64.
+ omp.target device(%d16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i32 -> cast to i64.
+ omp.target device(%d32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i64 stays i64.
+ omp.target device(%d64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ llvm.return
+ }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
// -----
llvm.func @target_device(%x : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause device in omp.target operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.target}}
omp.target device(%x : i32) {
omp.terminator
}
More information about the Mlir-commits
mailing list