[flang-commits] [flang] [flang][openacc] Allow open acc routines from other modules. (PR #136012)

Andre Kuhlenschmidt via flang-commits flang-commits at lists.llvm.org
Wed Apr 16 12:18:11 PDT 2025


https://github.com/akuhlens created https://github.com/llvm/llvm-project/pull/136012

OpenACC routines annotations in separate compilation units currently get ignored, which leads to errors in compilation. There are two reason for currently ignoring open acc routine information and this PR is addressing both.
- The module file reader doesn't read back in openacc directives from module files.
  - Simple fix in `flang/lib/Semantics/mod-file.cpp`
- The lowering to HLFIR doesn't generate routine directives for symbols imported from other modules that are openacc routines.
  - This is the majority of this diff, and is address by the changes that start in `flang/lib/Lower/CallInterface.cpp`. 

>From 0f4591ee621e2e9d7acb0e6066b556cb7e243162 Mon Sep 17 00:00:00 2001
From: Andre Kuhlenschmidt <akuhlenschmi at nvidia.com>
Date: Wed, 16 Apr 2025 12:01:24 -0700
Subject: [PATCH] initial commit

---
 flang/include/flang/Lower/AbstractConverter.h |   4 +
 flang/include/flang/Lower/OpenACC.h           |  10 +-
 flang/include/flang/Semantics/symbol.h        |  23 +-
 flang/lib/Lower/Bridge.cpp                    |   7 +-
 flang/lib/Lower/CallInterface.cpp             |  10 +
 flang/lib/Lower/OpenACC.cpp                   | 197 ++++++++++++++----
 flang/lib/Semantics/mod-file.cpp              |   1 +
 flang/lib/Semantics/resolve-directives.cpp    |  83 ++++----
 8 files changed, 233 insertions(+), 102 deletions(-)

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 1d1323642bf9c..59419e829718f 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -14,6 +14,7 @@
 #define FORTRAN_LOWER_ABSTRACTCONVERTER_H
 
 #include "flang/Lower/LoweringOptions.h"
+#include "flang/Lower/OpenACC.h"
 #include "flang/Lower/PFTDefs.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Dialect/FIRAttr.h"
@@ -357,6 +358,9 @@ class AbstractConverter {
   /// functions in order to be in sync).
   virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
 
+  virtual Fortran::lower::AccRoutineInfoMappingList &
+  getAccDelayedRoutines() = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 0d7038a7fd856..7832e8b69ea23 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -22,6 +22,9 @@ class StringRef;
 } // namespace llvm
 
 namespace mlir {
+namespace func {
+class FuncOp;
+}
 class Location;
 class Type;
 class ModuleOp;
@@ -42,6 +45,7 @@ struct OpenACCRoutineConstruct;
 } // namespace parser
 
 namespace semantics {
+class OpenACCRoutineInfo;
 class SemanticsContext;
 class Symbol;
 } // namespace semantics
@@ -79,8 +83,10 @@ void genOpenACCDeclarativeConstruct(AbstractConverter &,
 void genOpenACCRoutineConstruct(AbstractConverter &,
                                 Fortran::semantics::SemanticsContext &,
                                 mlir::ModuleOp,
-                                const parser::OpenACCRoutineConstruct &,
-                                AccRoutineInfoMappingList &);
+                                const parser::OpenACCRoutineConstruct &);
+void genOpenACCRoutineConstruct(
+    AbstractConverter &, mlir::ModuleOp, mlir::func::FuncOp,
+    const std::vector<Fortran::semantics::OpenACCRoutineInfo> &);
 
 void finalizeOpenACCRoutineAttachment(mlir::ModuleOp,
                                       AccRoutineInfoMappingList &);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 715811885c219..1b6b247c9f5bc 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -127,6 +127,8 @@ class WithBindName {
 // Device type specific OpenACC routine information
 class OpenACCRoutineDeviceTypeInfo {
 public:
+  OpenACCRoutineDeviceTypeInfo(Fortran::common::OpenACCDeviceType dType)
+      : deviceType_{dType} {}
   bool isSeq() const { return isSeq_; }
   void set_isSeq(bool value = true) { isSeq_ = value; }
   bool isVector() const { return isVector_; }
@@ -141,9 +143,7 @@ class OpenACCRoutineDeviceTypeInfo {
     return bindName_ ? &*bindName_ : nullptr;
   }
   void set_bindName(std::string &&name) { bindName_ = std::move(name); }
-  void set_dType(Fortran::common::OpenACCDeviceType dType) {
-    deviceType_ = dType;
-  }
+
   Fortran::common::OpenACCDeviceType dType() const { return deviceType_; }
 
 private:
@@ -162,13 +162,24 @@ class OpenACCRoutineDeviceTypeInfo {
 // in as objects in the OpenACCRoutineDeviceTypeInfo list.
 class OpenACCRoutineInfo : public OpenACCRoutineDeviceTypeInfo {
 public:
+  OpenACCRoutineInfo()
+      : OpenACCRoutineDeviceTypeInfo(Fortran::common::OpenACCDeviceType::None) {
+  }
   bool isNohost() const { return isNohost_; }
   void set_isNohost(bool value = true) { isNohost_ = value; }
-  std::list<OpenACCRoutineDeviceTypeInfo> &deviceTypeInfos() {
+  const std::list<OpenACCRoutineDeviceTypeInfo> &deviceTypeInfos() const {
     return deviceTypeInfos_;
   }
-  void add_deviceTypeInfo(OpenACCRoutineDeviceTypeInfo &info) {
-    deviceTypeInfos_.push_back(info);
+
+  OpenACCRoutineDeviceTypeInfo &add_deviceTypeInfo(
+      Fortran::common::OpenACCDeviceType type) {
+    return add_deviceTypeInfo(OpenACCRoutineDeviceTypeInfo(type));
+  }
+
+  OpenACCRoutineDeviceTypeInfo &add_deviceTypeInfo(
+      OpenACCRoutineDeviceTypeInfo &&info) {
+    deviceTypeInfos_.push_back(std::move(info));
+    return deviceTypeInfos_.back();
   }
 
 private:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index b4d1197822a43..9285d587585f8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -443,7 +443,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                     bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
                 Fortran::lower::genOpenACCRoutineConstruct(
                     *this, bridge.getSemanticsContext(), bridge.getModule(),
-                    d.routine, accRoutineInfos);
+                    d.routine);
                 builder = nullptr;
               },
           },
@@ -4287,6 +4287,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return Fortran::lower::createMutableBox(loc, *this, expr, localSymbols);
   }
 
+  Fortran::lower::AccRoutineInfoMappingList &
+  getAccDelayedRoutines() override final {
+    return accRoutineInfos;
+  }
+
   // Create the [newRank] array with the lower bounds to be passed to the
   // runtime as a descriptor.
   mlir::Value createLboundArray(llvm::ArrayRef<mlir::Value> lbounds,
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 226ba1e52c968..867248f16237e 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1689,6 +1689,16 @@ class SignatureBuilder
                           "SignatureBuilder should only be used once");
     declare();
     interfaceDetermined = true;
+    if (procDesignator && procDesignator->GetInterfaceSymbol() &&
+        procDesignator->GetInterfaceSymbol()
+            ->has<Fortran::semantics::SubprogramDetails>()) {
+      auto info = procDesignator->GetInterfaceSymbol()
+                      ->get<Fortran::semantics::SubprogramDetails>();
+      if (!info.openACCRoutineInfos().empty()) {
+        genOpenACCRoutineConstruct(converter, converter.getModuleOp(),
+                                   getFuncOp(), info.openACCRoutineInfos());
+      }
+    }
     return getFuncOp();
   }
 
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 3dd35ed9ae481..37b660408af6c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include <mlir/IR/MLIRContext.h>
 
 #define DEBUG_TYPE "flang-lower-openacc"
 
@@ -4139,11 +4140,152 @@ static void attachRoutineInfo(mlir::func::FuncOp func,
       mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
 }
 
+void createOpenACCRoutineConstruct(
+    Fortran::lower::AbstractConverter &converter, mlir::Location loc,
+    mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
+    bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
+    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &gangDimValues,
+    llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes) {
+
+  std::stringstream routineOpName;
+  routineOpName << accRoutinePrefix.str() << routineCounter++;
+
+  for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
+    if (routineOp.getFuncName().str().compare(funcName) == 0) {
+      // If the routine is already specified with the same clauses, just skip
+      // the operation creation.
+      if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
+                                gangDeviceTypes, gangDimValues,
+                                gangDimDeviceTypes, seqDeviceTypes,
+                                workerDeviceTypes, vectorDeviceTypes) &&
+          routineOp.getNohost() == hasNohost)
+        return;
+      mlir::emitError(loc, "Routine already specified with different clauses");
+    }
+  }
+  std::string routineOpStr = routineOpName.str();
+  mlir::OpBuilder modBuilder(mod.getBodyRegion());
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  modBuilder.create<mlir::acc::RoutineOp>(
+      loc, routineOpStr, funcName,
+      bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
+      bindNameDeviceTypes.empty() ? nullptr
+                                  : builder.getArrayAttr(bindNameDeviceTypes),
+      workerDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(workerDeviceTypes),
+      vectorDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(vectorDeviceTypes),
+      seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
+      hasNohost, /*implicit=*/false,
+      gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
+      gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
+      gangDimDeviceTypes.empty() ? nullptr
+                                 : builder.getArrayAttr(gangDimDeviceTypes));
+
+  if (funcOp)
+    attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpStr));
+  else
+    // FuncOp is not lowered yet. Keep the information so the routine info
+    // can be attached later to the funcOp.
+    converter.getAccDelayedRoutines().push_back(
+        std::make_pair(funcName, builder.getSymbolRefAttr(routineOpStr)));
+}
+
+static void interpretRoutineDeviceInfo(
+    fir::FirOpBuilder &builder,
+    const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo,
+    llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &bindNames,
+    llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &gangDimValues,
+    llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
+  mlir::MLIRContext *context{builder.getContext()};
+  if (dinfo.isSeq()) {
+    seqDeviceTypes.push_back(
+        mlir::acc::DeviceTypeAttr::get(context, getDeviceType(dinfo.dType())));
+  }
+  if (dinfo.isVector()) {
+    vectorDeviceTypes.push_back(
+        mlir::acc::DeviceTypeAttr::get(context, getDeviceType(dinfo.dType())));
+  }
+  if (dinfo.isWorker()) {
+    workerDeviceTypes.push_back(
+        mlir::acc::DeviceTypeAttr::get(context, getDeviceType(dinfo.dType())));
+  }
+  if (dinfo.isGang()) {
+    unsigned gangDim = dinfo.gangDim();
+    auto deviceType =
+        mlir::acc::DeviceTypeAttr::get(context, getDeviceType(dinfo.dType()));
+    if (!gangDim) {
+      gangDeviceTypes.push_back(deviceType);
+    } else {
+      gangDimValues.push_back(
+          builder.getIntegerAttr(builder.getI64Type(), gangDim));
+      gangDimDeviceTypes.push_back(deviceType);
+    }
+  }
+  if (const std::string *bindName{dinfo.bindName()}) {
+    bindNames.push_back(builder.getStringAttr(*bindName));
+    bindNameDeviceTypes.push_back(
+        mlir::acc::DeviceTypeAttr::get(context, getDeviceType(dinfo.dType())));
+  }
+}
+
+void Fortran::lower::genOpenACCRoutineConstruct(
+    Fortran::lower::AbstractConverter &converter, mlir::ModuleOp mod,
+    mlir::func::FuncOp funcOp,
+    const std::vector<Fortran::semantics::OpenACCRoutineInfo> &routineInfos) {
+  CHECK(funcOp && "Expected a valid function operation");
+  fir::FirOpBuilder &builder{converter.getFirOpBuilder()};
+  mlir::Location loc{funcOp.getLoc()};
+  std::string funcName{funcOp.getName()};
+
+  // Collect the routine clauses
+  bool hasNohost{false};
+
+  llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
+      workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
+      gangDimDeviceTypes, gangDimValues;
+
+  for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
+    // Device Independent Attributes
+    if (info.isNohost()) {
+      hasNohost = true;
+    }
+    // Note: Device Independent Attributes are set to the
+    // none device type in `info`.
+    interpretRoutineDeviceInfo(builder, info, seqDeviceTypes, vectorDeviceTypes,
+                               workerDeviceTypes, bindNameDeviceTypes,
+                               bindNames, gangDeviceTypes, gangDimValues,
+                               gangDimDeviceTypes);
+
+    // Device Dependent Attributes
+    for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
+         info.deviceTypeInfos()) {
+      interpretRoutineDeviceInfo(
+          builder, dinfo, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
+          bindNameDeviceTypes, bindNames, gangDeviceTypes, gangDimValues,
+          gangDimDeviceTypes);
+    }
+  }
+  createOpenACCRoutineConstruct(
+      converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
+      bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
+      seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
+}
+
 void Fortran::lower::genOpenACCRoutineConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp mod,
-    const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
-    Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
+    const Fortran::parser::OpenACCRoutineConstruct &routineConstruct) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   mlir::Location loc = converter.genLocation(routineConstruct.source);
   std::optional<Fortran::parser::Name> name =
@@ -4174,6 +4316,7 @@ void Fortran::lower::genOpenACCRoutineConstruct(
       funcName = funcOp.getName();
     }
   }
+  // TODO: Refactor this to use the OpenACCRoutineInfo
   bool hasNohost = false;
 
   llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
@@ -4226,6 +4369,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
                    std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
       if (const auto *name =
               std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
+        // FIXME: This case mangles the name, the one below does not.
+        // which is correct?
         mlir::Attribute bindNameAttr =
             builder.getStringAttr(converter.mangleName(*name->symbol));
         for (auto crtDeviceTypeAttr : crtDeviceTypes) {
@@ -4255,47 +4400,10 @@ void Fortran::lower::genOpenACCRoutineConstruct(
     }
   }
 
-  mlir::OpBuilder modBuilder(mod.getBodyRegion());
-  std::stringstream routineOpName;
-  routineOpName << accRoutinePrefix.str() << routineCounter++;
-
-  for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
-    if (routineOp.getFuncName().str().compare(funcName) == 0) {
-      // If the routine is already specified with the same clauses, just skip
-      // the operation creation.
-      if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
-                                gangDeviceTypes, gangDimValues,
-                                gangDimDeviceTypes, seqDeviceTypes,
-                                workerDeviceTypes, vectorDeviceTypes) &&
-          routineOp.getNohost() == hasNohost)
-        return;
-      mlir::emitError(loc, "Routine already specified with different clauses");
-    }
-  }
-
-  modBuilder.create<mlir::acc::RoutineOp>(
-      loc, routineOpName.str(), funcName,
-      bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
-      bindNameDeviceTypes.empty() ? nullptr
-                                  : builder.getArrayAttr(bindNameDeviceTypes),
-      workerDeviceTypes.empty() ? nullptr
-                                : builder.getArrayAttr(workerDeviceTypes),
-      vectorDeviceTypes.empty() ? nullptr
-                                : builder.getArrayAttr(vectorDeviceTypes),
-      seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
-      hasNohost, /*implicit=*/false,
-      gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
-      gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
-      gangDimDeviceTypes.empty() ? nullptr
-                                 : builder.getArrayAttr(gangDimDeviceTypes));
-
-  if (funcOp)
-    attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
-  else
-    // FuncOp is not lowered yet. Keep the information so the routine info
-    // can be attached later to the funcOp.
-    accRoutineInfos.push_back(std::make_pair(
-        funcName, builder.getSymbolRefAttr(routineOpName.str())));
+  createOpenACCRoutineConstruct(
+      converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
+      bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
+      seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
 }
 
 void Fortran::lower::finalizeOpenACCRoutineAttachment(
@@ -4443,8 +4551,7 @@ void Fortran::lower::genOpenACCDeclarativeConstruct(
             fir::FirOpBuilder &builder = converter.getFirOpBuilder();
             mlir::ModuleOp mod = builder.getModule();
             Fortran::lower::genOpenACCRoutineConstruct(
-                converter, semanticsContext, mod, routineConstruct,
-                accRoutineInfos);
+                converter, semanticsContext, mod, routineConstruct);
           },
       },
       accDeclConstruct.u);
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index ee356e56e4458..befd204a671fc 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -1387,6 +1387,7 @@ Scope *ModFileReader::Read(SourceName name, std::optional<bool> isIntrinsic,
   parser::Options options;
   options.isModuleFile = true;
   options.features.Enable(common::LanguageFeature::BackslashEscapes);
+  options.features.Enable(common::LanguageFeature::OpenACC);
   options.features.Enable(common::LanguageFeature::OpenMP);
   options.features.Enable(common::LanguageFeature::CUDA);
   if (!isIntrinsic.value_or(false) && !notAModule) {
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index d75b4ea13d35f..93c334a3ca3cb 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1034,61 +1034,53 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
     Symbol &symbol, const parser::OpenACCRoutineConstruct &x) {
   if (symbol.has<SubprogramDetails>()) {
     Fortran::semantics::OpenACCRoutineInfo info;
+    std::vector<OpenACCRoutineDeviceTypeInfo *> currentDevices;
+    currentDevices.push_back(&info);
     const auto &clauses = std::get<Fortran::parser::AccClauseList>(x.t);
     for (const Fortran::parser::AccClause &clause : clauses.v) {
-      if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-        if (info.deviceTypeInfos().empty()) {
-          info.set_isSeq();
-        } else {
-          info.deviceTypeInfos().back().set_isSeq();
+      if (const auto *dTypeClause =
+              std::get_if<Fortran::parser::AccClause::DeviceType>(&clause.u)) {
+        currentDevices.clear();
+        for (const auto &deviceTypeExpr : dTypeClause->v.v) {
+          currentDevices.push_back(&info.add_deviceTypeInfo(deviceTypeExpr.v));
         }
+      } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
+        info.set_isNohost();
+      } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
+        for (auto &device : currentDevices)
+          device->set_isSeq();
+      } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
+        for (auto &device : currentDevices)
+          device->set_isVector();
+      } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
+        for (auto &device : currentDevices)
+          device->set_isWorker();
       } else if (const auto *gangClause =
                      std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-        if (info.deviceTypeInfos().empty()) {
-          info.set_isGang();
-        } else {
-          info.deviceTypeInfos().back().set_isGang();
-        }
+        for (auto &device : currentDevices)
+          device->set_isGang();
         if (gangClause->v) {
           const Fortran::parser::AccGangArgList &x = *gangClause->v;
+          int numArgs{0};
           for (const Fortran::parser::AccGangArg &gangArg : x.v) {
+            CHECK(numArgs <= 1 && "expecting 0 or 1 gang dim args");
             if (const auto *dim =
                     std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
               if (const auto v{EvaluateInt64(context_, dim->v)}) {
-                if (info.deviceTypeInfos().empty()) {
-                  info.set_gangDim(*v);
-                } else {
-                  info.deviceTypeInfos().back().set_gangDim(*v);
-                }
+                for (auto &device : currentDevices)
+                  device->set_gangDim(*v);
               }
             }
+            numArgs++;
           }
         }
-      } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-        if (info.deviceTypeInfos().empty()) {
-          info.set_isVector();
-        } else {
-          info.deviceTypeInfos().back().set_isVector();
-        }
-      } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-        if (info.deviceTypeInfos().empty()) {
-          info.set_isWorker();
-        } else {
-          info.deviceTypeInfos().back().set_isWorker();
-        }
-      } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
-        info.set_isNohost();
       } else if (const auto *bindClause =
                      std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
+        std::string bindName = "";
         if (const auto *name =
                 std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
           if (Symbol *sym = ResolveFctName(*name)) {
-            if (info.deviceTypeInfos().empty()) {
-              info.set_bindName(sym->name().ToString());
-            } else {
-              info.deviceTypeInfos().back().set_bindName(
-                  sym->name().ToString());
-            }
+            bindName = sym->name().ToString();
           } else {
             context_.Say((*name).source,
                 "No function or subroutine declared for '%s'"_err_en_US,
@@ -1101,21 +1093,16 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
               Fortran::parser::Unwrap<Fortran::parser::CharLiteralConstant>(
                   *charExpr);
           std::string str{std::get<std::string>(charConst->t)};
-          std::stringstream bindName;
-          bindName << "\"" << str << "\"";
-          if (info.deviceTypeInfos().empty()) {
-            info.set_bindName(bindName.str());
-          } else {
-            info.deviceTypeInfos().back().set_bindName(bindName.str());
+          std::stringstream bindNameStream;
+          bindNameStream << "\"" << str << "\"";
+          bindName = bindNameStream.str();
+        }
+        if (!bindName.empty()) {
+          // Fixme: do we need to ensure there there is only one device?
+          for (auto &device : currentDevices) {
+            device->set_bindName(std::string(bindName));
           }
         }
-      } else if (const auto *dType =
-                     std::get_if<Fortran::parser::AccClause::DeviceType>(
-                         &clause.u)) {
-        const parser::AccDeviceTypeExprList &deviceTypeExprList = dType->v;
-        OpenACCRoutineDeviceTypeInfo dtypeInfo;
-        dtypeInfo.set_dType(deviceTypeExprList.v.front().v);
-        info.add_deviceTypeInfo(dtypeInfo);
       }
     }
     symbol.get<SubprogramDetails>().add_openACCRoutineInfo(info);



More information about the flang-commits mailing list