[Mlir-commits] [mlir] 3f5d91b - [Flang][OpenMP] Implement device clause lowering for target directive (#173509)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 6 09:10:11 PST 2026


Author: Chi-Chun, Chen
Date: 2026-01-06T11:10:03-06:00
New Revision: 3f5d91bfbc17a487fc14ac2c7f2d866fb97e3906

URL: https://github.com/llvm/llvm-project/commit/3f5d91bfbc17a487fc14ac2c7f2d866fb97e3906
DIFF: https://github.com/llvm/llvm-project/commit/3f5d91bfbc17a487fc14ac2c7f2d866fb97e3906.diff

LOG: [Flang][OpenMP] Implement device clause lowering for target directive (#173509)

Add lowering support for the OpenMP `device` clause on the `target`
directive in Flang.

The device expression is propagated through MLIR OpenMP and passed to
the host-side `__tgt_target_kernel` call.

Added: 
    mlir/test/Target/LLVMIR/omptarget-device.mlir

Modified: 
    flang/docs/OpenMPSupport.md
    flang/lib/Lower/OpenMP/OpenMP.cpp
    flang/test/Lower/OpenMP/target.f90
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Target/LLVMIR/openmp-todo.mlir

Removed: 
    


################################################################################
diff  --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index c76cafd1b3a5f..21966c5489108 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -38,7 +38,7 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | declare simd construct                                     | N      | |
 | do simd construct                                          | P      | linear clause is not supported |
 | target data construct                                      | P      | device clause not supported |
-| target construct                                           | P      | device clause not supported |
+| target construct                                           | Y      | |
 | target update construct                                    | P      | device clause not supported |
 | declare target directive                                   | Y      | |
 | teams construct                                            | Y      | |

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 090d608503f26..4381d1e9064cf 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4122,7 +4122,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::Mergeable>(clause.u) &&
         !std::holds_alternative<clause::Untied>(clause.u) &&
         !std::holds_alternative<clause::TaskReduction>(clause.u) &&
-        !std::holds_alternative<clause::Detach>(clause.u)) {
+        !std::holds_alternative<clause::Detach>(clause.u) &&
+        !std::holds_alternative<clause::Device>(clause.u)) {
       std::string name =
           parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
       if (!semaCtx.langOptions().OpenMPSimd)

diff  --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
    !$omp end target
    !CHECK: }
 end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+  integer            :: dev32
+  integer(kind=8)    :: dev64
+  integer(kind=2)    :: dev16
+
+  dev32 = 1
+  dev64 = 2_8
+  dev16 = 3_2
+
+  !$omp target device(dev32)
+  !$omp end target
+  ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+  ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+  !$omp target device(dev64)
+  !$omp end target
+  ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+  ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+  !$omp target device(dev16)
+  !$omp end target
+  ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+  ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+  !$omp target device(2)
+  !$omp end target
+  ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+  ! CHECK: omp.target device(%[[C2]] : i32)
+
+  !$omp target device(5_8)
+  !$omp end target
+  ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+  ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index c8db40d3cf51b..d6fc49afb6fdb 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2519,6 +2519,9 @@ class OpenMPIRBuilder {
     /// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
     /// it is a generic kernel.
     Value *LoopTripCount = nullptr;
+
+    /// Device ID value used in the kernel launch.
+    Value *DeviceID = nullptr;
   };
 
   /// Data structure that contains the needed information to construct the

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index b6a3d9e66fb9c..a3e7c5ea8059b 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8712,8 +8712,6 @@ static void emitTargetCall(
     }
 
     unsigned NumTargetItems = Info.NumberOfPtrs;
-    // TODO: Use correct device ID
-    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
     uint32_t SrcLocStrSize;
     Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
     Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8739,13 +8737,13 @@ static void emitTargetCall(
       // The presence of certain clauses on the target directive require the
       // explicit generation of the target task.
       if (RequiresOuterTargetTask)
-        return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
-                                         Dependencies, KArgs.RTArgs,
-                                         Info.HasNoWait);
+        return OMPBuilder.emitTargetTask(TaskBodyCB, RuntimeAttrs.DeviceID,
+                                         RTLoc, AllocaIP, Dependencies,
+                                         KArgs.RTArgs, Info.HasNoWait);
 
-      return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
-                                         EmitTargetCallFallbackCB, KArgs,
-                                         DeviceID, RTLoc, AllocaIP);
+      return OMPBuilder.emitKernelLaunch(
+          Builder, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
+          RuntimeAttrs.DeviceID, RTLoc, AllocaIP);
     }());
 
     Builder.restoreIP(AfterIP);

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 7710e66cf87b4..4e35e6819076c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6501,6 +6501,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
   RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
   RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+  RuntimeAttrs.DeviceID = Builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
 
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
@@ -6834,6 +6835,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
   RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+  RuntimeAttrs.DeviceID = Builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
   llvm::OpenMPIRBuilder::TargetDataInfo Info(
       /*RequiresDevicePointerInfo=*/false,
       /*SeparateBeginEndCalls=*/true);

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b7630bd97dca4..614f06017a324 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -454,12 +454,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::SimdOp op) { checkReduction(op, result); })
       .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
             omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
-      .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
+      .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
           [&](auto op) { checkDepend(op, result); })
+      .Case<omp::TargetUpdateOp>([&](auto op) {
+        checkDepend(op, result);
+        checkDevice(op, result);
+      })
+      .Case<omp::TargetDataOp>([&](auto op) { checkDevice(op, result); })
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
-        checkDevice(op, result);
         checkInReduction(op, result);
       })
       .Default([](Operation &) {
@@ -5998,6 +6002,13 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                                               {}, /*HasNUW=*/true);
     }
   }
+
+  attrs.DeviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+  if (mlir::Value devId = targetOp.getDevice()) {
+    attrs.DeviceID = moduleTranslation.lookupValue(devId);
+    attrs.DeviceID =
+        builder.CreateSExtOrTrunc(attrs.DeviceID, builder.getInt64Ty());
+  }
 }
 
 static LogicalResult

diff  --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %x  = llvm.mlir.constant(0 : i32) : i32
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target device(%c1_i16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target device(%c2_i32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target device(%c3_i64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target device(%d16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target device(%d32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target device(%d64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file

diff  --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
 // -----
 
 llvm.func @target_device(%x : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause device in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
   omp.target device(%x : i32) {
     omp.terminator
   }


        


More information about the Mlir-commits mailing list