[flang-commits] [mlir] [flang] [flang][openacc] Do not accept static and num for gang clause on routine dir (PR #77673)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Jan 11 13:07:43 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/77673

>From 6da2c829d6222be3f6ce8eb5ac21f3a047f75810 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 10 Jan 2024 11:38:04 -0800
Subject: [PATCH 1/2] [flang][openacc] Do not accept static and num for gang
 clause on routine dir

---
 flang/lib/Semantics/check-acc-structure.cpp  | 12 ++++++++++++
 flang/test/Semantics/OpenACC/acc-routine.f90 | 10 ++++++++++
 2 files changed, 22 insertions(+)

diff --git a/flang/lib/Semantics/check-acc-structure.cpp b/flang/lib/Semantics/check-acc-structure.cpp
index 4a5798a8a531a4..5c2a871c322e3a 100644
--- a/flang/lib/Semantics/check-acc-structure.cpp
+++ b/flang/lib/Semantics/check-acc-structure.cpp
@@ -560,14 +560,26 @@ void AccStructureChecker::Enter(const parser::AccClause::Gang &g) {
   if (g.v) {
     bool hasNum = false;
     bool hasDim = false;
+    bool hasStatic = false;
     const Fortran::parser::AccGangArgList &x = *g.v;
     for (const Fortran::parser::AccGangArg &gangArg : x.v) {
       if (std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u))
         hasNum = true;
       else if (std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u))
         hasDim = true;
+      else if (std::get_if<Fortran::parser::AccGangArg::Static>(&gangArg.u))
+        hasStatic = true;
     }
 
+    if (GetContext().directive == llvm::acc::Directive::ACCD_routine &&
+        (hasStatic || hasNum))
+      context_.Say(GetContext().clauseSource,
+          "Only the dim argument is allowed on the %s clause on the %s directive"_err_en_US,
+          parser::ToUpperCaseLetters(
+              llvm::acc::getOpenACCClauseName(llvm::acc::Clause::ACCC_gang)
+                  .str()),
+          ContextDirectiveAsFortran());
+
     if (hasDim && hasNum)
       context_.Say(GetContext().clauseSource,
           "The num argument is not allowed when dim is specified"_err_en_US);
diff --git a/flang/test/Semantics/OpenACC/acc-routine.f90 b/flang/test/Semantics/OpenACC/acc-routine.f90
index 4dcb849c642c83..f27084115fbee2 100644
--- a/flang/test/Semantics/OpenACC/acc-routine.f90
+++ b/flang/test/Semantics/OpenACC/acc-routine.f90
@@ -13,3 +13,13 @@ subroutine sub2(a)
 subroutine sub3()
   !$acc routine bind(sub1)
 end subroutine
+
+subroutine sub4()
+  !ERROR: Only the dim argument is allowed on the GANG clause on the ROUTINE directive
+  !$acc routine gang(num: 1)
+end subroutine
+
+subroutine sub5()
+  !ERROR: Only the dim argument is allowed on the GANG clause on the ROUTINE directive
+  !$acc routine gang(static: 1)
+end subroutine

>From ac3743ca5b6cd77e666e4a29ea237d33dd4e99df Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 11 Jan 2024 13:07:31 -0800
Subject: [PATCH 2/2] Add missing braces

---
 flang/lib/Semantics/check-acc-structure.cpp   | 10 ++-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 33 +++++--
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 89 +++++++++++++++++++
 3 files changed, 122 insertions(+), 10 deletions(-)

diff --git a/flang/lib/Semantics/check-acc-structure.cpp b/flang/lib/Semantics/check-acc-structure.cpp
index 5c2a871c322e3a..6c163b4498c993 100644
--- a/flang/lib/Semantics/check-acc-structure.cpp
+++ b/flang/lib/Semantics/check-acc-structure.cpp
@@ -563,22 +563,24 @@ void AccStructureChecker::Enter(const parser::AccClause::Gang &g) {
     bool hasStatic = false;
     const Fortran::parser::AccGangArgList &x = *g.v;
     for (const Fortran::parser::AccGangArg &gangArg : x.v) {
-      if (std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u))
+      if (std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
         hasNum = true;
-      else if (std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u))
+      } else if (std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
         hasDim = true;
-      else if (std::get_if<Fortran::parser::AccGangArg::Static>(&gangArg.u))
+      } else if (std::get_if<Fortran::parser::AccGangArg::Static>(&gangArg.u)) {
         hasStatic = true;
+      }
     }
 
     if (GetContext().directive == llvm::acc::Directive::ACCD_routine &&
-        (hasStatic || hasNum))
+        (hasStatic || hasNum)) {
       context_.Say(GetContext().clauseSource,
           "Only the dim argument is allowed on the %s clause on the %s directive"_err_en_US,
           parser::ToUpperCaseLetters(
               llvm::acc::getOpenACCClauseName(llvm::acc::Clause::ACCC_gang)
                   .str()),
           ContextDirectiveAsFortran());
+    }
 
     if (hasDim && hasNum)
       context_.Say(GetContext().clauseSource,
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index e6954062a50e0c..03bf89f35a0ab2 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1998,15 +1998,36 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
                        SymbolNameAttr:$func_name,
                        OptionalAttr<StrAttr>:$bind_name,
                        UnitAttr:$gang,
-                       UnitAttr:$worker,
-                       UnitAttr:$vector,
-                       UnitAttr:$seq,
+                       OptionalAttr<DeviceTypeArrayAttr>:$worker,
+                       OptionalAttr<DeviceTypeArrayAttr>:$vector,
+                       OptionalAttr<DeviceTypeArrayAttr>:$seq,
                        UnitAttr:$nohost,
                        UnitAttr:$implicit,
                        OptionalAttr<APIntAttr>:$gangDim);
 
   let extraClassDeclaration = [{
     static StringRef getGangDimKeyword() { return "dim"; }
+
+    /// Return true if the op has the worker attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWorker();
+    /// Return true if the op has the worker attribute for the given
+    /// device_type.
+    bool hasWorker(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the vector attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasVector();
+    /// Return true if the op has the vector attribute for the given
+    /// device_type.
+    bool hasVector(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the seq attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasSeq();
+    /// Return true if the op has the seq attribute for the given
+    /// device_type.
+    bool hasSeq(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
@@ -2014,9 +2035,9 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
     oilist (
         `bind` `(` $bind_name `)`
       | `gang` `` custom<RoutineGangClause>($gang, $gangDim)
-      | `worker` $worker
-      | `vector` $vector
-      | `seq` $seq
+      | `worker` custom<DeviceTypeArrayAttr>($worker)
+      | `vector` custom<DeviceTypeArrayAttr>($vector)
+      | `seq` custom<DeviceTypeArrayAttr>($seq)
       | `nohost` $nohost
       | `implicit` $implicit
     ) attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c53673fa426038..b4adef7e7dd1b9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2001,6 +2001,95 @@ void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang,
       << " : " << gangDim.getType() << ")";
 }
 
+static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<mlir::Attribute> attributes;
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
+    return success();
+  }
+
+  // Parse device type attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(attributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare() || parser.parseRParen())
+      return failure();
+  }
+  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
+  return success();
+}
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
+    std::optional<mlir::ArrayAttr> deviceTypes) {
+
+  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+  
+  p << "([";
+  llvm::interleaveComma(*deviceTypes, p,
+        [&](mlir::Attribute attr) { 
+          auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+          p << dTypeAttr;
+      });
+  p << "])";
+}
+
+bool RoutineOp::hasWorker() {
+  return hasWorker(mlir::acc::DeviceType::None);
+}
+
+bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getWorker()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+bool RoutineOp::hasVector() {
+  return hasWorker(mlir::acc::DeviceType::None);
+}
+
+bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getVector()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+bool RoutineOp::hasSeq() {
+  return hasWorker(mlir::acc::DeviceType::None);
+}
+
+bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getSeq()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // InitOp
 //===----------------------------------------------------------------------===//



More information about the flang-commits mailing list