[flang-commits] [flang] [mlir] [mlir][openacc][flang] Support wait devnum and homogenize async/wait IR (PR #79525)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Jan 25 15:49:05 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/79525

- Support wait(devnum: ) with device_type support on all operations that require it
  - devnum value is stored as the first value of waitOperands in its device_type sub-segment. The hasWaitDevnum attribute     inform which sub-segment has a wait(devnum) value.
- Make async/wait information homogenous on compute ops, data and update op.
  - Unify operands/attributes names across operations and use the same custom parser/printer

>From 196d4d2f85b444888e01f93d0bfde7d15228d7e5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 23 Jan 2024 09:55:28 -0800
Subject: [PATCH] [mlir][openacc][flang] Support wait devnum and homogenize
 async/wait IR

---
 flang/lib/Lower/OpenACC.cpp                   |  93 +++---
 flang/test/Lower/OpenACC/acc-data.f90         |   6 +-
 flang/test/Lower/OpenACC/acc-kernels-loop.f90 |   4 +-
 flang/test/Lower/OpenACC/acc-kernels.f90      |   4 +-
 .../test/Lower/OpenACC/acc-parallel-loop.f90  |   4 +-
 flang/test/Lower/OpenACC/acc-parallel.f90     |   4 +-
 flang/test/Lower/OpenACC/acc-serial-loop.f90  |   4 +-
 flang/test/Lower/OpenACC/acc-serial.f90       |   4 +-
 flang/test/Lower/OpenACC/acc-update.f90       |   5 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  96 +++---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 275 +++++++++++++-----
 mlir/test/Dialect/OpenACC/invalid.mlir        |  10 +-
 mlir/test/Dialect/OpenACC/ops.mlir            |   8 +-
 .../Dialect/OpenACC/OpenACCOpsTest.cpp        |  34 ++-
 14 files changed, 368 insertions(+), 183 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecfdaa5be993584..427b36a12a2df01 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
           builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, descTy);
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
           builder, loc, loadOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, loadOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
-static void
-genWaitClause(Fortran::lower::AbstractConverter &converter,
-              const Fortran::parser::AccClause::Wait *waitClause,
-              llvm::SmallVector<mlir::Value> &waitOperands,
-              llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
-              llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
-              llvm::SmallVector<int32_t> &waitOperandsSegments,
-              mlir::Value &waitDevnum,
-              llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
-              Fortran::lower::StatementContext &stmtCtx) {
+static void genWaitClauseWithDeviceType(
+    Fortran::lower::AbstractConverter &converter,
+    const Fortran::parser::AccClause::Wait *waitClause,
+    llvm::SmallVector<mlir::Value> &waitOperands,
+    llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+    llvm::SmallVector<bool> &hasDevnums,
+    llvm::SmallVector<int32_t> &waitOperandsSegments,
+    llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
+    Fortran::lower::StatementContext &stmtCtx) {
   const auto &waitClauseValue = waitClause->v;
   if (waitClauseValue) { // wait has a value.
+    llvm::SmallVector<mlir::Value> waitValues;
+
     const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+    const auto &waitDevnumValue =
+        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+    bool hasDevnum = false;
+    if (waitDevnumValue) {
+      waitValues.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
+      hasDevnum = true;
+    }
+
     const auto &waitList =
         std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    llvm::SmallVector<mlir::Value> waitValues;
     for (const Fortran::parser::ScalarIntExpr &value : waitList) {
       waitValues.push_back(fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(value), stmtCtx)));
     }
+
     for (auto deviceTypeAttr : deviceTypeAttrs) {
       for (auto value : waitValues)
         waitOperands.push_back(value);
       waitOperandsDeviceTypes.push_back(deviceTypeAttr);
       waitOperandsSegments.push_back(waitValues.size());
+      hasDevnums.push_back(hasDevnum);
     }
-
-    // TODO: move to device_type model.
-    const auto &waitDevnumValue =
-        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    if (waitDevnumValue)
-      waitDevnum = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
   } else {
     for (auto deviceTypeAttr : deviceTypeAttrs)
       waitOnlyDeviceTypes.push_back(deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
       waitOperandsDeviceTypes, waitOnlyDeviceTypes;
   llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+  llvm::SmallVector<bool> hasWaitDevnums;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
-  mlir::Value waitDevnum; // TODO not yet implemented on compute op.
 
   // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                      asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
           builder.getDenseI32ArrayAttr(numGangsSegments));
   }
   if (!asyncDeviceTypes.empty())
-    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+    computeOp.setAsyncOperandsDeviceTypeAttr(
+        builder.getArrayAttr(asyncDeviceTypes));
   if (!asyncOnlyDeviceTypes.empty())
     computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
 
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (!waitOperandsSegments.empty())
     computeOp.setWaitOperandsSegmentsAttr(
         builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!hasWaitDevnums.empty())
+    computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
   if (!waitOnlyDeviceTypes.empty())
     computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
       waitOperandsDeviceTypes, waitOnlyDeviceTypes;
   llvm::SmallVector<int32_t> waitOperandsSegments;
+  llvm::SmallVector<bool> hasWaitDevnums;
 
   bool hasDefaultNone = false;
   bool hasDefaultPresent = false;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                      asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if(const auto *defaultClause = 
                   std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
       if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<int32_t> operandSegments;
   addOperand(operands, operandSegments, ifCond);
   addOperands(operands, operandSegments, async);
-  addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
 
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
       operandSegments);
 
   if (!asyncDeviceTypes.empty())
-    dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+    dataOp.setAsyncOperandsDeviceTypeAttr(
+        builder.getArrayAttr(asyncDeviceTypes));
   if (!asyncOnlyDeviceTypes.empty())
     dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
   if (!waitOperandsDeviceTypes.empty())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   if (!waitOperandsSegments.empty())
     dataOp.setWaitOperandsSegmentsAttr(
         builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!hasWaitDevnums.empty())
+    dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
   if (!waitOnlyDeviceTypes.empty())
     dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
   return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
 }
 
+static inline mlir::ArrayAttr
+getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
+  return values.empty() ? nullptr : b.getBoolArrayAttr(values);
+}
+
 static inline mlir::DenseI32ArrayAttr
 getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
                      llvm::SmallVector<int32_t> &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
       waitOperands, deviceTypeOperands, asyncOperands;
   llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
       asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<bool> hasWaitDevnums;
   llvm::SmallVector<int32_t> waitOperandsSegments;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                      crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   builder.create<mlir::acc::UpdateOp>(
       currentLocation, ifCond, asyncOperands,
       getArrayAttr(builder, asyncOperandsDeviceTypes),
-      getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+      getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
       getDenseI32ArrayAttr(builder, waitOperandsSegments),
       getArrayAttr(builder, waitOperandsDeviceTypes),
+      getBoolArrayAttr(builder, hasWaitDevnums),
       getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
       ifPresent);
 
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index 75ffd1fc3fcab2f..5b4ab5a65ee6bd2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -164,8 +164,8 @@ subroutine acc_data
   !$acc data present(a) wait
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK: acc.data dataOperands(%{{.*}}) wait {
+! CHECK: }
 
   !$acc data present(a) wait(1)
   !$acc end data
@@ -176,7 +176,7 @@ subroutine acc_data
   !$acc data present(a) wait(devnum: 0: 1)
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
 ! CHECK: }{{$}}
 
   !$acc data default(none)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 21660d5c3a13163..d2134e8d2337ce6 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -93,12 +93,12 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.kernels {
+! CHECK:      acc.kernels wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc kernels loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723b..06194edbe165498 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -61,9 +61,9 @@ subroutine acc_kernels
   !$acc kernels wait
   !$acc end kernels
 
-! CHECK:      acc.kernels  {
+! CHECK:      acc.kernels wait {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc kernels wait(1)
   !$acc end kernels
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 614d201f98e26c4..24e443a20c895d1 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -95,12 +95,12 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.parallel {
+! CHECK:      acc.parallel wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc parallel loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f259955..6b37ecb5fab9aa6 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -83,9 +83,9 @@ subroutine acc_parallel
   !$acc parallel wait
   !$acc end parallel
 
-! CHECK:      acc.parallel {
+! CHECK:      acc.parallel wait {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc parallel wait(1)
   !$acc end parallel
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4134f9ff0ccf577..9c0dbff0d7dac16 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -114,12 +114,12 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.serial {
+! CHECK:      acc.serial wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc serial loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f45..d0fa9436be14a14 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -83,9 +83,9 @@ subroutine acc_serial
   !$acc serial wait
   !$acc end serial
 
-! CHECK:      acc.serial {
+! CHECK:      acc.serial wait {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc serial wait(1)
   !$acc end serial
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index ba036ac92811826..f42ae1356664b67 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -101,10 +101,7 @@ subroutine acc_update
 
   !$acc update host(a) wait(devnum: 1: queues: 1, 2)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) device_type(host, nvidia) async
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 87fd587782e7c35..9398cbfdacee469 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
   }];
 
   let arguments = (ins
-      Variadic<IntOrIndex>:$async,
-      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Variadic<IntOrIndex>:$numGangs,
       OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` custom<DeviceTypeOperands>($async,
-            type($async), $asyncDeviceType) `)`
+      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+            type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
             type($gangFirstPrivateOperands), $firstprivatizations)
         `)`
@@ -998,8 +1004,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
         `)`
       | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
             type($vectorLength), $vectorLengthDeviceType) `)`
-      | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
-            type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+      | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+          $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+          $waitOnly)
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
       | `reduction` `(` custom<SymOperandList>(
@@ -1034,12 +1041,13 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
   }];
 
   let arguments = (ins
-      Variadic<IntOrIndex>:$async,
-      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Optional<I1>:$ifCond,
       Optional<I1>:$selfCond,
@@ -1084,21 +1092,27 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` custom<DeviceTypeOperands>($async,
-            type($async), $asyncDeviceType) `)`
+      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+            type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
             type($gangFirstPrivateOperands), $firstprivatizations)
         `)`
       | `private` `(` custom<SymOperandList>(
             $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
         `)`
-      | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
-            type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+      | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+          $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+          $waitOnly)
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
       | `reduction` `(` custom<SymOperandList>(
@@ -1135,12 +1149,13 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
   }];
 
   let arguments = (ins
-      Variadic<IntOrIndex>:$async,
-      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Variadic<IntOrIndex>:$numGangs,
       OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -1205,22 +1220,27 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` custom<DeviceTypeOperands>($async,
-            type($async), $asyncDeviceType) `)`
+      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+            type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `num_gangs` `(` custom<NumGangs>($numGangs,
             type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
       | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
             type($numWorkers), $numWorkersDeviceType) `)`
       | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
             type($vectorLength), $vectorLengthDeviceType) `)`
-      | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
-            type($waitOperands), $waitOperandsDeviceType,
-            $waitOperandsSegments) `)`
+      | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+          $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+          $waitOnly)
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
     )
@@ -1258,13 +1278,13 @@ def OpenACC_DataOp : OpenACC_Op<"data",
 
 
   let arguments = (ins Optional<I1>:$ifCond,
-      Variadic<IntOrIndex>:$async,
-      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
-      Optional<IntOrIndex>:$waitDevnum,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       OptionalAttr<DefaultValueAttr>:$defaultAttr);
@@ -1300,18 +1320,22 @@ def OpenACC_DataOp : OpenACC_Op<"data",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `if` `(` $ifCond `)`
-      | `async` `(` custom<DeviceTypeOperands>($async,
-            type($async), $asyncDeviceType) `)`
+      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+            type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
-            type($waitOperands), $waitOperandsDeviceType, 
-            $waitOperandsSegments) `)`
+      | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+          $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+          $waitOnly)
     )
     $region attr-dict-with-keyword
   }];
@@ -2199,11 +2223,11 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
       Variadic<IntOrIndex>:$asyncOperands,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$async,
-      Optional<IntOrIndex>:$waitDevnum,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
-      OptionalAttr<DeviceTypeArrayAttr>:$wait,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       UnitAttr:$ifPresent);
 
@@ -2236,6 +2260,11 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
@@ -2244,10 +2273,9 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
       | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
             $asyncOperands, type($asyncOperands),
             $asyncOperandsDeviceType, $async)
-      | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `wait` `` custom<WaitClause>($waitOperands,
-            type($waitOperands), $waitOperandsDeviceType, 
-            $waitOperandsSegments, $wait)
+      | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+          $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+          $waitOnly)
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e1e69113bca1683..042153ac749102b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -104,6 +104,87 @@ static void printDeviceTypes(mlir::OpAsmPrinter &p,
   p << "]";
 }
 
+static std::optional<unsigned> findSegment(ArrayAttr segments,
+                                           mlir::acc::DeviceType deviceType) {
+  unsigned segmentIdx = 0;
+  for (auto attr : segments) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return std::make_optional(segmentIdx);
+    ++segmentIdx;
+  }
+  return std::nullopt;
+}
+
+static mlir::Operation::operand_range
+getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
+                      mlir::Operation::operand_range range,
+                      std::optional<llvm::ArrayRef<int32_t>> segments,
+                      mlir::acc::DeviceType deviceType) {
+  if (!arrayAttr)
+    return range.take_front(0);
+  if (auto pos = findSegment(*arrayAttr, deviceType)) {
+    int32_t nbOperandsBefore = 0;
+    for (unsigned i = 0; i < *pos; ++i)
+      nbOperandsBefore += (*segments)[i];
+    return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
+  }
+  return range.take_front(0);
+}
+
+static mlir::Value
+getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
+                   mlir::Operation::operand_range operands,
+                   std::optional<llvm::ArrayRef<int32_t>> segments,
+                   std::optional<mlir::ArrayAttr> hasWaitDevnum,
+                   mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(deviceTypeAttr))
+    return {};
+  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
+    if (hasWaitDevnum->getValue()[*pos])
+      return getValuesFromSegments(deviceTypeAttr, operands, segments,
+                                   deviceType)
+          .front();
+  return {};
+}
+
+static mlir::Operation::operand_range
+getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
+                           mlir::Operation::operand_range operands,
+                           std::optional<llvm::ArrayRef<int32_t>> segments,
+                           std::optional<mlir::ArrayAttr> hasWaitDevnum,
+                           mlir::acc::DeviceType deviceType) {
+  auto range =
+      getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
+  if (range.empty())
+    return range;
+  if (auto pos = findSegment(*deviceTypeAttr, deviceType))
+    if (hasWaitDevnum && *hasWaitDevnum && hasWaitDevnum->getValue()[*pos])
+      return range.drop_front(1); // first value is devnum
+  return range;
+}
+
+template <typename Op>
+static LogicalResult checkWaitAndAsyncConflict(Op op) {
+  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+       ++dtypeInt) {
+    auto dtype = static_cast<acc::DeviceType>(dtypeInt);
+
+    // The async attribute represent the async clause without value. Therefore
+    // the attribute and operand cannot appear at the same time.
+    if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
+        op.hasAsyncOnly(dtype))
+      return op.emitError("async attribute cannot appear with asyncOperand");
+
+    // The wait attribute represent the wait clause without values. Therefore
+    // the attribute and operands cannot appear at the same time.
+    if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
+        op.hasWaitOnly(dtype))
+      return op.emitError("wait attribute cannot appear with waitOperands");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DataBoundsOp
 //===----------------------------------------------------------------------===//
@@ -649,7 +730,7 @@ unsigned ParallelOp::getNumDataOperands() {
 }
 
 Value ParallelOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsync().size();
+  unsigned numOptional = getAsyncOperands().size();
   numOptional += getNumGangs().size();
   numOptional += getNumWorkers().size();
   numOptional += getVectorLength().size();
@@ -722,23 +803,15 @@ LogicalResult acc::ParallelOp::verify() {
                                         "vector_length")))
     return failure();
 
-  if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
-                                        getAsyncDeviceTypeAttr(), "async")))
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+                                        getAsyncOperandsDeviceTypeAttr(),
+                                        "async")))
     return failure();
 
-  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
-}
+  if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
+    return failure();
 
-static std::optional<unsigned> findSegment(ArrayAttr segments,
-                                           mlir::acc::DeviceType deviceType) {
-  unsigned segmentIdx = 0;
-  for (auto attr : segments) {
-    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
-    if (deviceTypeAttr.getValue() == deviceType)
-      return std::make_optional(segmentIdx);
-    ++segmentIdx;
-  }
-  return std::nullopt;
+  return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
 static mlir::Value
@@ -765,8 +838,8 @@ mlir::Value acc::ParallelOp::getAsyncValue() {
 }
 
 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
-  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
-                                     deviceType);
+  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+                                     getAsyncOperands(), deviceType);
 }
 
 mlir::Value acc::ParallelOp::getNumWorkersValue() {
@@ -793,22 +866,6 @@ mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
   return getNumGangsValues(mlir::acc::DeviceType::None);
 }
 
-static mlir::Operation::operand_range
-getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
-                      mlir::Operation::operand_range range,
-                      std::optional<llvm::ArrayRef<int32_t>> segments,
-                      mlir::acc::DeviceType deviceType) {
-  if (!arrayAttr)
-    return range.take_front(0);
-  if (auto pos = findSegment(*arrayAttr, deviceType)) {
-    int32_t nbOperandsBefore = 0;
-    for (unsigned i = 0; i < *pos; ++i)
-      nbOperandsBefore += (*segments)[i];
-    return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
-  }
-  return range.take_front(0);
-}
-
 mlir::Operation::operand_range
 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
   return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
@@ -829,8 +886,19 @@ mlir::Operation::operand_range ParallelOp::getWaitValues() {
 
 mlir::Operation::operand_range
 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
-  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
-                               getWaitOperandsSegments(), deviceType);
+  return getWaitValuesWithoutDevnum(
+      getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+      getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value ParallelOp::getWaitDevnum() {
+  return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+                            getWaitOperandsSegments(), getHasWaitDevnum(),
+                            deviceType);
 }
 
 static ParseResult parseNumGangs(
@@ -967,8 +1035,9 @@ static ParseResult parseWaitClause(
     mlir::OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
     llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
-    mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
-  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+    mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
+    mlir::ArrayAttr &keywordOnly) {
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
   llvm::SmallVector<int32_t> seg;
 
   bool needCommaBeforeOperands = false;
@@ -1003,6 +1072,14 @@ static ParseResult parseWaitClause(
 
     int32_t crtOperandsSize = operands.size();
 
+    if (succeeded(parser.parseOptionalKeyword("devnum"))) {
+      if (failed(parser.parseColon()))
+        return failure();
+      devnum.push_back(BoolAttr::get(parser.getContext(), true));
+    } else {
+      devnum.push_back(BoolAttr::get(parser.getContext(), false));
+    }
+
     if (failed(parser.parseCommaSeparatedList(
             mlir::AsmParser::Delimiter::None, [&]() {
               if (parser.parseOperand(operands.emplace_back()) ||
@@ -1033,6 +1110,7 @@ static ParseResult parseWaitClause(
   deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
   keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+  hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
 
   return success();
 }
@@ -1052,6 +1130,7 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
                             mlir::OperandRange operands, mlir::TypeRange types,
                             std::optional<mlir::ArrayAttr> deviceTypes,
                             std::optional<mlir::DenseI32ArrayAttr> segments,
+                            std::optional<mlir::ArrayAttr> hasDevNum,
                             std::optional<mlir::ArrayAttr> keywordOnly) {
 
   if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
@@ -1066,6 +1145,9 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
   unsigned opIdx = 0;
   llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
     p << "{";
+    auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
+    if (boolAttr && boolAttr.getValue())
+      p << "devnum: ";
     llvm::interleaveComma(
         llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
           p << operands[opIdx] << " : " << operands[opIdx].getType();
@@ -1209,7 +1291,7 @@ unsigned SerialOp::getNumDataOperands() {
 }
 
 Value SerialOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsync().size();
+  unsigned numOptional = getAsyncOperands().size();
   numOptional += getIfCond() ? 1 : 0;
   numOptional += getSelfCond() ? 1 : 0;
   return getOperand(getWaitOperands().size() + numOptional + i);
@@ -1228,8 +1310,8 @@ mlir::Value acc::SerialOp::getAsyncValue() {
 }
 
 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
-  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
-                                     deviceType);
+  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+                                     getAsyncOperands(), deviceType);
 }
 
 bool acc::SerialOp::hasWaitOnly() {
@@ -1246,8 +1328,19 @@ mlir::Operation::operand_range SerialOp::getWaitValues() {
 
 mlir::Operation::operand_range
 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
-  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
-                               getWaitOperandsSegments(), deviceType);
+  return getWaitValuesWithoutDevnum(
+      getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+      getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value SerialOp::getWaitDevnum() {
+  return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+                            getWaitOperandsSegments(), getHasWaitDevnum(),
+                            deviceType);
 }
 
 LogicalResult acc::SerialOp::verify() {
@@ -1265,8 +1358,12 @@ LogicalResult acc::SerialOp::verify() {
           getWaitOperandsDeviceTypeAttr(), "wait")))
     return failure();
 
-  if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
-                                        getAsyncDeviceTypeAttr(), "async")))
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+                                        getAsyncOperandsDeviceTypeAttr(),
+                                        "async")))
+    return failure();
+
+  if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
     return failure();
 
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
@@ -1281,7 +1378,7 @@ unsigned KernelsOp::getNumDataOperands() {
 }
 
 Value KernelsOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsync().size();
+  unsigned numOptional = getAsyncOperands().size();
   numOptional += getWaitOperands().size();
   numOptional += getNumGangs().size();
   numOptional += getNumWorkers().size();
@@ -1304,8 +1401,8 @@ mlir::Value acc::KernelsOp::getAsyncValue() {
 }
 
 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
-  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
-                                     deviceType);
+  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+                                     getAsyncOperands(), deviceType);
 }
 
 mlir::Value acc::KernelsOp::getNumWorkersValue() {
@@ -1352,8 +1449,19 @@ mlir::Operation::operand_range KernelsOp::getWaitValues() {
 
 mlir::Operation::operand_range
 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
-  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
-                               getWaitOperandsSegments(), deviceType);
+  return getWaitValuesWithoutDevnum(
+      getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+      getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value KernelsOp::getWaitDevnum() {
+  return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+                            getWaitOperandsSegments(), getHasWaitDevnum(),
+                            deviceType);
 }
 
 LogicalResult acc::KernelsOp::verify() {
@@ -1377,8 +1485,12 @@ LogicalResult acc::KernelsOp::verify() {
                                         "vector_length")))
     return failure();
 
-  if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
-                                        getAsyncDeviceTypeAttr(), "async")))
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+                                        getAsyncOperandsDeviceTypeAttr(),
+                                        "async")))
+    return failure();
+
+  if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
     return failure();
 
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
@@ -1943,6 +2055,9 @@ LogicalResult acc::DataOp::verify() {
       return emitError("expect data entry/exit operation or acc.getdeviceptr "
                        "as defining op");
 
+  if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
+    return failure();
+
   return success();
 }
 
@@ -1950,7 +2065,7 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
 
 Value DataOp::getDataOperand(unsigned i) {
   unsigned numOptional = getIfCond() ? 1 : 0;
-  numOptional += getAsync().size() ? 1 : 0;
+  numOptional += getAsyncOperands().size() ? 1 : 0;
   numOptional += getWaitOperands().size();
   return getOperand(numOptional + i);
 }
@@ -1968,8 +2083,8 @@ mlir::Value DataOp::getAsyncValue() {
 }
 
 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
-  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
-                                     deviceType);
+  return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+                                     getAsyncOperands(), deviceType);
 }
 
 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
@@ -1984,8 +2099,19 @@ mlir::Operation::operand_range DataOp::getWaitValues() {
 
 mlir::Operation::operand_range
 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
-  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
-                               getWaitOperandsSegments(), deviceType);
+  return getWaitValuesWithoutDevnum(
+      getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+      getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value DataOp::getWaitDevnum() {
+  return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+                            getWaitOperandsSegments(), getHasWaitDevnum(),
+                            deviceType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2549,23 +2675,8 @@ LogicalResult acc::UpdateOp::verify() {
           getWaitOperandsDeviceTypeAttr(), "wait")))
     return failure();
 
-  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
-       ++dtypeInt) {
-    auto dtype = static_cast<acc::DeviceType>(dtypeInt);
-
-    // The async attribute represent the async clause without value. Therefore
-    // the attribute and operand cannot appear at the same time.
-    if (getAsyncValue(dtype) && hasAsyncOnly(dtype))
-      return emitError("async attribute cannot appear with asyncOperand");
-
-    // The wait attribute represent the wait clause without values. Therefore
-    // the attribute and operands cannot appear at the same time.
-    if (!getWaitValues(dtype).empty() && hasWaitOnly(dtype))
-      return emitError("wait attribute cannot appear with waitOperands");
-  }
-
-  if (getWaitDevnum() && getWaitOperands().empty())
-    return emitError("wait_devnum cannot appear without waitOperands");
+  if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
+    return failure();
 
   for (mlir::Value operand : getDataClauseOperands())
     if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
@@ -2582,7 +2693,6 @@ unsigned UpdateOp::getNumDataOperands() {
 
 Value UpdateOp::getDataOperand(unsigned i) {
   unsigned numOptional = getAsyncOperands().size();
-  numOptional += getWaitDevnum() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
   return getOperand(getWaitOperands().size() + numOptional + i);
 }
@@ -2619,7 +2729,7 @@ bool UpdateOp::hasWaitOnly() {
 }
 
 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  return hasDeviceType(getWait(), deviceType);
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range UpdateOp::getWaitValues() {
@@ -2628,8 +2738,19 @@ mlir::Operation::operand_range UpdateOp::getWaitValues() {
 
 mlir::Operation::operand_range
 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
-  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
-                               getWaitOperandsSegments(), deviceType);
+  return getWaitValuesWithoutDevnum(
+      getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+      getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value UpdateOp::getWaitDevnum() {
+  return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+  return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+                            getWaitOperandsSegments(), getHasWaitDevnum(),
+                            deviceType);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 80d439f19d9f4cf..16df33eec642ce7 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -126,14 +126,6 @@ acc.update
 
 // -----
 
-%cst = arith.constant 1 : index
-%value = memref.alloc() : memref<f32>
-%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
-// expected-error at +1 {{wait_devnum cannot appear without waitOperands}}
-acc.update wait_devnum(%cst: index) dataOperands(%0: memref<f32>)
-
-// -----
-
 %cst = arith.constant 1 : index
 %value = memref.alloc() : memref<f32>
 %0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
@@ -146,7 +138,7 @@ acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async =
 %value = memref.alloc() : memref<f32>
 %0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
 // expected-error at +1 {{wait attribute cannot appear with waitOperands}}
-acc.update wait({%cst: index}) dataOperands(%0: memref<f32>) attributes {wait = [#acc.device_type<none>]} 
+acc.update wait({%cst: index}) dataOperands(%0: memref<f32>) attributes {waitOnly = [#acc.device_type<none>]} 
 
 // -----
 
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 45b41f1a7722566..4e6ed8645cdbce7 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -802,7 +802,7 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
   } attributes { defaultAttr = #acc<defaultvalue none>, wait }
 
   %wd1 = arith.constant 1 : i64
-  acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) {
+  acc.data wait({devnum: %wd1 : i64, %w1 : i64}) {
   } attributes { defaultAttr = #acc<defaultvalue none>, wait }
 
   return
@@ -916,7 +916,7 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
 // CHECK:      acc.data wait({%{{.*}} : i64}) {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
 
-// CHECK:      acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) {
+// CHECK:      acc.data wait({devnum: %{{.*}} : i64, %{{.*}} : i64}) {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
 
 // -----
@@ -934,7 +934,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
   acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
   acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
   acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
-  acc.update wait_devnum(%i64Value: i64) wait({%i32Value : i32, %idxValue : index}) dataOperands(%0: memref<f32>)
+  acc.update wait({devnum: %i64Value: i64, %i32Value : i32, %idxValue : index}) dataOperands(%0: memref<f32>)
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   acc.update dataOperands(%0: memref<f32>)
   acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
@@ -953,7 +953,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
 // CHECK:   acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
-// CHECK:   acc.update wait_devnum([[I64VALUE]] : i64) wait({[[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update wait({devnum: [[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index 474f887928992a9..41b751b8d7f7cbc 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -86,13 +86,13 @@ void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
   OwningOpRef<arith::ConstantIndexOp> val =
       b.create<arith::ConstantIndexOp>(loc, 1);
   auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
-  op->setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
-  op->getAsyncMutable().assign(val->getResult());
+  op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+  op->getAsyncOperandsMutable().assign(val->getResult());
   EXPECT_EQ(op->getAsyncValue(), empty);
   EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
 
-  op->getAsyncMutable().clear();
-  op->removeAsyncDeviceTypeAttr();
+  op->getAsyncOperandsMutable().clear();
+  op->removeAsyncOperandsDeviceTypeAttr();
 }
 
 TEST_F(OpenACCOpsTest, asyncValueTest) {
@@ -232,6 +232,8 @@ TEST_F(OpenACCOpsTest, waitOnlyTest) {
   testWaitOnly<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
   testWaitOnly<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
   testWaitOnly<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testWaitOnly<UpdateOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testWaitOnly<DataOp>(b, context, loc, dtypes, dtypesWithoutNone);
 }
 
 template <typename Op>
@@ -245,6 +247,8 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
       b.create<arith::ConstantIndexOp>(loc, 1);
   OwningOpRef<arith::ConstantIndexOp> val2 =
       b.create<arith::ConstantIndexOp>(loc, 4);
+  OwningOpRef<arith::ConstantIndexOp> val3 =
+      b.create<arith::ConstantIndexOp>(loc, 5);
   auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
   op->getWaitOperandsMutable().assign(val1->getResult());
   op->setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
@@ -294,6 +298,28 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
   op->getWaitOperandsMutable().clear();
   op->removeWaitOperandsDeviceTypeAttr();
   op->removeWaitOperandsSegmentsAttr();
+
+  op->getWaitOperandsMutable().append(val3->getResult());
+  op->getWaitOperandsMutable().append(val2->getResult());
+  op->getWaitOperandsMutable().append(val1->getResult());
+  op->setWaitOperandsDeviceTypeAttr(
+      b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
+  op->setHasWaitDevnumAttr(b.getBoolArrayAttr({true}));
+  op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({3}));
+  EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(),
+            op->getWaitValues(DeviceType::None).end());
+  EXPECT_FALSE(op->getWaitDevnum());
+
+  EXPECT_EQ(op->getWaitDevnum(DeviceType::Multicore), val3->getResult());
+  EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).front(),
+            val2->getResult());
+  EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).drop_front().front(),
+            val1->getResult());
+
+  op->getWaitOperandsMutable().clear();
+  op->removeWaitOperandsDeviceTypeAttr();
+  op->removeWaitOperandsSegmentsAttr();
+  op->removeHasWaitDevnumAttr();
 }
 
 TEST_F(OpenACCOpsTest, waitValuesTest) {



More information about the flang-commits mailing list