[llvm-branch-commits] [flang] [flang][openacc] Carry device dependent info for acc routine in the module file (PR #77804)

Valentin Clement バレンタイン クレメン via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 11 09:37:48 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/77804

- Move the DeviceType enumeratoon to `Fortran::common` so it can be used outside of parser only. 
- Store the device_type dependent information on the symbol and reproduce them in the module file. 

>From d417332a130814e34fcc2448430d71b0fd5376b5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 10 Jan 2024 21:26:53 -0800
Subject: [PATCH] [flang][openacc] Carry device dependent info for routine in
 the module file

---
 flang/include/flang/Common/Fortran.h         |  4 ++
 flang/include/flang/Parser/dump-parse-tree.h |  3 +-
 flang/include/flang/Parser/parse-tree.h      |  4 +-
 flang/include/flang/Semantics/symbol.h       | 29 +++++++++--
 flang/lib/Lower/OpenACC.cpp                  | 16 +++---
 flang/lib/Parser/openacc-parsers.cpp         | 14 ++---
 flang/lib/Semantics/mod-file.cpp             | 54 ++++++++++++++------
 flang/lib/Semantics/resolve-directives.cpp   | 50 +++++++++++++++---
 flang/test/Semantics/OpenACC/acc-module.f90  | 13 +++++
 9 files changed, 142 insertions(+), 45 deletions(-)

diff --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h
index 4007bfc7994f98..1d3a85e2500733 100644
--- a/flang/include/flang/Common/Fortran.h
+++ b/flang/include/flang/Common/Fortran.h
@@ -87,6 +87,10 @@ ENUM_CLASS(CUDASubprogramAttrs, Host, Device, HostDevice, Global, Grid_Global)
 // CUDA data attributes; mutually exclusive
 ENUM_CLASS(CUDADataAttr, Constant, Device, Managed, Pinned, Shared, Texture)
 
+// OpenACC device types
+ENUM_CLASS(
+    OpenACCDeviceType, Star, Default, Nvidia, Radeon, Host, Multicore, None)
+
 // OpenMP atomic_default_mem_order clause allowed values
 ENUM_CLASS(OmpAtomicDefaultMemOrderType, SeqCst, AcqRel, Relaxed)
 
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 1defbf132327c4..d067a7273540f0 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -48,6 +48,7 @@ class ParseTreeDumper {
   NODE(std, uint64_t)
   NODE_ENUM(common, CUDADataAttr)
   NODE_ENUM(common, CUDASubprogramAttrs)
+  NODE_ENUM(common, OpenACCDeviceType)
   NODE(format, ControlEditDesc)
   NODE(format::ControlEditDesc, Kind)
   NODE(format, DerivedTypeDataEditDesc)
@@ -101,7 +102,7 @@ class ParseTreeDumper {
   NODE(parser, AccSelfClause)
   NODE(parser, AccStandaloneDirective)
   NODE(parser, AccDeviceTypeExpr)
-  NODE_ENUM(parser::AccDeviceTypeExpr, Device)
+
   NODE(parser, AccDeviceTypeExprList)
   NODE(parser, AccTileExpr)
   NODE(parser, AccTileExprList)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 71195f2bb9ddc4..e9bfb728a2bef6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4072,8 +4072,8 @@ struct AccWaitArgument {
 };
 
 struct AccDeviceTypeExpr {
-  ENUM_CLASS(Device, Star, Default, Nvidia, Radeon, Host, Multicore)
-  WRAPPER_CLASS_BOILERPLATE(AccDeviceTypeExpr, Device);
+  WRAPPER_CLASS_BOILERPLATE(
+      AccDeviceTypeExpr, Fortran::common::OpenACCDeviceType);
   CharBlock source;
 };
 
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index f6f195b6bb95b2..2b6b18e40e6528 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -112,7 +112,8 @@ class WithBindName {
   bool isExplicitBindName_{false};
 };
 
-class OpenACCRoutineInfo {
+// Device type specific OpenACC routine information
+class OpenACCRoutineDeviceTypeInfo {
 public:
   bool isSeq() const { return isSeq_; }
   void set_isSeq(bool value = true) { isSeq_ = value; }
@@ -124,12 +125,12 @@ class OpenACCRoutineInfo {
   void set_isGang(bool value = true) { isGang_ = value; }
   unsigned gangDim() const { return gangDim_; }
   void set_gangDim(unsigned value) { gangDim_ = value; }
-  bool isNohost() const { return isNohost_; }
-  void set_isNohost(bool value = true) { isNohost_ = value; }
   const std::string *bindName() const {
     return bindName_ ? &*bindName_ : nullptr;
   }
   void set_bindName(std::string &&name) { bindName_ = std::move(name); }
+  void set_dType(Fortran::common::OpenACCDeviceType dType) { dType_ = dType; }
+  Fortran::common::OpenACCDeviceType dType() const { return dType_; }
 
 private:
   bool isSeq_{false};
@@ -137,8 +138,28 @@ class OpenACCRoutineInfo {
   bool isWorker_{false};
   bool isGang_{false};
   unsigned gangDim_{0};
-  bool isNohost_{false};
   std::optional<std::string> bindName_;
+  Fortran::common::OpenACCDeviceType dType_{
+      Fortran::common::OpenACCDeviceType::None};
+};
+
+// OpenACC routine information. Device independent info are stored on the
+// OpenACCRoutineInfo instance while device dependent info are stored
+// in as objects in the OpenACCRoutineDeviceTypeInfo list.
+class OpenACCRoutineInfo : public OpenACCRoutineDeviceTypeInfo {
+public:
+  bool isNohost() const { return isNohost_; }
+  void set_isNohost(bool value = true) { isNohost_ = value; }
+  std::list<OpenACCRoutineDeviceTypeInfo> &deviceTypeInfos() {
+    return deviceTypeInfos_;
+  }
+  void add_deviceTypeInfo(OpenACCRoutineDeviceTypeInfo &info) {
+    deviceTypeInfos_.push_back(info);
+  }
+
+private:
+  std::list<OpenACCRoutineDeviceTypeInfo> deviceTypeInfos_;
+  bool isNohost_{false};
 };
 
 // A subroutine or function definition, or a subprogram interface defined
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index d24c369d81bed4..db9ed72bc87257 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1483,20 +1483,22 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter,
 }
 
 static mlir::acc::DeviceType
-getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
+getDeviceType(Fortran::common::OpenACCDeviceType device) {
   switch (device) {
-  case Fortran::parser::AccDeviceTypeExpr::Device::Star:
+  case Fortran::common::OpenACCDeviceType::Star:
     return mlir::acc::DeviceType::Star;
-  case Fortran::parser::AccDeviceTypeExpr::Device::Default:
+  case Fortran::common::OpenACCDeviceType::Default:
     return mlir::acc::DeviceType::Default;
-  case Fortran::parser::AccDeviceTypeExpr::Device::Nvidia:
+  case Fortran::common::OpenACCDeviceType::Nvidia:
     return mlir::acc::DeviceType::Nvidia;
-  case Fortran::parser::AccDeviceTypeExpr::Device::Radeon:
+  case Fortran::common::OpenACCDeviceType::Radeon:
     return mlir::acc::DeviceType::Radeon;
-  case Fortran::parser::AccDeviceTypeExpr::Device::Host:
+  case Fortran::common::OpenACCDeviceType::Host:
     return mlir::acc::DeviceType::Host;
-  case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
+  case Fortran::common::OpenACCDeviceType::Multicore:
     return mlir::acc::DeviceType::Multicore;
+  case Fortran::common::OpenACCDeviceType::None:
+    return mlir::acc::DeviceType::None;
   }
   return mlir::acc::DeviceType::None;
 }
diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp
index 5b9267e0e17c6d..946b33d0084a96 100644
--- a/flang/lib/Parser/openacc-parsers.cpp
+++ b/flang/lib/Parser/openacc-parsers.cpp
@@ -54,13 +54,13 @@ TYPE_PARSER(construct<AccSizeExpr>(scalarIntExpr) ||
 TYPE_PARSER(construct<AccSizeExprList>(nonemptyList(Parser<AccSizeExpr>{})))
 
 TYPE_PARSER(sourced(construct<AccDeviceTypeExpr>(
-    first("*" >> pure(AccDeviceTypeExpr::Device::Star),
-        "DEFAULT" >> pure(AccDeviceTypeExpr::Device::Default),
-        "NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
-        "ACC_DEVICE_NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
-        "RADEON" >> pure(AccDeviceTypeExpr::Device::Radeon),
-        "HOST" >> pure(AccDeviceTypeExpr::Device::Host),
-        "MULTICORE" >> pure(AccDeviceTypeExpr::Device::Multicore)))))
+    first("*" >> pure(Fortran::common::OpenACCDeviceType::Star),
+        "DEFAULT" >> pure(Fortran::common::OpenACCDeviceType::Default),
+        "NVIDIA" >> pure(Fortran::common::OpenACCDeviceType::Nvidia),
+        "ACC_DEVICE_NVIDIA" >> pure(Fortran::common::OpenACCDeviceType::Nvidia),
+        "RADEON" >> pure(Fortran::common::OpenACCDeviceType::Radeon),
+        "HOST" >> pure(Fortran::common::OpenACCDeviceType::Host),
+        "MULTICORE" >> pure(Fortran::common::OpenACCDeviceType::Multicore)))))
 
 TYPE_PARSER(
     construct<AccDeviceTypeExprList>(nonemptyList(Parser<AccDeviceTypeExpr>{})))
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 70b6bbf8b557ac..e87125580cc1d4 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -491,31 +491,51 @@ void ModFileWriter::PutDECStructure(
 static const Attrs subprogramPrefixAttrs{Attr::ELEMENTAL, Attr::IMPURE,
     Attr::MODULE, Attr::NON_RECURSIVE, Attr::PURE, Attr::RECURSIVE};
 
+static void PutOpenACCDeviceTypeRoutineInfo(
+    llvm::raw_ostream &os, const OpenACCRoutineDeviceTypeInfo &info) {
+  if (info.isSeq()) {
+    os << " seq";
+  }
+  if (info.isGang()) {
+    os << " gang";
+    if (info.gangDim() > 0) {
+      os << "(dim: " << info.gangDim() << ")";
+    }
+  }
+  if (info.isVector()) {
+    os << " vector";
+  }
+  if (info.isWorker()) {
+    os << " worker";
+  }
+  if (info.bindName()) {
+    os << " bind(" << *info.bindName() << ")";
+  }
+}
+
 static void PutOpenACCRoutineInfo(
     llvm::raw_ostream &os, const SubprogramDetails &details) {
   for (auto info : details.openACCRoutineInfos()) {
     os << "!$acc routine";
-    if (info.isSeq()) {
-      os << " seq";
-    }
-    if (info.isGang()) {
-      os << " gang";
-      if (info.gangDim() > 0) {
-        os << "(dim: " << info.gangDim() << ")";
-      }
-    }
-    if (info.isVector()) {
-      os << " vector";
-    }
-    if (info.isWorker()) {
-      os << " worker";
-    }
+
+    PutOpenACCDeviceTypeRoutineInfo(os, info);
+
     if (info.isNohost()) {
       os << " nohost";
     }
-    if (info.bindName()) {
-      os << " bind(" << *info.bindName() << ")";
+
+    for (auto dtype : info.deviceTypeInfos()) {
+      os << " device_type(";
+      if (dtype.dType() == common::OpenACCDeviceType::Star) {
+        os << "*";
+      } else {
+        os << parser::ToLowerCaseLetters(common::EnumToString(dtype.dType()));
+      }
+      os << ")";
+
+      PutOpenACCDeviceTypeRoutineInfo(os, dtype);
     }
+
     os << "\n";
   }
 }
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index da6c865ad56a3b..d7b13631ab4df0 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -945,25 +945,45 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
     const auto &clauses = std::get<Fortran::parser::AccClauseList>(x.t);
     for (const Fortran::parser::AccClause &clause : clauses.v) {
       if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-        info.set_isSeq();
+        if (info.deviceTypeInfos().empty()) {
+          info.set_isSeq();
+        } else {
+          info.deviceTypeInfos().back().set_isSeq();
+        }
       } else if (const auto *gangClause =
                      std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-        info.set_isGang();
+        if (info.deviceTypeInfos().empty()) {
+          info.set_isGang();
+        } else {
+          info.deviceTypeInfos().back().set_isGang();
+        }
         if (gangClause->v) {
           const Fortran::parser::AccGangArgList &x = *gangClause->v;
           for (const Fortran::parser::AccGangArg &gangArg : x.v) {
             if (const auto *dim =
                     std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
               if (const auto v{EvaluateInt64(context_, dim->v)}) {
-                info.set_gangDim(*v);
+                if (info.deviceTypeInfos().empty()) {
+                  info.set_gangDim(*v);
+                } else {
+                  info.deviceTypeInfos().back().set_gangDim(*v);
+                }
               }
             }
           }
         }
       } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-        info.set_isVector();
+        if (info.deviceTypeInfos().empty()) {
+          info.set_isVector();
+        } else {
+          info.deviceTypeInfos().back().set_isVector();
+        }
       } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-        info.set_isWorker();
+        if (info.deviceTypeInfos().empty()) {
+          info.set_isWorker();
+        } else {
+          info.deviceTypeInfos().back().set_isWorker();
+        }
       } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
         info.set_isNohost();
       } else if (const auto *bindClause =
@@ -971,7 +991,12 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
         if (const auto *name =
                 std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
           if (Symbol *sym = ResolveFctName(*name)) {
-            info.set_bindName(sym->name().ToString());
+            if (info.deviceTypeInfos().empty()) {
+              info.set_bindName(sym->name().ToString());
+            } else {
+              info.deviceTypeInfos().back().set_bindName(
+                  sym->name().ToString());
+            }
           } else {
             context_.Say((*name).source,
                 "No function or subroutine declared for '%s'"_err_en_US,
@@ -986,8 +1011,19 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
           std::string str{std::get<std::string>(charConst->t)};
           std::stringstream bindName;
           bindName << "\"" << str << "\"";
-          info.set_bindName(bindName.str());
+          if (info.deviceTypeInfos().empty()) {
+            info.set_bindName(bindName.str());
+          } else {
+            info.deviceTypeInfos().back().set_bindName(bindName.str());
+          }
         }
+      } else if (const auto *dType =
+                     std::get_if<Fortran::parser::AccClause::DeviceType>(
+                         &clause.u)) {
+        const parser::AccDeviceTypeExprList &deviceTypeExprList = dType->v;
+        OpenACCRoutineDeviceTypeInfo dtypeInfo;
+        dtypeInfo.set_dType(deviceTypeExprList.v.front().v);
+        info.add_deviceTypeInfo(dtypeInfo);
       }
     }
     symbol.get<SubprogramDetails>().add_openACCRoutineInfo(info);
diff --git a/flang/test/Semantics/OpenACC/acc-module.f90 b/flang/test/Semantics/OpenACC/acc-module.f90
index f552816d698823..7f034d8ae54f0a 100644
--- a/flang/test/Semantics/OpenACC/acc-module.f90
+++ b/flang/test/Semantics/OpenACC/acc-module.f90
@@ -60,6 +60,13 @@ subroutine sub9()
   subroutine sub10()
   end subroutine
 
+  subroutine sub11()
+    !$acc routine device_type(nvidia) gang device_type(*) seq
+  end subroutine
+
+  subroutine sub12()
+    !$acc routine device_type(host) bind(sub7) device_type(multicore) bind(sub8)
+  end subroutine
 end module
 
 !Expect: acc_mod.mod
@@ -107,4 +114,10 @@ subroutine sub10()
 ! subroutine sub10()
 ! !$acc routine seq
 ! end
+! subroutinesub11()
+! !$acc routine device_type(nvidia) gang device_type(*) seq
+! end
+! subroutinesub12()
+! !$acc routine device_type(host) bind(sub7) device_type(multicore) bind(sub8)
+! end
 ! end



More information about the llvm-branch-commits mailing list