[flang-commits] [mlir] [flang] [mlir][flang][openacc] Add device_type support for update op (PR #78764)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 19 11:19:44 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Add support for device_type information on the acc.update operation and update lowering from Flang. 

---

Patch is 29.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78764.diff


6 Files Affected:

- (modified) flang/lib/Lower/OpenACC.cpp (+43-37) 
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+8-13) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+39-10) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+189-40) 
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+2-2) 
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+6-6) 


``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 682ca06cabd6f6..541ea2e114324f 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2840,27 +2840,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static inline mlir::ArrayAttr
+getArrayAttr(fir::FirOpBuilder &b,
+             llvm::SmallVector<mlir::Attribute> &attributes) {
+  return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
+}
+
+static inline mlir::DenseI32ArrayAttr
+getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
+                     llvm::SmallVector<int32_t> &values) {
+  return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
+}
+
 static void
 genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                mlir::Location currentLocation,
                Fortran::semantics::SemanticsContext &semanticsContext,
                Fortran::lower::StatementContext &stmtCtx,
                const Fortran::parser::AccClauseList &accClauseList) {
-  mlir::Value ifCond, async, waitDevnum;
+  mlir::Value ifCond, waitDevnum;
   llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
-      waitOperands, deviceTypeOperands;
-  llvm::SmallVector<mlir::Attribute> deviceTypes;
-
-  // Async and wait clause have optional values but can be present with
-  // no value as well. When there is no value, the op has an attribute to
-  // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
-  bool addIfPresentAttr = false;
+      waitOperands, deviceTypeOperands, asyncOperands;
+  llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
+      asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> waitOperandsSegments;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
-  // Lower clauses values mapped to operands.
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
+
+  bool ifPresent = false;
+
+  // Lower clauses values mapped to operands and array attributes.
   // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2870,15 +2885,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      genAsyncClause(converter, asyncClause, asyncOperands,
+                     asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
+                     crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      genWaitClause(converter, waitClause, waitOperands,
+                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2892,7 +2911,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
           dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
           /*implicit=*/false);
     } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
-      addIfPresentAttr = true;
+      ifPresent = true;
     } else if (const auto *selfClause =
                    std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
       const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -2909,30 +2928,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
 
   dataClauseOperands.append(updateHostOperands);
 
-  // Prepare the operand segment size attribute and the operands value range.
-  llvm::SmallVector<mlir::Value> operands;
-  llvm::SmallVector<int32_t> operandSegments;
-  addOperand(operands, operandSegments, ifCond);
-  addOperand(operands, operandSegments, async);
-  addOperand(operands, operandSegments, waitDevnum);
-  addOperands(operands, operandSegments, waitOperands);
-  addOperands(operands, operandSegments, dataClauseOperands);
-
-  mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
-      builder, currentLocation, operands, operandSegments);
-  if (!deviceTypes.empty())
-    updateOp.setDeviceTypesAttr(
-        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
+  builder.create<mlir::acc::UpdateOp>(
+      currentLocation, ifCond, asyncOperands,
+      getArrayAttr(builder, asyncOperandsDeviceTypes),
+      getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+      getDenseI32ArrayAttr(builder, waitOperandsSegments),
+      getArrayAttr(builder, waitOperandsDeviceTypes),
+      getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
+      ifPresent);
 
   genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
       builder, updateHostOperands, /*structured=*/false);
-
-  if (addAsyncAttr)
-    updateOp.setAsyncAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    updateOp.setWaitAttr(builder.getUnitAttr());
-  if (addIfPresentAttr)
-    updateOp.setIfPresentAttr(builder.getUnitAttr());
 }
 
 static void
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index d2b15f8bd258e7..ac7a56c56b1f20 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -61,17 +61,17 @@ subroutine acc_update
 
   !$acc update host(a) async
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+! CHECK: acc.update async dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) async wait
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+! CHECK: acc.update async wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) async(1)
@@ -89,14 +89,14 @@ subroutine acc_update
   !$acc update host(a) wait(1)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait(queues: 1, 2)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait(devnum: 1: queues: 1, 2)
@@ -104,17 +104,12 @@ subroutine acc_update
 ! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
 ! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
 ! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
-  !$acc update host(a) device_type(default, host)
+  !$acc update host(a) device_type(host, nvidia) async
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]} 
-! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
-
-  !$acc update host(a) device_type(*)
-! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]} 
+! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
 end subroutine acc_update
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 7344ab2852b9ce..5b678e84b93ee4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2187,14 +2187,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
   }];
 
   let arguments = (ins Optional<I1>:$ifCond,
-                       Optional<IntOrIndex>:$asyncOperand,
-                       Optional<IntOrIndex>:$waitDevnum,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$async,
-                       UnitAttr:$wait,
-                       OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       UnitAttr:$ifPresent);
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$async,
+      Optional<IntOrIndex>:$waitDevnum,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$wait,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      UnitAttr:$ifPresent);
 
   let extraClassDeclaration = [{
     /// The number of data operands.
@@ -2202,14 +2204,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsync();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsync(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWait();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWait(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `if` `(` $ifCond `)`
-      | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
+      | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $asyncOperands, type($asyncOperands),
+            $asyncOperandsDeviceType, $async)
       | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `wait` `` custom<WaitClause>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, 
+            $waitOperandsSegments, $wait)
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..4e31f7b163b9dc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -936,6 +936,138 @@ static void printDeviceTypeOperandsWithSegment(
   });
 }
 
+static ParseResult parseWaitClause(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+  llvm::SmallVector<int32_t> seg;
+
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(keywordAttrs.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  do {
+    if (failed(parser.parseLBrace()))
+      return failure();
+
+    int32_t crtOperandsSize = operands.size();
+
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(operands.emplace_back()) ||
+                  parser.parseColonType(types.emplace_back()))
+                return failure();
+              return success();
+            })))
+      return failure();
+
+    seg.push_back(operands.size() - crtOperandsSize);
+
+    if (failed(parser.parseRBrace()))
+      return failure();
+
+    if (succeeded(parser.parseOptionalLSquare())) {
+      if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+          parser.parseRSquare())
+        return failure();
+    } else {
+      deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+          parser.getContext(), mlir::acc::DeviceType::None));
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+  return success();
+}
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+
+  p << "[";
+  llvm::interleaveComma(*deviceTypes, p,
+                        [&](mlir::Attribute attr) { p << attr; });
+  p << "]";
+}
+
+static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
+  if (!hasDeviceTypeValues(attrs))
+    return false;
+  if (attrs->size() != 1)
+    return false;
+  if (auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
+    return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
+  return false;
+}
+
+static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                            mlir::OperandRange operands, mlir::TypeRange types,
+                            std::optional<mlir::ArrayAttr> deviceTypes,
+                            std::optional<mlir::DenseI32ArrayAttr> segments,
+                            std::optional<mlir::ArrayAttr> keywordOnly) {
+
+  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
+    return;
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnly);
+  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  unsigned opIdx = 0;
+  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
+    p << "{";
+    llvm::interleaveComma(
+        llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
+          p << operands[opIdx] << " : " << operands[opIdx].getType();
+          ++opIdx;
+        });
+    p << "}";
+    printSingleDeviceType(p, it.value());
+  });
+
+  p << ")";
+}
+
 static ParseResult parseDeviceTypeOperands(
     mlir::OpAsmP...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/78764


More information about the flang-commits mailing list