[flang-commits] [mlir] [flang] [mlir][flang][openacc] Device type support on acc routine op (PR #78375)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Jan 17 14:54:47 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/78375

>From 74ca692f5ed54842d8f5b7b661765b82f92c64d6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 11 Jan 2024 13:07:31 -0800
Subject: [PATCH 1/2] [mlir][openacc] Device type support on acc routine op

---
 flang/lib/Lower/OpenACC.cpp                   | 138 +++++++--
 flang/test/Lower/OpenACC/acc-routine.f90      |  18 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  58 +++-
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 287 ++++++++++++++++--
 mlir/test/Dialect/OpenACC/ops.mlir            |   4 +-
 .../Dialect/OpenACC/OpenACCOpsTest.cpp        |  58 ++++
 6 files changed, 495 insertions(+), 68 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index db9ed72bc87257..fd89d27db74dc0 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   llvm_unreachable("unsupported declarative directive");
 }
 
+static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  for (auto attr : arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+  return false;
+}
+
+template <typename RetTy, typename AttrTy>
+static std::optional<RetTy>
+getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
+                              llvm::SmallVector<mlir::Attribute> &deviceTypes,
+                              mlir::acc::DeviceType deviceType) {
+  assert(attributes.size() == deviceTypes.size() &&
+         "expect same number of attributes");
+  for (auto it : llvm::enumerate(deviceTypes)) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
+    if (deviceTypeAttr.getValue() == deviceType) {
+      if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
+        auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
+        return strAttr.getValue();
+      } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
+        auto intAttr =
+            mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
+        return intAttr.getInt();
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+static bool compareDeviceTypeInfo(
+    mlir::acc::RoutineOp op,
+    llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
+  for (uint32_t dtypeInt = 0;
+       dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
+    auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
+    if (op.getBindNameValue(dtype) !=
+        getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
+            bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
+      return false;
+    if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
+      return false;
+    if (op.getGangDimValue(dtype) !=
+        getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
+            gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
+      return false;
+    if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
+      return false;
+    if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
+      return false;
+    if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
+      return false;
+  }
+  return true;
+}
+
 static void attachRoutineInfo(mlir::func::FuncOp func,
                               mlir::SymbolRefAttr routineAttr) {
   llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
       funcName = funcOp.getName();
     }
   }
-  bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
-       hasNohost = false;
-  std::optional<std::string> bindName = std::nullopt;
-  std::optional<int64_t> gangDim = std::nullopt;
+  bool hasNohost = false;
+
+  llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
+      workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
+      gangDimDeviceTypes, gangDimValues;
+
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
 
   for (const Fortran::parser::AccClause &clause : clauses.v) {
     if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      hasSeq = true;
+      seqDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *gangClause =
                    std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-      hasGang = true;
+
       if (gangClause->v) {
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
             if (!dimValue)
               mlir::emitError(loc,
                               "dim value must be a constant positive integer");
-            gangDim = *dimValue;
+            gangDimValues.push_back(
+                builder.getIntegerAttr(builder.getI64Type(), *dimValue));
+            gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
           }
         }
+      } else {
+        gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-      hasVector = true;
+      vectorDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-      hasWorker = true;
+      workerDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
       hasNohost = true;
     } else if (const auto *bindClause =
                    std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
       if (const auto *name =
               std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
-        bindName = converter.mangleName(*name->symbol);
+        bindNames.push_back(
+            builder.getStringAttr(converter.mangleName(*name->symbol)));
+        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
       } else if (const auto charExpr =
                      std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
                          &bindClause->v.u)) {
@@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
                                                           *charExpr);
         if (!name)
           mlir::emitError(loc, "Could not retrieve the bind name");
-        bindName = *name;
+        bindNames.push_back(builder.getStringAttr(*name));
+        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
       }
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
@@ -3575,12 +3663,11 @@ void Fortran::lower::genOpenACCRoutineConstruct(
     if (routineOp.getFuncName().str().compare(funcName) == 0) {
       // If the routine is already specified with the same clauses, just skip
       // the operation creation.
-      if (routineOp.getBindName() == bindName &&
-          routineOp.getGang() == hasGang &&
-          routineOp.getWorker() == hasWorker &&
-          routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
-          routineOp.getNohost() == hasNohost &&
-          routineOp.getGangDim() == gangDim)
+      if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
+                                gangDeviceTypes, gangDimValues,
+                                gangDimDeviceTypes, seqDeviceTypes,
+                                workerDeviceTypes, vectorDeviceTypes) &&
+          routineOp.getNohost() == hasNohost)
         return;
       mlir::emitError(loc, "Routine already specified with different clauses");
     }
@@ -3588,10 +3675,19 @@ void Fortran::lower::genOpenACCRoutineConstruct(
 
   modBuilder.create<mlir::acc::RoutineOp>(
       loc, routineOpName.str(), funcName,
-      bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
-      hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
-      gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
-              : mlir::IntegerAttr{});
+      bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
+      bindNameDeviceTypes.empty() ? nullptr
+                                  : builder.getArrayAttr(bindNameDeviceTypes),
+      workerDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(workerDeviceTypes),
+      vectorDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(vectorDeviceTypes),
+      seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
+      hasNohost, /*implicit=*/false,
+      gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
+      gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
+      gangDimDeviceTypes.empty() ? nullptr
+                                 : builder.getArrayAttr(gangDimDeviceTypes));
 
   if (funcOp)
     attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 8b94279503334a..8e9e65da32cd19 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,12 +2,14 @@
 
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
-
+! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
 ! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
 ! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
 ! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
 ! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
-! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32)
+! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(1 : i64)
 ! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
 ! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
 ! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
@@ -106,3 +108,15 @@ subroutine acc_routine14()
 subroutine acc_routine15()
   !$acc routine bind(acc_routine16)
 end subroutine
+
+subroutine acc_routine16()
+  !$acc routine device_type(host) seq dtype(nvidia) gang
+end subroutine
+
+subroutine acc_routine17()
+  !$acc routine device_type(host) worker dtype(multicore) vector 
+end subroutine
+
+subroutine acc_routine18()
+  !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16) 
+end subroutine
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 24f129d92805c0..7344ab2852b9ce 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        SymbolNameAttr:$func_name,
-                       OptionalAttr<StrAttr>:$bind_name,
-                       UnitAttr:$gang,
-                       UnitAttr:$worker,
-                       UnitAttr:$vector,
-                       UnitAttr:$seq,
+                       OptionalAttr<StrArrayAttr>:$bindName,
+                       OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
+                       OptionalAttr<DeviceTypeArrayAttr>:$worker,
+                       OptionalAttr<DeviceTypeArrayAttr>:$vector,
+                       OptionalAttr<DeviceTypeArrayAttr>:$seq,
                        UnitAttr:$nohost,
                        UnitAttr:$implicit,
-                       OptionalAttr<APIntAttr>:$gangDim);
+                       OptionalAttr<DeviceTypeArrayAttr>:$gang,
+                       OptionalAttr<I64ArrayAttr>:$gangDim,
+                       OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType);
 
   let extraClassDeclaration = [{
     static StringRef getGangDimKeyword() { return "dim"; }
+
+    /// Return true if the op has the worker attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWorker();
+    /// Return true if the op has the worker attribute for the given
+    /// device_type.
+    bool hasWorker(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the vector attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasVector();
+    /// Return true if the op has the vector attribute for the given
+    /// device_type.
+    bool hasVector(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the seq attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasSeq();
+    /// Return true if the op has the seq attribute for the given
+    /// device_type.
+    bool hasSeq(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the gang attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasGang();
+    /// Return true if the op has the gang attribute for the given
+    /// device_type.
+    bool hasGang(mlir::acc::DeviceType deviceType);
+
+    std::optional<int64_t> getGangDimValue();
+    std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
+
+    std::optional<llvm::StringRef> getBindNameValue();
+    std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     $sym_name `func` `(` $func_name `)`
     oilist (
-        `bind` `(` $bind_name `)`
-      | `gang` `` custom<RoutineGangClause>($gang, $gangDim)
-      | `worker` $worker
-      | `vector` $vector
-      | `seq` $seq
+        `bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
+      | `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
+      | `worker` custom<DeviceTypeArrayAttr>($worker)
+      | `vector` custom<DeviceTypeArrayAttr>($vector)
+      | `seq` custom<DeviceTypeArrayAttr>($seq)
       | `nohost` $nohost
       | `implicit` $implicit
     ) attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bf3264b5da9802..82e614cb7572f6 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1046,7 +1046,7 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
   return success();
 }
 
-bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
   if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
     return true;
   return false;
@@ -2131,55 +2131,278 @@ LogicalResult acc::DeclareOp::verify() {
 // RoutineOp
 //===----------------------------------------------------------------------===//
 
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(arrayAttr))
+    return false;
+
+  for (auto attr : *arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+
+  return false;
+}
+
+static unsigned getParallelismForDeviceType(acc::RoutineOp op,
+                                            acc::DeviceType dtype) {
+  unsigned parallelism = 0;
+  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
+  parallelism += op.hasWorker(dtype) ? 1 : 0;
+  parallelism += op.hasVector(dtype) ? 1 : 0;
+  parallelism += op.hasSeq(dtype) ? 1 : 0;
+  return parallelism;
+}
+
 LogicalResult acc::RoutineOp::verify() {
-  int parallelism = 0;
-  parallelism += getGang() ? 1 : 0;
-  parallelism += getWorker() ? 1 : 0;
-  parallelism += getVector() ? 1 : 0;
-  parallelism += getSeq() ? 1 : 0;
+  unsigned baseParallelism =
+      getParallelismForDeviceType(*this, acc::DeviceType::None);
 
-  if (parallelism > 1)
+  if (baseParallelism > 1)
     return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
                           "be present at the same time";
 
+  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+       ++dtypeInt) {
+    auto dtype = static_cast<acc::DeviceType>(dtypeInt);
+    if (dtype == acc::DeviceType::None)
+      continue;
+    unsigned parallelism = getParallelismForDeviceType(*this, dtype);
+
+    if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
+      return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
+                            "be present at the same time";
+  }
+
   return success();
 }
 
-static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang,
-                                          IntegerAttr &gangDim) {
-  // Since gang clause exists, ensure that unit attribute is set.
-  gang = UnitAttr::get(parser.getBuilder().getContext());
+static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
+                                 mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
 
-  // Next, look for dim on gang. Don't initialize `gangDim` yet since
-  // we leave it without attribute if there is no `dim` specifier.
-  if (succeeded(parser.parseOptionalLParen())) {
-    // Look for syntax that looks like `dim = 1 : i32`.
-    // Thus first look for `dim =`
-    if (failed(parser.parseKeyword(RoutineOp::getGangDimKeyword())) ||
-        failed(parser.parseEqual()))
-      return failure();
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(bindNameAttrs.emplace_back()))
+          return failure();
+        if (failed(parser.parseOptionalLSquare())) {
+          deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        } else {
+          if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        }
+        return success();
+      })))
+    return failure();
 
-    int64_t dimValue;
-    Type valueType;
-    // Now look for `1 : i32`
-    if (failed(parser.parseInteger(dimValue)) ||
-        failed(parser.parseColonType(valueType)))
-      return failure();
+  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
+  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+
+  return success();
+}
+
+static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                          std::optional<mlir::ArrayAttr> bindName,
+                          std::optional<mlir::ArrayAttr> deviceTypes) {
+  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
+                        [&](const auto &pair) {
+                          p << std::get<0>(pair);
+                          printSingleDeviceType(p, std::get<1>(pair));
+                        });
+}
+
+static ParseResult parseRoutineGangClause(OpAsmParser &parser,
+                                          mlir::ArrayAttr &gang,
+                                          mlir::ArrayAttr &gangDim,
+                                          mlir::ArrayAttr &gangDimDeviceTypes) {
 
-    gangDim = IntegerAttr::get(valueType, dimValue);
+  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
+      gangDimDeviceTypeAttrs;
+  bool needCommaBeforeOperands = false;
 
-    if (failed(parser.parseRParen()))
+  // Gang keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    gang = ArrayAttr::get(parser.getContext(), gangAttrs);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(gangAttrs.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
       return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(gangDimAttrs.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
+  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
+  gangDimDeviceTypes =
+      ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
+
+  return success();
+}
+
+void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
+                            std::optional<mlir::ArrayAttr> gang,
+                            std::optional<mlir::ArrayAttr> gangDim,
+                            std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
+
+  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
+      gang->size() == 1) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+
+  printDeviceTypes(p, gang);
+
+  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
+    p << ", ";
+
+  if (hasDeviceTypeValues(gangDimDeviceTypes))
+    llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
+                          [&](const auto &pair) {
+                            p << std::get<0>(pair);
+                            printSingleDeviceType(p, std::get<1>(pair));
+                          });
+
+  p << ")";
+}
+
+static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
+                                            mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<mlir::Attribute> attributes;
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
+    return success();
   }
 
+  // Parse device type attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(attributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare() || parser.parseRParen())
+      return failure();
+  }
+  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
   return success();
 }
 
-void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang,
-                            IntegerAttr gangDim) {
-  if (gangDim)
-    p << "(" << RoutineOp::getGangDimKeyword() << " = " << gangDim.getValue()
-      << " : " << gangDim.getType() << ")";
+static void
+printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                         std::optional<mlir::ArrayAttr> deviceTypes) {
+
+  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+
+  p << "([";
+  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
+    auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    p << dTypeAttr;
+  });
+  p << "])";
+}
+
+bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getWorker(), deviceType);
+}
+
+bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getVector(), deviceType);
+}
+
+bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getSeq(), deviceType);
+}
+
+std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
+  return getBindNameValue(mlir::acc::DeviceType::None);
+}
+
+std::optional<llvm::StringRef>
+RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(getBindNameDeviceType()))
+    return std::nullopt;
+  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
+    auto attr = (*getBindName())[*pos];
+    auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
+    return stringAttr.getValue();
+  }
+  return std::nullopt;
+}
+
+bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getGang(), deviceType);
+}
+
+std::optional<int64_t> RoutineOp::getGangDimValue() {
+  return getGangDimValue(mlir::acc::DeviceType::None);
+}
+
+std::optional<int64_t>
+RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(getGangDimDeviceType()))
+    return std::nullopt;
+  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
+    auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
+    return intAttr.getInt();
+  }
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 8fa37bc98294ce..77599fdd32e567 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1656,7 +1656,7 @@ acc.routine @acc_func_rout5 func(@acc_func) bind("acc_func_gpu_worker") worker
 acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
 acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") implicit gang
 acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32)
+acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
 
 // CHECK-LABEL: func.func @acc_func(
 // CHECK: attributes {acc.routine_info = #acc.routine_info<[@acc_func_rout1, @acc_func_rout2, @acc_func_rout3,
@@ -1669,7 +1669,7 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
 // CHECK: acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
 // CHECK: acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") gang implicit
 // CHECK: acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32)
+// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
 
 // -----
 
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index d78d7b0fdf6769..474f887928992a 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -347,3 +347,61 @@ TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) {
   }
   op->removeVectorAttr();
 }
+
+TEST_F(OpenACCOpsTest, routineOpTest) {
+  OwningOpRef<RoutineOp> op =
+      b.create<RoutineOp>(loc, TypeRange{}, ValueRange{});
+
+  EXPECT_FALSE(op->hasSeq());
+  EXPECT_FALSE(op->hasVector());
+  EXPECT_FALSE(op->hasWorker());
+
+  for (auto d : dtypes) {
+    EXPECT_FALSE(op->hasSeq(d));
+    EXPECT_FALSE(op->hasVector(d));
+    EXPECT_FALSE(op->hasWorker(d));
+  }
+
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op->setSeqAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasSeq());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasSeq(d));
+  op->removeSeqAttr();
+
+  op->setVectorAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasVector());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasVector(d));
+  op->removeVectorAttr();
+
+  op->setWorkerAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasWorker());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasWorker(d));
+  op->removeWorkerAttr();
+
+  op->setGangAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasGang());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasGang(d));
+  op->removeGangAttr();
+
+  op->setGangDimDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+  op->setGangDimAttr(b.getArrayAttr({b.getIntegerAttr(b.getI64Type(), 8)}));
+  EXPECT_TRUE(op->getGangDimValue().has_value());
+  EXPECT_EQ(op->getGangDimValue().value(), 8);
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->getGangDimValue(d).has_value());
+  op->removeGangDimDeviceTypeAttr();
+  op->removeGangDimAttr();
+
+  op->setBindNameDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+  op->setBindNameAttr(b.getArrayAttr({b.getStringAttr("fname")}));
+  EXPECT_TRUE(op->getBindNameValue().has_value());
+  EXPECT_EQ(op->getBindNameValue().value(), "fname");
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->getBindNameValue(d).has_value());
+  op->removeBindNameDeviceTypeAttr();
+  op->removeBindNameAttr();
+}

>From 66f8415ecf19b9f7e2010dce8869f14d47db80c4 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 17 Jan 2024 14:54:30 -0800
Subject: [PATCH 2/2] Keep dim keyword for gang

---
 flang/test/Lower/OpenACC/acc-routine.f90 | 2 +-
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp  | 5 ++++-
 mlir/test/Dialect/OpenACC/ops.mlir       | 4 ++--
 3 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 8e9e65da32cd19..2fe150e70b0cfb 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -9,7 +9,7 @@
 ! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
 ! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
 ! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
-! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(1 : i64)
+! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim: 1 : i64)
 ! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
 ! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
 ! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 82e614cb7572f6..5cfbe233a670f3 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2248,7 +2248,9 @@ static ParseResult parseRoutineGangClause(OpAsmParser &parser,
     return failure();
 
   if (failed(parser.parseCommaSeparatedList([&]() {
-        if (parser.parseAttribute(gangDimAttrs.emplace_back()))
+        if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
+            parser.parseColon() ||
+            parser.parseAttribute(gangDimAttrs.emplace_back()))
           return failure();
         if (succeeded(parser.parseOptionalLSquare())) {
           if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
@@ -2295,6 +2297,7 @@ void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
   if (hasDeviceTypeValues(gangDimDeviceTypes))
     llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
                           [&](const auto &pair) {
+                            p << acc::RoutineOp::getGangDimKeyword() << ": ";
                             p << std::get<0>(pair);
                             printSingleDeviceType(p, std::get<1>(pair));
                           });
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 77599fdd32e567..99b44183758d95 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1656,7 +1656,7 @@ acc.routine @acc_func_rout5 func(@acc_func) bind("acc_func_gpu_worker") worker
 acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
 acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") implicit gang
 acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
+acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64)
 
 // CHECK-LABEL: func.func @acc_func(
 // CHECK: attributes {acc.routine_info = #acc.routine_info<[@acc_func_rout1, @acc_func_rout2, @acc_func_rout3,
@@ -1669,7 +1669,7 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
 // CHECK: acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
 // CHECK: acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") gang implicit
 // CHECK: acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
+// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64)
 
 // -----
 



More information about the flang-commits mailing list