[flang-commits] [flang] [mlir] [Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add lowering support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET directive. (PR #74187)

Raghu Maddhipatla via flang-commits flang-commits at lists.llvm.org
Fri Dec 1 23:36:34 PST 2023


https://github.com/raghavendhra created https://github.com/llvm/llvm-project/pull/74187

Here is the description of the changes.
1. Changed semantic check to only check for, "Variable in IS_DEVICE_PTR clause must be a dummy argument that does not have the ALLOCATABLE, POINTER or VALUE attribute." Earlier there was an additional semantic check "Variable in IS_DEVICE_PTR clause must be a dummy argument" which might be incorrect.
2. Added lowering support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses for OMP TARGET directive and added related tests for these changes.

>From 5460fa08f5c36e590bf7591880a8c4700e7d3dd6 Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Wed, 25 Oct 2023 18:19:08 -0500
Subject: [PATCH] [Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add lowering
 support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET
 directive.

---
 flang/lib/Lower/OpenMP.cpp                    | 54 +++++++++++++++++--
 flang/lib/Semantics/check-omp-structure.cpp   |  7 +--
 flang/test/Lower/OpenMP/FIR/target.f90        | 41 +++++++++++++-
 flang/test/Semantics/OpenMP/target01.f90      |  1 -
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 18 +++++--
 mlir/test/Dialect/OpenMP/ops.mlir             |  8 +--
 6 files changed, 112 insertions(+), 17 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index a49589d8c59ff81..09d6358c3382617 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -604,6 +604,18 @@ class ClauseProcessor {
                       llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                           &useDeviceSymbols) const;
+  bool
+  processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+                     llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+                     llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+                     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                         &isDeviceSymbols) const;
+  bool
+  processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+                       llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+                       llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                           &isDeviceSymbols) const;
 
   // Call this method for these clauses that should be supported but are not
   // implemented yet. It triggers a compilation error if any of the given
@@ -1890,6 +1902,34 @@ bool ClauseProcessor::processUseDevicePtr(
       });
 }
 
+bool ClauseProcessor::processIsDevicePtr(
+    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+    llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+    const {
+  return findRepeatableClause<ClauseTy::IsDevicePtr>(
+      [&](const ClauseTy::IsDevicePtr *devPtrClause,
+          const Fortran::parser::CharBlock &) {
+        addUseDeviceClause(converter, devPtrClause->v, operands, isDeviceTypes,
+                           isDeviceLocs, isDeviceSymbols);
+      });
+}
+
+bool ClauseProcessor::processHasDeviceAddr(
+    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+    llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+    const {
+  return findRepeatableClause<ClauseTy::HasDeviceAddr>(
+      [&](const ClauseTy::HasDeviceAddr *devAddrClause,
+          const Fortran::parser::CharBlock &) {
+        addUseDeviceClause(converter, devAddrClause->v, operands, isDeviceTypes,
+                           isDeviceLocs, isDeviceSymbols);
+      });
+}
+
 template <typename... Ts>
 void ClauseProcessor::processTODO(mlir::Location currentLocation,
                                   llvm::omp::Directive directive) const {
@@ -2648,6 +2688,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Type> mapSymTypes;
   llvm::SmallVector<mlir::Location> mapSymLocs;
   llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
+  llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
+  llvm::SmallVector<mlir::Type> useDeviceTypes;
+  llvm::SmallVector<mlir::Location> useDeviceLocs;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
 
   ClauseProcessor cp(converter, clauseList);
   cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
@@ -2657,11 +2701,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
                 mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
+  cp.processIsDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
+                        useDeviceSymbols);
+  cp.processHasDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
+                          useDeviceSymbols);
   cp.processTODO<Fortran::parser::OmpClause::Private,
                  Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
-                 Fortran::parser::OmpClause::IsDevicePtr,
-                 Fortran::parser::OmpClause::HasDeviceAddr,
                  Fortran::parser::OmpClause::Reduction,
                  Fortran::parser::OmpClause::InReduction,
                  Fortran::parser::OmpClause::Allocate,
@@ -2736,7 +2782,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      nowaitAttr, mapOperands);
+      nowaitAttr, devicePtrOperands, deviceAddrOperands, mapOperands);
 
   genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
                     mapSymbols, currentLocation);
@@ -3132,6 +3178,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
         !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
       TODO(clauseLocation, "OpenMP Block construct clause");
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 53bdf57ff8efa5a..5d5ccceca1b2cab 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2963,11 +2963,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::IsDevicePtr &x) {
         context_.Say(itr->second->source,
             "Variable '%s' in IS_DEVICE_PTR clause must be of type C_PTR"_err_en_US,
             source.ToString());
-      } else if (!(IsDummy(*symbol))) {
-        context_.Say(itr->second->source,
-            "Variable '%s' in IS_DEVICE_PTR clause must be a dummy argument"_err_en_US,
-            source.ToString());
-      } else if (IsAllocatableOrPointer(*symbol) || IsValue(*symbol)) {
+      } else if (IsDummy(*symbol) &&
+          (IsAllocatableOrPointer(*symbol) || IsValue(*symbol))) {
         context_.Say(itr->second->source,
             "Variable '%s' in IS_DEVICE_PTR clause must be a dummy argument that does not have the ALLOCATABLE, POINTER or VALUE attribute."_err_en_US,
             source.ToString());
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 2034ac84334e546..8f38261fc1aa62f 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -411,4 +411,43 @@ subroutine omp_target_parallel_do
    !CHECK: omp.terminator
    !CHECK: }
    !$omp end target parallel do
- end subroutine omp_target_parallel_do
+end subroutine omp_target_parallel_do
+
+!===============================================================================
+! Target `is_device_ptr` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
+subroutine omp_target_is_device_ptr
+   use iso_c_binding, only : c_ptr, c_loc
+   !CHECK: %[[DEV_PTR:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
+   type(c_ptr) :: a
+   !CHECK %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
+   integer, target :: b
+   !CHECK: %[[MAP_0:.*]] = omp.map_info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
+   !CHECK: %[[MAP_1:.*]] = omp.map_info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32)   map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
+   !CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]], %[[MAP_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+   !$omp target map(tofrom: a,b) is_device_ptr(a)
+      !CHECK: {{.*}} = fir.coordinate_of %[[DEV_PTR:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+      a = c_loc(b)
+   !CHECK: omp.terminator
+   !$omp end target
+   !CHECK: }
+end subroutine omp_target_is_device_ptr
+
+ !===============================================================================
+ ! Target `has_device_addr` clause
+ !===============================================================================
+
+ !CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
+ subroutine omp_target_has_device_addr
+   integer, pointer :: a
+   !CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
+   !CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+   !$omp target has_device_addr(a)
+   !CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+      a = 10
+   !CHECK: omp.terminator
+   !$omp end target
+   !CHECK: }
+end subroutine omp_target_has_device_addr
diff --git a/flang/test/Semantics/OpenMP/target01.f90 b/flang/test/Semantics/OpenMP/target01.f90
index 485fa1f2530c3b7..2ce9a1af7cc80cd 100644
--- a/flang/test/Semantics/OpenMP/target01.f90
+++ b/flang/test/Semantics/OpenMP/target01.f90
@@ -39,7 +39,6 @@ subroutine bar(b1, b2, b3)
   type(c_ptr), pointer :: b2
   type(c_ptr), value :: b3
 
-  !ERROR: Variable 'c' in IS_DEVICE_PTR clause must be a dummy argument
   !$omp target is_device_ptr(c)
     y = y + 1
   !$omp end target
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad453..4946cfe6880e382 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1389,10 +1389,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
 
     The optional $thread_limit specifies the limit on the number of threads
 
-    The optional $nowait elliminates the implicit barrier so the parent task can make progress
+    The optional $nowait eliminates the implicit barrier so the parent task can make progress
     even if the target task is not yet completed.
 
-    TODO:  is_device_ptr, depend, defaultmap, in_reduction
+    The optional $is_device_ptr indicates list items are device pointers
+
+    The optional $use_device_addr indicates that list items already have device
+    addresses, so may be directly accessed from target device. May include array
+    sections.
+
+    The optional $map_operands maps data from the task’s environment to the
+    device environment.
+
+    TODO:  depend, defaultmap, in_reduction
 
   }];
 
@@ -1400,8 +1409,9 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
                        Optional<AnyInteger>:$device,
                        Optional<AnyInteger>:$thread_limit,
                        UnitAttr:$nowait,
+                       Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
+                       Variadic<OpenMP_PointerLikeType>:$has_device_addr,
                        Variadic<AnyType>:$map_operands);
-
   let regions = (region AnyRegion:$region);
 
   let assemblyFormat = [{
@@ -1409,6 +1419,8 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
     | `device` `(` $device `:` type($device) `)`
     | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
     | `nowait` $nowait
+    | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
+    | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
     ) $region attr-dict
   }];
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4d88d9ac86fe16c..b153b1b8221d807 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -480,22 +480,22 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
 }
 
 // CHECK-LABEL: omp_target
-func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
+func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
 
     // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
     // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
     // CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
+    // CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
     %mapv1 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
     %mapv2 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>)   map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
+    omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
     ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
       omp.terminator
     }



More information about the flang-commits mailing list