[flang-commits] [flang] [mlir] [mlir][flang][openacc] Add device_type support for update op (PR #78764)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Mon Jan 22 14:51:35 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/78764

>From 96b042b1e53b86896bdd1e5e3a35b1ed628fc8fe Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 18 Jan 2024 10:09:31 -0800
Subject: [PATCH 1/2] [mlir][flang][openacc] Add device_type support for update
 op

---
 flang/lib/Lower/OpenACC.cpp                   |  80 +++----
 flang/test/Lower/OpenACC/acc-update.f90       |  21 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  49 ++++-
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 204 +++++++++++++++---
 mlir/test/Dialect/OpenACC/invalid.mlir        |   4 +-
 mlir/test/Dialect/OpenACC/ops.mlir            |  18 +-
 6 files changed, 279 insertions(+), 97 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index d619d47fc235971..dac14242bd800aa 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2941,27 +2941,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static inline mlir::ArrayAttr
+getArrayAttr(fir::FirOpBuilder &b,
+             llvm::SmallVector<mlir::Attribute> &attributes) {
+  return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
+}
+
+static inline mlir::DenseI32ArrayAttr
+getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
+                     llvm::SmallVector<int32_t> &values) {
+  return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
+}
+
 static void
 genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                mlir::Location currentLocation,
                Fortran::semantics::SemanticsContext &semanticsContext,
                Fortran::lower::StatementContext &stmtCtx,
                const Fortran::parser::AccClauseList &accClauseList) {
-  mlir::Value ifCond, async, waitDevnum;
+  mlir::Value ifCond, waitDevnum;
   llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
-      waitOperands, deviceTypeOperands;
-  llvm::SmallVector<mlir::Attribute> deviceTypes;
-
-  // Async and wait clause have optional values but can be present with
-  // no value as well. When there is no value, the op has an attribute to
-  // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
-  bool addIfPresentAttr = false;
+      waitOperands, deviceTypeOperands, asyncOperands;
+  llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
+      asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> waitOperandsSegments;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
-  // Lower clauses values mapped to operands.
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
+
+  bool ifPresent = false;
+
+  // Lower clauses values mapped to operands and array attributes.
   // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2971,15 +2986,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      genAsyncClause(converter, asyncClause, asyncOperands,
+                     asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
+                     crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      genWaitClause(converter, waitClause, waitOperands,
+                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2993,7 +3012,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
           dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
           /*implicit=*/false);
     } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
-      addIfPresentAttr = true;
+      ifPresent = true;
     } else if (const auto *selfClause =
                    std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
       const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -3010,30 +3029,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
 
   dataClauseOperands.append(updateHostOperands);
 
-  // Prepare the operand segment size attribute and the operands value range.
-  llvm::SmallVector<mlir::Value> operands;
-  llvm::SmallVector<int32_t> operandSegments;
-  addOperand(operands, operandSegments, ifCond);
-  addOperand(operands, operandSegments, async);
-  addOperand(operands, operandSegments, waitDevnum);
-  addOperands(operands, operandSegments, waitOperands);
-  addOperands(operands, operandSegments, dataClauseOperands);
-
-  mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
-      builder, currentLocation, operands, operandSegments);
-  if (!deviceTypes.empty())
-    updateOp.setDeviceTypesAttr(
-        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
+  builder.create<mlir::acc::UpdateOp>(
+      currentLocation, ifCond, asyncOperands,
+      getArrayAttr(builder, asyncOperandsDeviceTypes),
+      getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+      getDenseI32ArrayAttr(builder, waitOperandsSegments),
+      getArrayAttr(builder, waitOperandsDeviceTypes),
+      getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
+      ifPresent);
 
   genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
       builder, updateHostOperands, /*structured=*/false);
-
-  if (addAsyncAttr)
-    updateOp.setAsyncAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    updateOp.setWaitAttr(builder.getUnitAttr());
-  if (addIfPresentAttr)
-    updateOp.setIfPresentAttr(builder.getUnitAttr());
 }
 
 static void
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index d2b15f8bd258e7a..ba036ac92811826 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -61,17 +61,17 @@ subroutine acc_update
 
   !$acc update host(a) async
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+! CHECK: acc.update async() dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) async wait
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+! CHECK: acc.update async() wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) async(1)
@@ -89,14 +89,14 @@ subroutine acc_update
   !$acc update host(a) wait(1)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait(queues: 1, 2)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait(devnum: 1: queues: 1, 2)
@@ -104,17 +104,12 @@ subroutine acc_update
 ! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
 ! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
 ! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
-  !$acc update host(a) device_type(default, host)
+  !$acc update host(a) device_type(host, nvidia) async
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]} 
-! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
-
-  !$acc update host(a) device_type(*)
-! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]} 
+! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
 end subroutine acc_update
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 992f2809644a6a8..87fd587782e7c35 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2196,14 +2196,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
   }];
 
   let arguments = (ins Optional<I1>:$ifCond,
-                       Optional<IntOrIndex>:$asyncOperand,
-                       Optional<IntOrIndex>:$waitDevnum,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$async,
-                       UnitAttr:$wait,
-                       OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       UnitAttr:$ifPresent);
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$async,
+      Optional<IntOrIndex>:$waitDevnum,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$wait,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      UnitAttr:$ifPresent);
 
   let extraClassDeclaration = [{
     /// The number of data operands.
@@ -2211,14 +2213,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsyncOnly();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWaitOnly();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `if` `(` $ifCond `)`
-      | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
+      | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $asyncOperands, type($asyncOperands),
+            $asyncOperandsDeviceType, $async)
       | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `wait` `` custom<WaitClause>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, 
+            $waitOperandsSegments, $wait)
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index f6229e5192a0abd..c04b563d20d7b20 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -963,6 +963,121 @@ static void printDeviceTypeOperandsWithSegment(
   });
 }
 
+static ParseResult parseWaitClause(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+  llvm::SmallVector<int32_t> seg;
+
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(keywordAttrs.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  do {
+    if (failed(parser.parseLBrace()))
+      return failure();
+
+    int32_t crtOperandsSize = operands.size();
+
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(operands.emplace_back()) ||
+                  parser.parseColonType(types.emplace_back()))
+                return failure();
+              return success();
+            })))
+      return failure();
+
+    seg.push_back(operands.size() - crtOperandsSize);
+
+    if (failed(parser.parseRBrace()))
+      return failure();
+
+    if (succeeded(parser.parseOptionalLSquare())) {
+      if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+          parser.parseRSquare())
+        return failure();
+    } else {
+      deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+          parser.getContext(), mlir::acc::DeviceType::None));
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+  return success();
+}
+
+static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
+  if (!hasDeviceTypeValues(attrs))
+    return false;
+  if (attrs->size() != 1)
+    return false;
+  if (auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
+    return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
+  return false;
+}
+
+static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                            mlir::OperandRange operands, mlir::TypeRange types,
+                            std::optional<mlir::ArrayAttr> deviceTypes,
+                            std::optional<mlir::DenseI32ArrayAttr> segments,
+                            std::optional<mlir::ArrayAttr> keywordOnly) {
+
+  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
+    return;
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnly);
+  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  unsigned opIdx = 0;
+  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
+    p << "{";
+    llvm::interleaveComma(
+        llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
+          p << operands[opIdx] << " : " << operands[opIdx].getType();
+          ++opIdx;
+        });
+    p << "}";
+    printSingleDeviceType(p, it.value());
+  });
+
+  p << ")";
+}
+
 static ParseResult parseDeviceTypeOperands(
     mlir::OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -993,6 +1108,8 @@ static void
 printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
                         mlir::OperandRange operands, mlir::TypeRange types,
                         std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
   llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
     p << std::get<1>(it) << " : " << std::get<1>(it).getType();
     printSingleDeviceType(p, std::get<0>(it));
@@ -1068,15 +1185,10 @@ static void printDeviceTypeOperandsWithKeywordOnly(
     std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
 
   p << "(";
-
-  if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
-      keywordOnlyDeviceTypes->size() == 1) {
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
-    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) {
-      p << ")";
-      return;
-    }
+  if (operands.begin() == operands.end() &&
+      hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
+    p << ")";
+    return;
   }
 
   printDeviceTypes(p, keywordOnlyDeviceTypes);
@@ -1452,14 +1564,9 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
 
   p << "(";
   if (operands.begin() == operands.end() &&
-      hasDeviceTypeValues(gangOnlyDeviceTypes) &&
-      gangOnlyDeviceTypes->size() == 1) {
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
-    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) {
-      p << ")";
-      return;
-    }
+      hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
+    p << ")";
+    return;
   }
 
   printDeviceTypes(p, gangOnlyDeviceTypes);
@@ -2432,15 +2539,20 @@ LogicalResult acc::UpdateOp::verify() {
   if (getDataClauseOperands().empty())
     return emitError("at least one value must be present in dataOperands");
 
-  // The async attribute represent the async clause without value. Therefore the
-  // attribute and operand cannot appear at the same time.
-  if (getAsyncOperand() && getAsync())
-    return emitError("async attribute cannot appear with asyncOperand");
+  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+       ++dtypeInt) {
+    auto dtype = static_cast<acc::DeviceType>(dtypeInt);
 
-  // The wait attribute represent the wait clause without values. Therefore the
-  // attribute and operands cannot appear at the same time.
-  if (!getWaitOperands().empty() && getWait())
-    return emitError("wait attribute cannot appear with waitOperands");
+    // The async attribute represent the async clause without value. Therefore
+    // the attribute and operand cannot appear at the same time.
+    if (getAsyncValue(dtype) && hasAsyncOnly(dtype))
+      return emitError("async attribute cannot appear with asyncOperand");
+
+    // The wait attribute represent the wait clause without values. Therefore
+    // the attribute and operands cannot appear at the same time.
+    if (!getWaitValues(dtype).empty() && hasWaitOnly(dtype))
+      return emitError("wait attribute cannot appear with waitOperands");
+  }
 
   if (getWaitDevnum() && getWaitOperands().empty())
     return emitError("wait_devnum cannot appear without waitOperands");
@@ -2459,7 +2571,7 @@ unsigned UpdateOp::getNumDataOperands() {
 }
 
 Value UpdateOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsyncOperand() ? 1 : 0;
+  unsigned numOptional = getAsyncOperands().size();
   numOptional += getWaitDevnum() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
   return getOperand(getWaitOperands().size() + numOptional + i);
@@ -2470,6 +2582,46 @@ void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveConstantIfCondition<UpdateOp>>(context);
 }
 
+bool UpdateOp::hasAsyncOnly() {
+  return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getAsync(), deviceType);
+}
+
+mlir::Value UpdateOp::getAsyncValue() {
+  return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
+    return {};
+
+  if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
+    return getAsyncOperands()[*pos];
+
+  return {};
+}
+
+bool UpdateOp::hasWaitOnly() {
+  return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+  return hasDeviceType(getWait(), deviceType);
+}
+
+mlir::Operation::operand_range UpdateOp::getWaitValues() {
+  return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+                               getWaitOperandsSegments(), deviceType);
+}
+
 //===----------------------------------------------------------------------===//
 // WaitOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 57ae5856149d114..80d439f19d9f4cf 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -138,7 +138,7 @@ acc.update wait_devnum(%cst: index) dataOperands(%0: memref<f32>)
 %value = memref.alloc() : memref<f32>
 %0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
 // expected-error at +1 {{async attribute cannot appear with asyncOperand}}
-acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async}
+acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async = [#acc.device_type<none>]} 
 
 // -----
 
@@ -146,7 +146,7 @@ acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async}
 %value = memref.alloc() : memref<f32>
 %0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
 // expected-error at +1 {{wait attribute cannot appear with waitOperands}}
-acc.update wait(%cst: index) dataOperands(%0: memref<f32>) attributes {wait}
+acc.update wait({%cst: index}) dataOperands(%0: memref<f32>) attributes {wait = [#acc.device_type<none>]} 
 
 // -----
 
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index d4c884a837f875b..45b41f1a7722566 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -934,17 +934,17 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
   acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
   acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
   acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
-  acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) dataOperands(%0: memref<f32>)
+  acc.update wait_devnum(%i64Value: i64) wait({%i32Value : i32, %idxValue : index}) dataOperands(%0: memref<f32>)
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
-  acc.update dataOperands(%0: memref<f32>) attributes {acc.device_types = [#acc.device_type<star>]}
+  acc.update dataOperands(%0: memref<f32>)
   acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
-  acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {async}
-  acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
+  acc.update async() dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
+  acc.update wait dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
   acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {ifPresent}
   return
 }
 
-// CHECK: func @testupdateop([[ARGA:%.*]]: memref<f32>, [[ARGB:%.*]]: memref<f32>, [[ARGC:%.*]]: memref<f32>) {
+// CHECK: func.func @testupdateop(%{{.*}}: memref<f32>, %{{.*}}: memref<f32>, %{{.*}}: memref<f32>)
 // CHECK:   [[I64VALUE:%.*]] = arith.constant 1 : i64
 // CHECK:   [[I32VALUE:%.*]] = arith.constant 1 : i32
 // CHECK:   [[IDXVALUE:%.*]] = arith.constant 1 : index
@@ -953,12 +953,12 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
 // CHECK:   acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
-// CHECK:   acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update wait_devnum([[I64VALUE]] : i64) wait({[[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
-// CHECK:   acc.update dataOperands(%{{.*}} : memref<f32>) attributes {acc.device_types = [#acc.device_type<star>]}
+// CHECK:   acc.update dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
-// CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {async}
-// CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
+// CHECK:   acc.update async() dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
+// CHECK:   acc.update wait dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
 // CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {ifPresent}
 
 // -----

>From 945231c361ffbbd6f89bcd791f472579051d1853 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 22 Jan 2024 14:51:03 -0800
Subject: [PATCH 2/2] Add check in the verifier

---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c04b563d20d7b20..e1e69113bca1683 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2539,6 +2539,16 @@ LogicalResult acc::UpdateOp::verify() {
   if (getDataClauseOperands().empty())
     return emitError("at least one value must be present in dataOperands");
 
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+                                        getAsyncOperandsDeviceTypeAttr(),
+                                        "async")))
+    return failure();
+
+  if (failed(verifyDeviceTypeAndSegmentCountMatch(
+          *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+          getWaitOperandsDeviceTypeAttr(), "wait")))
+    return failure();
+
   for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
        ++dtypeInt) {
     auto dtype = static_cast<acc::DeviceType>(dtypeInt);



More information about the flang-commits mailing list