[flang-commits] [flang] [llvm] [mlir] [Flang][OpenMP] Implement device clause lowering for target directive (PR #173509)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 24 12:28:21 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Chi-Chun, Chen (chichunchen)

<details>
<summary>Changes</summary>

Add lowering support for the OpenMP `device` clause on the `target` directive in Flang.

The device expression is propagated through MLIR OpenMP and passed to the host-side `__tgt_target_kernel` call.

---
Full diff: https://github.com/llvm/llvm-project/pull/173509.diff


7 Files Affected:

- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-1) 
- (modified) flang/test/Lower/OpenMP/target.f90 (+41) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+1-1) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+6-8) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+14-4) 
- (added) mlir/test/Target/LLVMIR/omptarget-device.mlir (+68) 
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-2) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7965119764e5d..4f2b8ef15519c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4087,7 +4087,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::Mergeable>(clause.u) &&
         !std::holds_alternative<clause::Untied>(clause.u) &&
         !std::holds_alternative<clause::TaskReduction>(clause.u) &&
-        !std::holds_alternative<clause::Detach>(clause.u)) {
+        !std::holds_alternative<clause::Detach>(clause.u) &&
+        !std::holds_alternative<clause::Device>(clause.u)) {
       std::string name =
           parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
       if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
    !$omp end target
    !CHECK: }
 end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+  integer            :: dev32
+  integer(kind=8)    :: dev64
+  integer(kind=2)    :: dev16
+
+  dev32 = 1
+  dev64 = 2_8
+  dev16 = 3_2
+
+  !$omp target device(dev32)
+  !$omp end target
+  ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+  ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+  !$omp target device(dev64)
+  !$omp end target
+  ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+  ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+  !$omp target device(dev16)
+  !$omp end target
+  ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+  ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+  !$omp target device(2)
+  !$omp end target
+  ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+  ! CHECK: omp.target device(%[[C2]] : i32)
+
+  !$omp target device(5_8)
+  !$omp end target
+  ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+  ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f5eb6222fd58d..8103a7e9504ea 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3341,7 +3341,7 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
-      TargetRegionEntryInfo &EntryInfo,
+      Value *DeviceID, TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
       const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 716f8582dd7b2..3be96350cb058 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8548,7 +8548,7 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 static void emitTargetCall(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     OpenMPIRBuilder::InsertPointTy AllocaIP,
-    OpenMPIRBuilder::TargetDataInfo &Info,
+    OpenMPIRBuilder::TargetDataInfo &Info, Value *DeviceID,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
     Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
@@ -8680,8 +8680,6 @@ static void emitTargetCall(
     }
 
     unsigned NumTargetItems = Info.NumberOfPtrs;
-    // TODO: Use correct device ID
-    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
     uint32_t SrcLocStrSize;
     Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
     Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8740,7 +8738,7 @@ static void emitTargetCall(
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetDataInfo &Info,
+    InsertPointTy CodeGenIP, TargetDataInfo &Info, Value *DeviceID,
     TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
     const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
@@ -8770,10 +8768,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
-                   IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
-                   CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
-                   DynCGroupMemFallback);
+    emitTargetCall(*this, Builder, AllocaIP, Info, DeviceID, DefaultAttrs,
+                   RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Inputs,
+                   GenMapInfoCB, CustomMapperCB, Dependencies, HasNowait,
+                   DynCGroupMem, DynCGroupMemFallback);
   return Builder.saveIP();
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03d67a52853f6..ac2d6c93b890e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -320,7 +320,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("depend");
   };
   auto checkDevice = [&todo](auto op, LogicalResult &result) {
-    if (op.getDevice())
+    if (op.getDevice() && !isa<omp::TargetOp>(op))
       result = todo("device");
   };
   auto checkHint = [](auto op, LogicalResult &) {
@@ -5961,6 +5961,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool isTargetDevice = ompBuilder->Config.isTargetDevice();
   bool isGPU = ompBuilder->Config.isGPU();
+  llvm::Value *deviceIDValue = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+
+  if (!isTargetDevice) {
+    if (mlir::Value devId = targetOp.getDevice()) {
+      deviceIDValue = moduleTranslation.lookupValue(devId);
+      deviceIDValue =
+          builder.CreateSExtOrTrunc(deviceIDValue, builder.getInt64Ty());
+    }
+  }
 
   auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
   auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
@@ -6235,9 +6244,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
-          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
-          defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
-          argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info,
+          deviceIDValue, entryInfo, defaultAttrs, runtimeAttrs, ifCond,
+          kernelInput, genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
+          targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %x  = llvm.mlir.constant(0 : i32) : i32
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target device(%c1_i16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target device(%c2_i32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target device(%c3_i64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target device(%d16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target device(%d32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target device(%d64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
 // -----
 
 llvm.func @target_device(%x : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause device in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
   omp.target device(%x : i32) {
     omp.terminator
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/173509


More information about the flang-commits mailing list