[flang-commits] [flang] 22209a6 - [flang][openacc] Keep routine information in the module file

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Wed Aug 23 08:58:11 PDT 2023


Author: Valentin Clement
Date: 2023-08-23T08:56:55-07:00
New Revision: 22209a673ec278d3dcbd10c53be3b646bc0bbb7a

URL: https://github.com/llvm/llvm-project/commit/22209a673ec278d3dcbd10c53be3b646bc0bbb7a
DIFF: https://github.com/llvm/llvm-project/commit/22209a673ec278d3dcbd10c53be3b646bc0bbb7a.diff

LOG: [flang][openacc] Keep routine information in the module file

This patch propagates the acc routine information
to the module file so they can be used by the caller.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D158541

Added: 
    

Modified: 
    flang/include/flang/Semantics/symbol.h
    flang/include/flang/Semantics/tools.h
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Semantics/mod-file.cpp
    flang/lib/Semantics/resolve-directives.cpp
    flang/test/Semantics/OpenACC/acc-module.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 93ed272149f307..98ea0adc829324 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -82,6 +82,35 @@ class WithBindName {
   bool isExplicitBindName_{false};
 };
 
+class OpenACCRoutineInfo {
+public:
+  bool isSeq() const { return isSeq_; }
+  void set_isSeq(bool value = true) { isSeq_ = value; }
+  bool isVector() const { return isVector_; }
+  void set_isVector(bool value = true) { isVector_ = value; }
+  bool isWorker() const { return isWorker_; }
+  void set_isWorker(bool value = true) { isWorker_ = value; }
+  bool isGang() const { return isGang_; }
+  void set_isGang(bool value = true) { isGang_ = value; }
+  unsigned gangDim() const { return gangDim_; }
+  void set_gangDim(unsigned value) { gangDim_ = value; }
+  bool isNohost() const { return isNohost_; }
+  void set_isNohost(bool value = true) { isNohost_ = value; }
+  const std::string *bindName() const {
+    return bindName_ ? &*bindName_ : nullptr;
+  }
+  void set_bindName(std::string &&name) { bindName_ = std::move(name); }
+
+private:
+  bool isSeq_{false};
+  bool isVector_{false};
+  bool isWorker_{false};
+  bool isGang_{false};
+  unsigned gangDim_{0};
+  bool isNohost_{false};
+  std::optional<std::string> bindName_;
+};
+
 // A subroutine or function definition, or a subprogram interface defined
 // in an INTERFACE block as part of the definition of a dummy procedure
 // or a procedure pointer (with just POINTER).
@@ -137,6 +166,12 @@ class SubprogramDetails : public WithBindName {
   void set_cudaClusterDims(std::vector<std::int64_t> &&x) {
     cudaClusterDims_ = std::move(x);
   }
+  const std::vector<OpenACCRoutineInfo> &openACCRoutineInfos() const {
+    return openACCRoutineInfos_;
+  }
+  void add_openACCRoutineInfo(OpenACCRoutineInfo info) {
+    openACCRoutineInfos_.push_back(info);
+  }
 
 private:
   bool isInterface_{false}; // true if this represents an interface-body
@@ -154,6 +189,8 @@ class SubprogramDetails : public WithBindName {
   std::optional<common::CUDASubprogramAttrs> cudaSubprogramAttrs_;
   // CUDA LAUNCH_BOUNDS(...) & CLUSTER_DIMS(...) from prefix
   std::vector<std::int64_t> cudaLaunchBounds_, cudaClusterDims_;
+  // OpenACC routine information
+  std::vector<OpenACCRoutineInfo> openACCRoutineInfos_;
 
   friend llvm::raw_ostream &operator<<(
       llvm::raw_ostream &, const SubprogramDetails &);

diff  --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index acd34e9781ee51..0eed9937e7d78e 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -664,5 +664,20 @@ inline const parser::Name *getDesignatorNameIfDataRef(
 
 bool CouldBeDataPointerValuedFunction(const Symbol *);
 
+template <typename R, typename T>
+std::optional<R> GetConstExpr(
+    Fortran::semantics::SemanticsContext &semanticsContext, const T &x) {
+  using DefaultCharConstantType = Fortran::evaluate::Ascii;
+  if (const auto *expr{Fortran::semantics::GetExpr(semanticsContext, x)}) {
+    const auto foldExpr{Fortran::evaluate::Fold(
+        semanticsContext.foldingContext(), Fortran::common::Clone(*expr))};
+    if constexpr (std::is_same_v<R, std::string>) {
+      return Fortran::evaluate::GetScalarConstantValue<DefaultCharConstantType>(
+          foldExpr);
+    }
+  }
+  return std::nullopt;
+}
+
 } // namespace Fortran::semantics
 #endif // FORTRAN_SEMANTICS_TOOLS_H_

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ce7635a5c1d7e6..1aa97efe3aadc9 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2929,22 +2929,6 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   llvm_unreachable("unsupported declarative directive");
 }
 
-template <typename R, typename T>
-std::optional<R>
-GetConstExpr(Fortran::semantics::SemanticsContext &semanticsContext,
-             const T &x) {
-  using DefaultCharConstantType = Fortran::evaluate::Ascii;
-  if (const auto *expr{Fortran::semantics::GetExpr(semanticsContext, x)}) {
-    const auto foldExpr{Fortran::evaluate::Fold(
-        semanticsContext.foldingContext(), Fortran::common::Clone(*expr))};
-    if constexpr (std::is_same_v<R, std::string>) {
-      return Fortran::evaluate::GetScalarConstantValue<DefaultCharConstantType>(
-          foldExpr);
-    }
-  }
-  return std::nullopt;
-}
-
 static void attachRoutineInfo(mlir::func::FuncOp func,
                               mlir::SymbolRefAttr routineAttr) {
   llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3030,7 +3014,8 @@ genACC(Fortran::lower::AbstractConverter &converter,
                      std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
                          &bindClause->v.u)) {
         const std::optional<std::string> bindName =
-            GetConstExpr<std::string>(semanticsContext, *charExpr);
+            Fortran::semantics::GetConstExpr<std::string>(semanticsContext,
+                                                          *charExpr);
         if (!bindName)
           routineOp.emitError("Could not retrieve the bind name");
         routineOp.setBindName(builder.getStringAttr(*bindName));

diff  --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 6671151777af46..3925f3b0ef0335 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -412,6 +412,35 @@ void ModFileWriter::PutDECStructure(
 static const Attrs subprogramPrefixAttrs{Attr::ELEMENTAL, Attr::IMPURE,
     Attr::MODULE, Attr::NON_RECURSIVE, Attr::PURE, Attr::RECURSIVE};
 
+static void PutOpenACCRoutineInfo(
+    llvm::raw_ostream &os, const SubprogramDetails &details) {
+  for (auto info : details.openACCRoutineInfos()) {
+    os << "!$acc routine";
+    if (info.isSeq()) {
+      os << " seq";
+    }
+    if (info.isGang()) {
+      os << " gang";
+      if (info.gangDim() > 0) {
+        os << "(dim: " << info.gangDim() << ")";
+      }
+    }
+    if (info.isVector()) {
+      os << " vector";
+    }
+    if (info.isWorker()) {
+      os << " worker";
+    }
+    if (info.isNohost()) {
+      os << " nohost";
+    }
+    if (info.bindName()) {
+      os << " bind(" << *info.bindName() << ")";
+    }
+    os << "\n";
+  }
+}
+
 void ModFileWriter::PutSubprogram(const Symbol &symbol) {
   auto &details{symbol.get<SubprogramDetails>()};
   if (const Symbol * interface{details.moduleInterface()}) {
@@ -513,6 +542,7 @@ void ModFileWriter::PutSubprogram(const Symbol &symbol) {
     decls_ << "import::" << import << "\n";
   }
   os << writer.decls_.str();
+  PutOpenACCRoutineInfo(os, details);
   os << "end\n";
   if (isInterface) {
     os << "end interface\n";

diff  --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 69cd6930c3cf70..d758e6ad61139a 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -20,6 +20,7 @@
 #include "flang/Semantics/expression.h"
 #include <list>
 #include <map>
+#include <sstream>
 
 namespace Fortran::semantics {
 
@@ -273,6 +274,8 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
   void DoNotAllowAssumedSizedArray(const parser::AccObjectList &objectList);
   void EnsureAllocatableOrPointer(
       const llvm::acc::Clause clause, const parser::AccObjectList &objectList);
+  void AddRoutineInfoToSymbol(
+      Symbol &, const parser::OpenACCRoutineConstruct &);
 };
 
 // Data-sharing and Data-mapping attributes for data-refs in OpenMP construct
@@ -832,14 +835,88 @@ Symbol *AccAttributeVisitor::ResolveName(
   return prev;
 }
 
+template <typename T>
+common::IfNoLvalue<T, T> FoldExpr(
+    evaluate::FoldingContext &foldingContext, T &&expr) {
+  return evaluate::Fold(foldingContext, std::move(expr));
+}
+
+template <typename T>
+MaybeExpr EvaluateExpr(
+    Fortran::semantics::SemanticsContext &semanticsContext, const T &expr) {
+  return FoldExpr(
+      semanticsContext.foldingContext(), AnalyzeExpr(semanticsContext, expr));
+}
+
+void AccAttributeVisitor::AddRoutineInfoToSymbol(
+    Symbol &symbol, const parser::OpenACCRoutineConstruct &x) {
+  if (symbol.has<SubprogramDetails>()) {
+    Fortran::semantics::OpenACCRoutineInfo info;
+    const auto &clauses = std::get<Fortran::parser::AccClauseList>(x.t);
+    for (const Fortran::parser::AccClause &clause : clauses.v) {
+      if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
+        info.set_isSeq();
+      } else if (const auto *gangClause =
+                     std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
+        info.set_isGang();
+        if (gangClause->v) {
+          const Fortran::parser::AccGangArgList &x = *gangClause->v;
+          for (const Fortran::parser::AccGangArg &gangArg : x.v) {
+            if (const auto *dim =
+                    std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
+              if (const auto v{EvaluateInt64(context_, dim->v)}) {
+                info.set_gangDim(*v);
+              }
+            }
+          }
+        }
+      } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
+        info.set_isVector();
+      } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
+        info.set_isWorker();
+      } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
+        info.set_isNohost();
+      } else if (const auto *bindClause =
+                     std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
+        if (const auto *name =
+                std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
+          if (Symbol *sym = ResolveName(*name, true)) {
+            info.set_bindName(sym->name().ToString());
+          } else {
+            context_.Say((*name).source,
+                "No function or subroutine declared for '%s'"_err_en_US,
+                (*name).source);
+          }
+        } else if (const auto charExpr =
+                       std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
+                           &bindClause->v.u)) {
+          auto *charConst =
+              Fortran::parser::Unwrap<Fortran::parser::CharLiteralConstant>(
+                  *charExpr);
+          std::string str{std::get<std::string>(charConst->t)};
+          std::stringstream bindName;
+          bindName << "\"" << str << "\"";
+          info.set_bindName(std::move(bindName.str()));
+        }
+      }
+    }
+    symbol.get<SubprogramDetails>().add_openACCRoutineInfo(info);
+  }
+}
+
 bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
   const auto &optName{std::get<std::optional<parser::Name>>(x.t)};
   if (optName) {
-    if (!ResolveName(*optName, true)) {
+    if (Symbol *sym = ResolveName(*optName, true)) {
+      Symbol &ultimate{sym->GetUltimate()};
+      AddRoutineInfoToSymbol(ultimate, x);
+    } else {
       context_.Say((*optName).source,
           "No function or subroutine declared for '%s'"_err_en_US,
           (*optName).source);
     }
+  } else {
+    AddRoutineInfoToSymbol(*currScope().symbol(), x);
   }
   return true;
 }

diff  --git a/flang/test/Semantics/OpenACC/acc-module.f90 b/flang/test/Semantics/OpenACC/acc-module.f90
index d618650f5c439a..f552816d698823 100644
--- a/flang/test/Semantics/OpenACC/acc-module.f90
+++ b/flang/test/Semantics/OpenACC/acc-module.f90
@@ -15,6 +15,51 @@ module acc_mod
 
   integer :: data_link(50)
   !$acc declare link(data_link)
+
+  !$acc routine(sub10) seq
+
+contains
+  subroutine sub1()
+    !$acc routine
+  end subroutine
+
+  subroutine sub2()
+    !$acc routine seq
+  end subroutine
+
+  subroutine sub3()
+    !$acc routine gang
+  end subroutine
+
+  subroutine sub4()
+    !$acc routine vector
+  end subroutine
+
+  subroutine sub5()
+    !$acc routine worker
+  end subroutine
+
+  subroutine sub6()
+    !$acc routine gang(dim:2)
+  end subroutine
+
+  subroutine sub7()
+    !$acc routine bind("sub7_")
+  end subroutine
+
+  subroutine sub8()
+    !$acc routine bind(sub7)
+  end subroutine
+
+  subroutine sub9()
+    !$acc routine vector
+    !$acc routine seq bind(sub7)
+    !$acc routine gang bind(sub8)
+  end subroutine
+
+  subroutine sub10()
+  end subroutine
+
 end module
 
 !Expect: acc_mod.mod
@@ -29,4 +74,37 @@ module acc_mod
 ! !$acc declare device_resident(data_device_resident)
 ! integer(4)::data_link(1_8:50_8)
 ! !$acc declare link(data_link)
+! contains
+! subroutine sub1()
+! !$acc routine
+! end
+! subroutine sub2()
+! !$acc routine seq
+! end
+! subroutine sub3()
+! !$acc routine gang
+! end
+! subroutine sub4()
+! !$acc routine vector
+! end
+! subroutine sub5()
+! !$acc routine worker
+! end
+! subroutine sub6()
+! !$acc routine gang(dim:2)
+! end
+! subroutine sub7()
+! !$acc routine bind("sub7_")
+! end
+! subroutine sub8()
+! !$acc routine bind(sub7)
+! end
+! subroutine sub9()
+! !$acc routine vector
+! !$acc routine seq bind(sub7)
+! !$acc routine gang bind(sub8)
+! end
+! subroutine sub10()
+! !$acc routine seq
+! end
 ! end


        


More information about the flang-commits mailing list