[flang-commits] [flang] [Flang] [OpenMP] [Semantics] Add semantic support for IS_DEVICE_PTR nd HAS_DEVICE_ADDR clauses on OMP TARGET directive and add more semantic checks for OMP TARGET. (PR #67290)

Raghu Maddhipatla via flang-commits flang-commits at lists.llvm.org
Mon Sep 25 00:44:17 PDT 2023


https://github.com/raghavendhra created https://github.com/llvm/llvm-project/pull/67290

Summary of this patch

- Add semantic support for HAS_DEVICE_ADDR and IS_DEVICE_PTR clauses.
- A list item that appears in an IS_DEVICE_PTR clause must be a valid device pointer for the device data environment.
- A list item may not be specified in both an IS_DEVICE_PTR clause and a HAS_DEVICE_ADDR clauses on the directive.
- A list item that appears in an IS_DEVICE_PTR or a HAS_DEVICE_ADDR clauses must not be specified in any data-sharing attribute clause on the same target construct.

>From 81f82c777f11c21571357aed07a5b1e1b9b4ad1f Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Mon, 25 Sep 2023 02:31:37 -0500
Subject: [PATCH] [Flang] [OpenMP] [Semantics] Add semantic support for
 IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET directive and add
 more semantic checks for OMP TARGET.

---
 flang/include/flang/Semantics/symbol.h      |  3 +-
 flang/lib/Lower/OpenMP.cpp                  | 51 +++++++++++++++++++-
 flang/lib/Parser/openmp-parsers.cpp         |  5 +-
 flang/lib/Semantics/check-omp-structure.cpp | 53 ++++++++++++++++++++-
 flang/lib/Semantics/resolve-directives.cpp  | 28 +++++++++--
 flang/lib/Semantics/symbol.cpp              |  6 +++
 flang/test/Semantics/OpenMP/target01.f90    | 29 +++++++++--
 llvm/include/llvm/Frontend/OpenMP/OMP.td    |  6 +--
 8 files changed, 163 insertions(+), 18 deletions(-)

diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index a5f4ad76c26b784..f6f195b6bb95b20 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -695,7 +695,8 @@ class Symbol {
       OmpShared, OmpPrivate, OmpLinear, OmpFirstPrivate, OmpLastPrivate,
       // OpenMP data-mapping attribute
       OmpMapTo, OmpMapFrom, OmpMapToFrom, OmpMapAlloc, OmpMapRelease,
-      OmpMapDelete, OmpUseDevicePtr, OmpUseDeviceAddr,
+      OmpMapDelete, OmpUseDevicePtr, OmpUseDeviceAddr, OmpIsDevicePtr,
+      OmpHasDeviceAddr,
       // OpenMP data-copying attribute
       OmpCopyIn, OmpCopyPrivate,
       // OpenMP miscellaneous flags
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 5f5e968eaaa6414..bbeded3d590379d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -557,7 +557,18 @@ class ClauseProcessor {
                       llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                           &useDeviceSymbols) const;
-
+  bool
+  processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+                     llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+                     llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+                     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                         &isDeviceSymbols) const;
+  bool
+  processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+                       llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+                       llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                           &isDeviceSymbols) const;
   // Call this method for these clauses that should be supported but are not
   // implemented yet. It triggers a compilation error if any of the given
   // clauses is found.
@@ -1824,6 +1835,34 @@ bool ClauseProcessor::processUseDevicePtr(
       });
 }
 
+bool ClauseProcessor::processIsDevicePtr(
+    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+    llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+    const {
+  return findRepeatableClause<ClauseTy::IsDevicePtr>(
+      [&](const ClauseTy::IsDevicePtr *devPtrClause,
+          const Fortran::parser::CharBlock &) {
+        addUseDeviceClause(converter, devPtrClause->v, operands, isDeviceTypes,
+                           isDeviceLocs, isDeviceSymbols);
+      });
+}
+
+bool ClauseProcessor::processHasDeviceAddr(
+    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+    llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+    const {
+  return findRepeatableClause<ClauseTy::HasDeviceAddr>(
+      [&](const ClauseTy::HasDeviceAddr *devAddrClause,
+          const Fortran::parser::CharBlock &) {
+        addUseDeviceClause(converter, devAddrClause->v, operands, isDeviceTypes,
+                           isDeviceLocs, isDeviceSymbols);
+      });
+}
+
 template <typename... Ts>
 void ClauseProcessor::processTODO(mlir::Location currentLocation,
                                   llvm::omp::Directive directive) const {
@@ -2411,6 +2450,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
   mlir::UnitAttr nowaitAttr;
   llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
+  llvm::SmallVector<mlir::Type> useDeviceTypes;
+  llvm::SmallVector<mlir::Location> useDeviceLocs;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
 
   ClauseProcessor cp(converter, clauseList);
   cp.processIf(stmtCtx,
@@ -2421,6 +2464,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
                 mapOperands);
+  cp.processIsDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
+                        useDeviceSymbols);
+  cp.processHasDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
+                          useDeviceSymbols);
   cp.processTODO<Fortran::parser::OmpClause::Private,
                  Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
@@ -2830,6 +2877,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
         !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
       TODO(clauseLocation, "OpenMP Block construct clause");
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index b220d578b3bbc7e..0271f699a687ff0 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -268,13 +268,16 @@ TYPE_PARSER(
                   parenthesized(Parser<OmpObjectList>{}))) ||
     "GRAINSIZE" >> construct<OmpClause>(construct<OmpClause::Grainsize>(
                        parenthesized(scalarIntExpr))) ||
+    "HAS_DEVICE_ADDR" >>
+        construct<OmpClause>(construct<OmpClause::HasDeviceAddr>(
+            parenthesized(Parser<OmpObjectList>{}))) ||
     "HINT" >> construct<OmpClause>(
                   construct<OmpClause::Hint>(parenthesized(constantExpr))) ||
     "IF" >> construct<OmpClause>(construct<OmpClause::If>(
                 parenthesized(Parser<OmpIfClause>{}))) ||
     "INBRANCH" >> construct<OmpClause>(construct<OmpClause::Inbranch>()) ||
     "IS_DEVICE_PTR" >> construct<OmpClause>(construct<OmpClause::IsDevicePtr>(
-                           parenthesized(nonemptyList(name)))) ||
+                           parenthesized(Parser<OmpObjectList>{}))) ||
     "LASTPRIVATE" >> construct<OmpClause>(construct<OmpClause::Lastprivate>(
                          parenthesized(Parser<OmpObjectList>{}))) ||
     "LINEAR" >> construct<OmpClause>(construct<OmpClause::Linear>(
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index ffd577aa203756d..7c5574541dc2a8e 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2109,8 +2109,6 @@ CHECK_SIMPLE_CLAUSE(Read, OMPC_read)
 CHECK_SIMPLE_CLAUSE(Threadprivate, OMPC_threadprivate)
 CHECK_SIMPLE_CLAUSE(Threads, OMPC_threads)
 CHECK_SIMPLE_CLAUSE(Inbranch, OMPC_inbranch)
-CHECK_SIMPLE_CLAUSE(IsDevicePtr, OMPC_is_device_ptr)
-CHECK_SIMPLE_CLAUSE(HasDeviceAddr, OMPC_has_device_addr)
 CHECK_SIMPLE_CLAUSE(Link, OMPC_link)
 CHECK_SIMPLE_CLAUSE(Indirect, OMPC_indirect)
 CHECK_SIMPLE_CLAUSE(Mergeable, OMPC_mergeable)
@@ -2827,6 +2825,57 @@ void OmpStructureChecker::Enter(const parser::OmpClause::UseDeviceAddr &x) {
   }
 }
 
+void OmpStructureChecker::Enter(const parser::OmpClause::IsDevicePtr &x) {
+  CheckAllowed(llvm::omp::Clause::OMPC_is_device_ptr);
+  SymbolSourceMap currSymbols;
+  GetSymbolsInObjectList(x.v, currSymbols);
+  semantics::UnorderedSymbolSet listVars;
+  auto isDevicePtrClauses{FindClauses(llvm::omp::Clause::OMPC_is_device_ptr)};
+  for (auto itr = isDevicePtrClauses.first; itr != isDevicePtrClauses.second;
+       ++itr) {
+    const auto &isDevicePtrClause{
+        std::get<parser::OmpClause::IsDevicePtr>(itr->second->u)};
+    const auto &isDevicePtrList{isDevicePtrClause.v};
+    std::list<parser::Name> isDevicePtrNameList;
+    for (const auto &ompObject : isDevicePtrList.v) {
+      if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
+        if (name->symbol) {
+          if (!(IsBuiltinCPtr(*(name->symbol)))) {
+            context_.Say(itr->second->source,
+                "Variable '%s' in IS_DEVICE_PTR clause must be of type C_PTR"_err_en_US,
+                name->ToString());
+          } else {
+            isDevicePtrNameList.push_back(*name);
+          }
+        }
+      }
+    }
+  }
+}
+
+void OmpStructureChecker::Enter(const parser::OmpClause::HasDeviceAddr &x) {
+  CheckAllowed(llvm::omp::Clause::OMPC_has_device_addr);
+  SymbolSourceMap currSymbols;
+  GetSymbolsInObjectList(x.v, currSymbols);
+  semantics::UnorderedSymbolSet listVars;
+  auto hasDeviceAddrClauses{
+      FindClauses(llvm::omp::Clause::OMPC_has_device_addr)};
+  for (auto itr = hasDeviceAddrClauses.first;
+       itr != hasDeviceAddrClauses.second; ++itr) {
+    const auto &hasDeviceAddrClause{
+        std::get<parser::OmpClause::HasDeviceAddr>(itr->second->u)};
+    const auto &hasDeviceAddrList{hasDeviceAddrClause.v};
+    std::list<parser::Name> hasDeviceAddrNameList;
+    for (const auto &ompObject : hasDeviceAddrList.v) {
+      if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
+        if (name->symbol) {
+          hasDeviceAddrNameList.push_back(*name);
+        }
+      }
+    }
+  }
+}
+
 llvm::StringRef OmpStructureChecker::getClauseName(llvm::omp::Clause clause) {
   return llvm::omp::getOpenMPClauseName(clause);
 }
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 63a3e1cd31549c0..794a78408863cb3 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -535,6 +535,16 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
     return false;
   }
 
+  bool Pre(const parser::OmpClause::IsDevicePtr &x) {
+    ResolveOmpObjectList(x.v, Symbol::Flag::OmpIsDevicePtr);
+    return false;
+  }
+
+  bool Pre(const parser::OmpClause::HasDeviceAddr &x) {
+    ResolveOmpObjectList(x.v, Symbol::Flag::OmpHasDeviceAddr);
+    return false;
+  }
+
   void Post(const parser::Name &);
 
   // Keep track of labels in the statements that causes jumps to target labels
@@ -634,7 +644,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
       Symbol::Flag::OmpLinear, Symbol::Flag::OmpFirstPrivate,
       Symbol::Flag::OmpLastPrivate, Symbol::Flag::OmpReduction,
       Symbol::Flag::OmpCriticalLock, Symbol::Flag::OmpCopyIn,
-      Symbol::Flag::OmpUseDevicePtr, Symbol::Flag::OmpUseDeviceAddr};
+      Symbol::Flag::OmpUseDevicePtr, Symbol::Flag::OmpUseDeviceAddr,
+      Symbol::Flag::OmpIsDevicePtr, Symbol::Flag::OmpHasDeviceAddr};
 
   Symbol::Flags ompFlagsRequireMark{
       Symbol::Flag::OmpThreadprivate, Symbol::Flag::OmpDeclareTarget};
@@ -2054,15 +2065,22 @@ void OmpAttributeVisitor::ResolveOmpObject(
                       symbol, Symbol::Flag::OmpLastPrivate);
                 }
                 if (llvm::omp::allTargetSet.test(GetContext().directive)) {
+                  checkExclusivelists(symbol, Symbol::Flag::OmpIsDevicePtr,
+                      symbol, Symbol::Flag::OmpHasDeviceAddr);
                   const auto *hostAssocSym{symbol};
-                  if (const auto *details{
-                          symbol->detailsIf<HostAssocDetails>()}) {
-                    hostAssocSym = &details->symbol();
+                  if (!(symbol->test(Symbol::Flag::OmpIsDevicePtr) ||
+                          symbol->test(Symbol::Flag::OmpHasDeviceAddr))) {
+                    if (const auto *details{
+                            symbol->detailsIf<HostAssocDetails>()}) {
+                      hostAssocSym = &details->symbol();
+                    }
                   }
                   Symbol::Flag dataMappingAttributeFlags[] = {
                       Symbol::Flag::OmpMapTo, Symbol::Flag::OmpMapFrom,
                       Symbol::Flag::OmpMapToFrom, Symbol::Flag::OmpMapAlloc,
-                      Symbol::Flag::OmpMapRelease, Symbol::Flag::OmpMapDelete};
+                      Symbol::Flag::OmpMapRelease, Symbol::Flag::OmpMapDelete,
+                      Symbol::Flag::OmpIsDevicePtr,
+                      Symbol::Flag::OmpHasDeviceAddr};
 
                   Symbol::Flag dataSharingAttributeFlags[] = {
                       Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate,
diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp
index f4edc8a08fe699d..2ce8e83e4771467 100644
--- a/flang/lib/Semantics/symbol.cpp
+++ b/flang/lib/Semantics/symbol.cpp
@@ -801,6 +801,12 @@ std::string Symbol::OmpFlagToClauseName(Symbol::Flag ompFlag) {
   case Symbol::Flag::OmpCopyPrivate:
     clauseName = "COPYPRIVATE";
     break;
+  case Symbol::Flag::OmpIsDevicePtr:
+    clauseName = "IS_DEVICE_PTR";
+    break;
+  case Symbol::Flag::OmpHasDeviceAddr:
+    clauseName = "HAS_DEVICE_ADDR";
+    break;
   default:
     clauseName = "";
     break;
diff --git a/flang/test/Semantics/OpenMP/target01.f90 b/flang/test/Semantics/OpenMP/target01.f90
index 8418381740250b0..d672b905a70ad2f 100644
--- a/flang/test/Semantics/OpenMP/target01.f90
+++ b/flang/test/Semantics/OpenMP/target01.f90
@@ -1,10 +1,31 @@
 ! RUN: %python %S/../test_errors.py %s %flang_fc1 -fopenmp
-! OpenMP Version 4.5
-
-integer :: x
+ 
+use iso_c_binding
+integer :: x,y
+type(C_PTR) :: b
 !ERROR: Variable 'x' may not appear on both MAP and PRIVATE clauses on a TARGET construct
 !$omp target map(x) private(x)
-x = x + 1
+  x = x + 1
+!$omp end target
+
+!ERROR: Variable 'y' in IS_DEVICE_PTR clause must be of type C_PTR
+!$omp target map(x) is_device_ptr(y)
+  x = x + 1
+!$omp end target
+
+!ERROR: Variable 'b' may not appear on both IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on a TARGET construct
+!$omp target map(x) is_device_ptr(b) has_device_addr(b)
+  x = x + 1
+!$omp end target
+
+!ERROR: Variable 'b' may not appear on both IS_DEVICE_PTR and PRIVATE clauses on a TARGET construct
+!$omp target map(x) is_device_ptr(b) private(b)
+  x = x + 1
+!$omp end target
+
+!ERROR: Variable 'y' may not appear on both HAS_DEVICE_ADDR and FIRSTPRIVATE clauses on a TARGET construct
+!$omp target map(x) has_device_addr(y) firstprivate(y)
+  y = y - 1
 !$omp end target
 
 end
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index e23ab9893ba9532..a9bfa2fe8293638 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -299,13 +299,11 @@ def OMPC_UseDevicePtr : Clause<"use_device_ptr"> {
 }
 def OMPC_IsDevicePtr : Clause<"is_device_ptr"> {
   let clangClass = "OMPIsDevicePtrClause";
-  let flangClass = "Name";
-  let isValueList = true;
+  let flangClass = "OmpObjectList";
 }
 def OMPC_HasDeviceAddr : Clause<"has_device_addr"> {
   let clangClass = "OMPHasDeviceAddrClause";
-  let flangClass = "Name";
-  let isValueList = true;
+  let flangClass = "OmpObjectList";
 }
 def OMPC_TaskReduction : Clause<"task_reduction"> {
   let clangClass = "OMPTaskReductionClause";



More information about the flang-commits mailing list