[flang-commits] [flang] b06bc7c - [mlir][flang][openacc] Device type support on acc routine op (#78375)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 18 09:04:16 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-01-18T09:04:11-08:00
New Revision: b06bc7c6a00c4e8acac4fa76e402c6a7e2035090

URL: https://github.com/llvm/llvm-project/commit/b06bc7c6a00c4e8acac4fa76e402c6a7e2035090
DIFF: https://github.com/llvm/llvm-project/commit/b06bc7c6a00c4e8acac4fa76e402c6a7e2035090.diff

LOG: [mlir][flang][openacc] Device type support on acc routine op (#78375)

This patch add support for device_type on the acc.routine operation.
device_type can be specified on seq, worker, vector, gang and bind
information.

The support is following the same design than the one for compute
operations, data operation and the loop operation.

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/test/Lower/OpenACC/acc-routine.f90
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir
    mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index db9ed72bc87257..fd89d27db74dc0 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   llvm_unreachable("unsupported declarative directive");
 }
 
+static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  for (auto attr : arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+  return false;
+}
+
+template <typename RetTy, typename AttrTy>
+static std::optional<RetTy>
+getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
+                              llvm::SmallVector<mlir::Attribute> &deviceTypes,
+                              mlir::acc::DeviceType deviceType) {
+  assert(attributes.size() == deviceTypes.size() &&
+         "expect same number of attributes");
+  for (auto it : llvm::enumerate(deviceTypes)) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
+    if (deviceTypeAttr.getValue() == deviceType) {
+      if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
+        auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
+        return strAttr.getValue();
+      } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
+        auto intAttr =
+            mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
+        return intAttr.getInt();
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+static bool compareDeviceTypeInfo(
+    mlir::acc::RoutineOp op,
+    llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
+  for (uint32_t dtypeInt = 0;
+       dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
+    auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
+    if (op.getBindNameValue(dtype) !=
+        getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
+            bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
+      return false;
+    if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
+      return false;
+    if (op.getGangDimValue(dtype) !=
+        getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
+            gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
+      return false;
+    if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
+      return false;
+    if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
+      return false;
+    if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
+      return false;
+  }
+  return true;
+}
+
 static void attachRoutineInfo(mlir::func::FuncOp func,
                               mlir::SymbolRefAttr routineAttr) {
   llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
       funcName = funcOp.getName();
     }
   }
-  bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
-       hasNohost = false;
-  std::optional<std::string> bindName = std::nullopt;
-  std::optional<int64_t> gangDim = std::nullopt;
+  bool hasNohost = false;
+
+  llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
+      workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
+      gangDimDeviceTypes, gangDimValues;
+
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
 
   for (const Fortran::parser::AccClause &clause : clauses.v) {
     if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      hasSeq = true;
+      seqDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *gangClause =
                    std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-      hasGang = true;
+
       if (gangClause->v) {
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
             if (!dimValue)
               mlir::emitError(loc,
                               "dim value must be a constant positive integer");
-            gangDim = *dimValue;
+            gangDimValues.push_back(
+                builder.getIntegerAttr(builder.getI64Type(), *dimValue));
+            gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
           }
         }
+      } else {
+        gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-      hasVector = true;
+      vectorDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-      hasWorker = true;
+      workerDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
       hasNohost = true;
     } else if (const auto *bindClause =
                    std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
       if (const auto *name =
               std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
-        bindName = converter.mangleName(*name->symbol);
+        bindNames.push_back(
+            builder.getStringAttr(converter.mangleName(*name->symbol)));
+        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
       } else if (const auto charExpr =
                      std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
                          &bindClause->v.u)) {
@@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
                                                           *charExpr);
         if (!name)
           mlir::emitError(loc, "Could not retrieve the bind name");
-        bindName = *name;
+        bindNames.push_back(builder.getStringAttr(*name));
+        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
       }
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
@@ -3575,12 +3663,11 @@ void Fortran::lower::genOpenACCRoutineConstruct(
     if (routineOp.getFuncName().str().compare(funcName) == 0) {
       // If the routine is already specified with the same clauses, just skip
       // the operation creation.
-      if (routineOp.getBindName() == bindName &&
-          routineOp.getGang() == hasGang &&
-          routineOp.getWorker() == hasWorker &&
-          routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
-          routineOp.getNohost() == hasNohost &&
-          routineOp.getGangDim() == gangDim)
+      if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
+                                gangDeviceTypes, gangDimValues,
+                                gangDimDeviceTypes, seqDeviceTypes,
+                                workerDeviceTypes, vectorDeviceTypes) &&
+          routineOp.getNohost() == hasNohost)
         return;
       mlir::emitError(loc, "Routine already specified with 
diff erent clauses");
     }
@@ -3588,10 +3675,19 @@ void Fortran::lower::genOpenACCRoutineConstruct(
 
   modBuilder.create<mlir::acc::RoutineOp>(
       loc, routineOpName.str(), funcName,
-      bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
-      hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
-      gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
-              : mlir::IntegerAttr{});
+      bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
+      bindNameDeviceTypes.empty() ? nullptr
+                                  : builder.getArrayAttr(bindNameDeviceTypes),
+      workerDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(workerDeviceTypes),
+      vectorDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(vectorDeviceTypes),
+      seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
+      hasNohost, /*implicit=*/false,
+      gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
+      gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
+      gangDimDeviceTypes.empty() ? nullptr
+                                 : builder.getArrayAttr(gangDimDeviceTypes));
 
   if (funcOp)
     attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));

diff  --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 8b94279503334a..2fe150e70b0cfb 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,12 +2,14 @@
 
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
-
+! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
 ! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
 ! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
 ! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
 ! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
-! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32)
+! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim: 1 : i64)
 ! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
 ! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
 ! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
@@ -106,3 +108,15 @@ subroutine acc_routine14()
 subroutine acc_routine15()
   !$acc routine bind(acc_routine16)
 end subroutine
+
+subroutine acc_routine16()
+  !$acc routine device_type(host) seq dtype(nvidia) gang
+end subroutine
+
+subroutine acc_routine17()
+  !$acc routine device_type(host) worker dtype(multicore) vector 
+end subroutine
+
+subroutine acc_routine18()
+  !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16) 
+end subroutine

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 24f129d92805c0..7344ab2852b9ce 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        SymbolNameAttr:$func_name,
-                       OptionalAttr<StrAttr>:$bind_name,
-                       UnitAttr:$gang,
-                       UnitAttr:$worker,
-                       UnitAttr:$vector,
-                       UnitAttr:$seq,
+                       OptionalAttr<StrArrayAttr>:$bindName,
+                       OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
+                       OptionalAttr<DeviceTypeArrayAttr>:$worker,
+                       OptionalAttr<DeviceTypeArrayAttr>:$vector,
+                       OptionalAttr<DeviceTypeArrayAttr>:$seq,
                        UnitAttr:$nohost,
                        UnitAttr:$implicit,
-                       OptionalAttr<APIntAttr>:$gangDim);
+                       OptionalAttr<DeviceTypeArrayAttr>:$gang,
+                       OptionalAttr<I64ArrayAttr>:$gangDim,
+                       OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType);
 
   let extraClassDeclaration = [{
     static StringRef getGangDimKeyword() { return "dim"; }
+
+    /// Return true if the op has the worker attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWorker();
+    /// Return true if the op has the worker attribute for the given
+    /// device_type.
+    bool hasWorker(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the vector attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasVector();
+    /// Return true if the op has the vector attribute for the given
+    /// device_type.
+    bool hasVector(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the seq attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasSeq();
+    /// Return true if the op has the seq attribute for the given
+    /// device_type.
+    bool hasSeq(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the gang attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasGang();
+    /// Return true if the op has the gang attribute for the given
+    /// device_type.
+    bool hasGang(mlir::acc::DeviceType deviceType);
+
+    std::optional<int64_t> getGangDimValue();
+    std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
+
+    std::optional<llvm::StringRef> getBindNameValue();
+    std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     $sym_name `func` `(` $func_name `)`
     oilist (
-        `bind` `(` $bind_name `)`
-      | `gang` `` custom<RoutineGangClause>($gang, $gangDim)
-      | `worker` $worker
-      | `vector` $vector
-      | `seq` $seq
+        `bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
+      | `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
+      | `worker` custom<DeviceTypeArrayAttr>($worker)
+      | `vector` custom<DeviceTypeArrayAttr>($vector)
+      | `seq` custom<DeviceTypeArrayAttr>($seq)
       | `nohost` $nohost
       | `implicit` $implicit
     ) attr-dict-with-keyword

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 664a0161b79c1e..20465f6bb86ed1 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1033,7 +1033,7 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
   return success();
 }
 
-bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
   if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
     return true;
   return false;
@@ -2090,55 +2090,281 @@ LogicalResult acc::DeclareOp::verify() {
 // RoutineOp
 //===----------------------------------------------------------------------===//
 
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(arrayAttr))
+    return false;
+
+  for (auto attr : *arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+
+  return false;
+}
+
+static unsigned getParallelismForDeviceType(acc::RoutineOp op,
+                                            acc::DeviceType dtype) {
+  unsigned parallelism = 0;
+  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
+  parallelism += op.hasWorker(dtype) ? 1 : 0;
+  parallelism += op.hasVector(dtype) ? 1 : 0;
+  parallelism += op.hasSeq(dtype) ? 1 : 0;
+  return parallelism;
+}
+
 LogicalResult acc::RoutineOp::verify() {
-  int parallelism = 0;
-  parallelism += getGang() ? 1 : 0;
-  parallelism += getWorker() ? 1 : 0;
-  parallelism += getVector() ? 1 : 0;
-  parallelism += getSeq() ? 1 : 0;
+  unsigned baseParallelism =
+      getParallelismForDeviceType(*this, acc::DeviceType::None);
 
-  if (parallelism > 1)
+  if (baseParallelism > 1)
     return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
                           "be present at the same time";
 
+  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+       ++dtypeInt) {
+    auto dtype = static_cast<acc::DeviceType>(dtypeInt);
+    if (dtype == acc::DeviceType::None)
+      continue;
+    unsigned parallelism = getParallelismForDeviceType(*this, dtype);
+
+    if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
+      return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
+                            "be present at the same time";
+  }
+
   return success();
 }
 
-static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang,
-                                          IntegerAttr &gangDim) {
-  // Since gang clause exists, ensure that unit attribute is set.
-  gang = UnitAttr::get(parser.getBuilder().getContext());
+static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
+                                 mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
 
-  // Next, look for dim on gang. Don't initialize `gangDim` yet since
-  // we leave it without attribute if there is no `dim` specifier.
-  if (succeeded(parser.parseOptionalLParen())) {
-    // Look for syntax that looks like `dim = 1 : i32`.
-    // Thus first look for `dim =`
-    if (failed(parser.parseKeyword(RoutineOp::getGangDimKeyword())) ||
-        failed(parser.parseEqual()))
-      return failure();
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(bindNameAttrs.emplace_back()))
+          return failure();
+        if (failed(parser.parseOptionalLSquare())) {
+          deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        } else {
+          if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        }
+        return success();
+      })))
+    return failure();
 
-    int64_t dimValue;
-    Type valueType;
-    // Now look for `1 : i32`
-    if (failed(parser.parseInteger(dimValue)) ||
-        failed(parser.parseColonType(valueType)))
-      return failure();
+  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
+  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+
+  return success();
+}
 
-    gangDim = IntegerAttr::get(valueType, dimValue);
+static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                          std::optional<mlir::ArrayAttr> bindName,
+                          std::optional<mlir::ArrayAttr> deviceTypes) {
+  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
+                        [&](const auto &pair) {
+                          p << std::get<0>(pair);
+                          printSingleDeviceType(p, std::get<1>(pair));
+                        });
+}
+
+static ParseResult parseRoutineGangClause(OpAsmParser &parser,
+                                          mlir::ArrayAttr &gang,
+                                          mlir::ArrayAttr &gangDim,
+                                          mlir::ArrayAttr &gangDimDeviceTypes) {
+
+  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
+      gangDimDeviceTypeAttrs;
+  bool needCommaBeforeOperands = false;
 
-    if (failed(parser.parseRParen()))
+  // Gang keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    gang = ArrayAttr::get(parser.getContext(), gangAttrs);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(gangAttrs.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
       return failure();
+    needCommaBeforeOperands = true;
   }
 
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
+            parser.parseColon() ||
+            parser.parseAttribute(gangDimAttrs.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  gang = ArrayAttr::get(parser.getContext(), gangAttrs);
+  gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
+  gangDimDeviceTypes =
+      ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
+
   return success();
 }
 
-void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang,
-                            IntegerAttr gangDim) {
-  if (gangDim)
-    p << "(" << RoutineOp::getGangDimKeyword() << " = " << gangDim.getValue()
-      << " : " << gangDim.getType() << ")";
+void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
+                            std::optional<mlir::ArrayAttr> gang,
+                            std::optional<mlir::ArrayAttr> gangDim,
+                            std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
+
+  if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
+      gang->size() == 1) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+
+  printDeviceTypes(p, gang);
+
+  if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
+    p << ", ";
+
+  if (hasDeviceTypeValues(gangDimDeviceTypes))
+    llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
+                          [&](const auto &pair) {
+                            p << acc::RoutineOp::getGangDimKeyword() << ": ";
+                            p << std::get<0>(pair);
+                            printSingleDeviceType(p, std::get<1>(pair));
+                          });
+
+  p << ")";
+}
+
+static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
+                                            mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<mlir::Attribute> attributes;
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
+    return success();
+  }
+
+  // Parse device type attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(attributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare() || parser.parseRParen())
+      return failure();
+  }
+  deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
+  return success();
+}
+
+static void
+printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                         std::optional<mlir::ArrayAttr> deviceTypes) {
+
+  if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+
+  p << "([";
+  llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
+    auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    p << dTypeAttr;
+  });
+  p << "])";
+}
+
+bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getWorker(), deviceType);
+}
+
+bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getVector(), deviceType);
+}
+
+bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getSeq(), deviceType);
+}
+
+std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
+  return getBindNameValue(mlir::acc::DeviceType::None);
+}
+
+std::optional<llvm::StringRef>
+RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(getBindNameDeviceType()))
+    return std::nullopt;
+  if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
+    auto attr = (*getBindName())[*pos];
+    auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
+    return stringAttr.getValue();
+  }
+  return std::nullopt;
+}
+
+bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getGang(), deviceType);
+}
+
+std::optional<int64_t> RoutineOp::getGangDimValue() {
+  return getGangDimValue(mlir::acc::DeviceType::None);
+}
+
+std::optional<int64_t>
+RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(getGangDimDeviceType()))
+    return std::nullopt;
+  if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
+    auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
+    return intAttr.getInt();
+  }
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 8fa37bc98294ce..99b44183758d95 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1656,7 +1656,7 @@ acc.routine @acc_func_rout5 func(@acc_func) bind("acc_func_gpu_worker") worker
 acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
 acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") implicit gang
 acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32)
+acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64)
 
 // CHECK-LABEL: func.func @acc_func(
 // CHECK: attributes {acc.routine_info = #acc.routine_info<[@acc_func_rout1, @acc_func_rout2, @acc_func_rout3,
@@ -1669,7 +1669,7 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
 // CHECK: acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
 // CHECK: acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") gang implicit
 // CHECK: acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32)
+// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64)
 
 // -----
 

diff  --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index d78d7b0fdf6769..474f887928992a 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -347,3 +347,61 @@ TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) {
   }
   op->removeVectorAttr();
 }
+
+TEST_F(OpenACCOpsTest, routineOpTest) {
+  OwningOpRef<RoutineOp> op =
+      b.create<RoutineOp>(loc, TypeRange{}, ValueRange{});
+
+  EXPECT_FALSE(op->hasSeq());
+  EXPECT_FALSE(op->hasVector());
+  EXPECT_FALSE(op->hasWorker());
+
+  for (auto d : dtypes) {
+    EXPECT_FALSE(op->hasSeq(d));
+    EXPECT_FALSE(op->hasVector(d));
+    EXPECT_FALSE(op->hasWorker(d));
+  }
+
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op->setSeqAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasSeq());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasSeq(d));
+  op->removeSeqAttr();
+
+  op->setVectorAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasVector());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasVector(d));
+  op->removeVectorAttr();
+
+  op->setWorkerAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasWorker());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasWorker(d));
+  op->removeWorkerAttr();
+
+  op->setGangAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasGang());
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->hasGang(d));
+  op->removeGangAttr();
+
+  op->setGangDimDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+  op->setGangDimAttr(b.getArrayAttr({b.getIntegerAttr(b.getI64Type(), 8)}));
+  EXPECT_TRUE(op->getGangDimValue().has_value());
+  EXPECT_EQ(op->getGangDimValue().value(), 8);
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->getGangDimValue(d).has_value());
+  op->removeGangDimDeviceTypeAttr();
+  op->removeGangDimAttr();
+
+  op->setBindNameDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+  op->setBindNameAttr(b.getArrayAttr({b.getStringAttr("fname")}));
+  EXPECT_TRUE(op->getBindNameValue().has_value());
+  EXPECT_EQ(op->getBindNameValue().value(), "fname");
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op->getBindNameValue(d).has_value());
+  op->removeBindNameDeviceTypeAttr();
+  op->removeBindNameAttr();
+}


        


More information about the flang-commits mailing list