[flang-commits] [flang] [mlir] [mlir][openacc] Add device_type support for compute operations (PR #75864)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Dec 20 07:45:55 PST 2023


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/75864

>From 97fa81b7486339ae7e4884bdc67357d35207768f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 7 Dec 2023 14:04:54 -0800
Subject: [PATCH] [mlir][openacc] Add device_type support for compute
 operations

---
 flang/lib/Lower/OpenACC.cpp                   | 106 +++-
 flang/test/Lower/OpenACC/acc-device-type.f90  |  44 ++
 flang/test/Lower/OpenACC/acc-kernels-loop.f90 |  14 +-
 flang/test/Lower/OpenACC/acc-kernels.f90      |  14 +-
 .../test/Lower/OpenACC/acc-parallel-loop.f90  |  14 +-
 flang/test/Lower/OpenACC/acc-parallel.f90     |  16 +-
 flang/test/Lower/OpenACC/acc-serial-loop.f90  |  10 +-
 flang/test/Lower/OpenACC/acc-serial.f90       |  10 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 286 +++++++---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 515 +++++++++++++++++-
 mlir/test/Dialect/OpenACC/invalid.mlir        |   4 +-
 mlir/test/Dialect/OpenACC/ops.mlir            |  76 +--
 mlir/unittests/Dialect/CMakeLists.txt         |   1 +
 mlir/unittests/Dialect/OpenACC/CMakeLists.txt |   8 +
 .../Dialect/OpenACC/OpenACCOpsTest.cpp        | 275 ++++++++++
 15 files changed, 1216 insertions(+), 177 deletions(-)
 create mode 100644 flang/test/Lower/OpenACC/acc-device-type.f90
 create mode 100644 mlir/unittests/Dialect/OpenACC/CMakeLists.txt
 create mode 100644 mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fae54eefb02f70..59db5ab71b7024 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
   case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
     return mlir::acc::DeviceType::Multicore;
   }
-  return mlir::acc::DeviceType::Default;
+  return mlir::acc::DeviceType::None;
 }
 
 static void gatherDeviceTypeAttrs(
@@ -1781,26 +1781,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                 bool outerCombined = false) {
 
   // Parallel operation operands
-  mlir::Value async;
-  mlir::Value numWorkers;
-  mlir::Value vectorLength;
   mlir::Value ifCond;
   mlir::Value selfCond;
   mlir::Value waitDevnum;
   llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
       copyEntryOperands, copyoutEntryOperands, createEntryOperands,
-      dataClauseOperands, numGangs;
+      dataClauseOperands, numGangs, numWorkers, vectorLength, async;
+  llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
+      vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
+      waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
 
-  // Async, wait and self clause have optional values but can be present with
+  // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
   // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
   bool addSelfAttr = false;
 
   bool hasDefaultNone = false;
@@ -1808,6 +1807,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
+
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separatly as clauses can appear
   // more than once.
@@ -1815,27 +1819,52 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      const auto &asyncClauseValue = asyncClause->v;
+      if (asyncClauseValue) { // async has a value.
+        async.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+        asyncDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      const auto &waitClauseValue = waitClause->v;
+      if (waitClauseValue) { // wait has a value.
+        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+        const auto &waitList =
+            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+        auto crtWaitOperands = waitOperands.size();
+        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+          waitOperands.push_back(fir::getBase(converter.genExprValue(
+              *Fortran::semantics::GetExpr(value), stmtCtx)));
+        }
+        waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+      } else {
+        waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
+      auto crtNumGangs = numGangs.size();
       for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
         numGangs.push_back(fir::getBase(converter.genExprValue(
             *Fortran::semantics::GetExpr(expr), stmtCtx)));
+      numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+      numGangsSegments.push_back(numGangs.size() - crtNumGangs);
     } else if (const auto *numWorkersClause =
                    std::get_if<Fortran::parser::AccClause::NumWorkers>(
                        &clause.u)) {
-      numWorkers = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+      numWorkers.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
+      numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *vectorLengthClause =
                    std::get_if<Fortran::parser::AccClause::VectorLength>(
                        &clause.u)) {
-      vectorLength = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+      vectorLength.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
+      vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *ifClause =
                    std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1986,18 +2015,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       else if ((defaultClause->v).v ==
                llvm::acc::DefaultValue::ACC_Default_present)
         hasDefaultPresent = true;
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value, 8> operands;
   llvm::SmallVector<int32_t, 8> operandSegments;
-  addOperand(operands, operandSegments, async);
+  addOperands(operands, operandSegments, async);
   addOperands(operands, operandSegments, waitOperands);
   if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
     addOperands(operands, operandSegments, numGangs);
-    addOperand(operands, operandSegments, numWorkers);
-    addOperand(operands, operandSegments, vectorLength);
+    addOperands(operands, operandSegments, numWorkers);
+    addOperands(operands, operandSegments, vectorLength);
   }
   addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, selfCond);
@@ -2018,10 +2056,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
         builder, currentLocation, eval, operands, operandSegments,
         outerCombined);
 
-  if (addAsyncAttr)
-    computeOp.setAsyncAttrAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    computeOp.setWaitAttrAttr(builder.getUnitAttr());
   if (addSelfAttr)
     computeOp.setSelfAttrAttr(builder.getUnitAttr());
 
@@ -2030,6 +2064,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (hasDefaultPresent)
     computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
 
+  if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
+    if (!numWorkersDeviceTypes.empty())
+      computeOp.setNumWorkersDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
+    if (!vectorLengthDeviceTypes.empty())
+      computeOp.setVectorLengthDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
+    if (!numGangsDeviceTypes.empty())
+      computeOp.setNumGangsDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
+    if (!numGangsSegments.empty())
+      computeOp.setNumGangsSegmentsAttr(
+          builder.getDenseI32ArrayAttr(numGangsSegments));
+  }
+  if (!asyncDeviceTypes.empty())
+    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+  if (!asyncOnlyDeviceTypes.empty())
+    computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+
+  if (!waitOperandsDeviceTypes.empty())
+    computeOp.setWaitOperandsDeviceTypeAttr(
+        builder.getArrayAttr(waitOperandsDeviceTypes));
+  if (!waitOperandsSegments.empty())
+    computeOp.setWaitOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!waitOnlyDeviceTypes.empty())
+    computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
+
   if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
     if (!privatizations.empty())
       computeOp.setPrivatizationsAttr(
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
new file mode 100644
index 00000000000000..871dbc95f60fcb
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -0,0 +1,44 @@
+! This test checks lowering of OpenACC device_type clause on directive where its
+! position and the clauses that follow have special semantic
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1()
+
+  !$acc parallel num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
+
+  !$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
+
+  !$acc parallel device_type(multicore) vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+
+end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 34e72326972417..93bc699031d550 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels loop async(1)
   DO i = 1, n
@@ -103,7 +103,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels loop wait(1)
   DO i = 1, n
@@ -111,7 +111,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 1f882c6df51061..99629bb8351723 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -40,7 +40,7 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels async(1)
   !$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels wait(1)
   !$acc end kernels
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels  wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -78,7 +78,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels  wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -87,7 +87,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -95,7 +95,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -103,7 +103,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 1856215ce59d13..deee7089033ead 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -64,7 +64,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop async(1)
   DO i = 1, n
@@ -105,7 +105,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop wait(1)
   DO i = 1, n
@@ -113,7 +113,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -128,7 +128,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -143,7 +143,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -157,7 +157,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -171,7 +171,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index bbf51ba36a7dea..a369bf01f25995 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -62,7 +62,7 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel async(1)
   !$acc end parallel
@@ -85,13 +85,13 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel wait(1)
   !$acc end parallel
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -100,7 +100,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -109,7 +109,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -117,7 +117,7 @@ subroutine acc_parallel
   !$acc end parallel
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -125,14 +125,14 @@ subroutine acc_parallel
   !$acc end parallel
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
   !$acc parallel num_gangs(1, 1, 1)
   !$acc end parallel
 
-! CHECK:      acc.parallel num_gangs(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) {
+! CHECK:      acc.parallel num_gangs({%{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4ed7bb8da29a1a..712bfc80ce387c 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -83,7 +83,7 @@ subroutine acc_serial_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc serial loop async(1)
   DO i = 1, n
@@ -124,7 +124,7 @@ subroutine acc_serial_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc serial loop wait(1)
   DO i = 1, n
@@ -132,7 +132,7 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.serial wait([[WAIT1]] : i32) {
+! CHECK:      acc.serial wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -147,7 +147,7 @@ subroutine acc_serial_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -162,7 +162,7 @@ subroutine acc_serial_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index ab3b0ccd545958..d05e51d3d274f4 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -62,7 +62,7 @@ subroutine acc_serial
 
 ! CHECK:      acc.serial {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc serial async(1)
   !$acc end serial
@@ -85,13 +85,13 @@ subroutine acc_serial
 
 ! CHECK:      acc.serial {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc serial wait(1)
   !$acc end serial
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.serial wait([[WAIT1]] : i32) {
+! CHECK:      acc.serial wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -100,7 +100,7 @@ subroutine acc_serial
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -109,7 +109,7 @@ subroutine acc_serial
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index a78c3e98c95517..234c1076e14e3b 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -156,29 +156,46 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> {
 }
 
 // Device type enumeration.
-def OpenACC_DeviceTypeStar      : I32EnumAttrCase<"Star", 0, "star">;
-def OpenACC_DeviceTypeDefault   : I32EnumAttrCase<"Default", 1, "default">;
-def OpenACC_DeviceTypeHost      : I32EnumAttrCase<"Host", 2, "host">;
-def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">;
-def OpenACC_DeviceTypeNvidia    : I32EnumAttrCase<"Nvidia", 4, "nvidia">;
-def OpenACC_DeviceTypeRadeon    : I32EnumAttrCase<"Radeon", 5, "radeon">;
-
+def OpenACC_DeviceTypeNone      : I32EnumAttrCase<"None", 0, "none">;
+def OpenACC_DeviceTypeStar      : I32EnumAttrCase<"Star", 1, "star">;
+def OpenACC_DeviceTypeDefault   : I32EnumAttrCase<"Default", 2, "default">;
+def OpenACC_DeviceTypeHost      : I32EnumAttrCase<"Host", 3, "host">;
+def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 4, "multicore">;
+def OpenACC_DeviceTypeNvidia    : I32EnumAttrCase<"Nvidia", 5, "nvidia">;
+def OpenACC_DeviceTypeRadeon    : I32EnumAttrCase<"Radeon", 6, "radeon">;
 
 def OpenACC_DeviceType : I32EnumAttr<"DeviceType",
     "built-in device type supported by OpenACC",
-    [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
+    [OpenACC_DeviceTypeNone, OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
      OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore,
      OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon
     ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::acc";
 }
+
+// Device type attribute is used to associate a value for for clauses that
+// appear after a device_type clause. The list of clauses allowed after the
+// device_type clause is defined per construct as follows:
+// Loop construct: collapse, gang, worker, vector, seq, independent, auto,
+//                 and tile
+// Compute construct: async, wait, num_gangs, num_workers, and vector_length
+// Data construct: async and wait 
+// Routine: gang, worker, vector, seq and bind
+//
+// The `none` means that the value appears before any device_type clause.
+//
 def OpenACC_DeviceTypeAttr : EnumAttr<OpenACC_Dialect,
                                       OpenACC_DeviceType,
                                       "device_type"> {
   let assemblyFormat = [{ ```<` $value `>` }];
 }
 
+def DeviceTypeArrayAttr :
+  TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "device type array attribute"> {
+  let constBuilderCall = ?;
+}
+
 // Define a resource for the OpenACC runtime counters.
 def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
 
@@ -863,24 +880,32 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     ```
   }];
 
-  let arguments = (ins Optional<IntOrIndex>:$async,
-                       UnitAttr:$asyncAttr,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$waitAttr,
-                       Variadic<IntOrIndex>:$numGangs,
-                       Optional<IntOrIndex>:$numWorkers,
-                       Optional<IntOrIndex>:$vectorLength,
-                       Optional<I1>:$ifCond,
-                       Optional<I1>:$selfCond,
-                       UnitAttr:$selfAttr,
-                       Variadic<AnyType>:$reductionOperands,
-                       OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
-                       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
-                       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       OptionalAttr<DefaultValueAttr>:$defaultAttr);
+  let arguments = (ins
+      Variadic<IntOrIndex>:$async,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+      Variadic<IntOrIndex>:$numGangs,
+      OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$numGangsDeviceType,
+      Variadic<IntOrIndex>:$numWorkers,
+      OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
+      Variadic<IntOrIndex>:$vectorLength,
+      OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
+      Optional<I1>:$ifCond,
+      Optional<I1>:$selfCond,
+      UnitAttr:$selfAttr,
+      Variadic<AnyType>:$reductionOperands,
+      OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+      OptionalAttr<SymbolRefArrayAttr>:$privatizations,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+      OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      OptionalAttr<DefaultValueAttr>:$defaultAttr);
 
   let regions = (region AnyRegion:$region);
 
@@ -890,22 +915,69 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsyncOnly();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return the value of the num_workers clause if present.
+    mlir::Value getNumWorkersValue();
+    /// Return the value of the num_workers clause for the given device_type if
+    /// present.
+    mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType);
+
+    /// Return the value of the vector_length clause if present.
+    mlir::Value getVectorLengthValue();
+    /// Return the value of the vector_length clause for the given device_type 
+    /// if present.
+    mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType);
+
+    /// Return the values of the num_gangs clause if present.
+    mlir::Operation::operand_range getNumGangsValues();
+    /// Return the values of the num_gangs clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getNumGangsValues(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWaitOnly();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` $async `:` type($async) `)`
+      | `async` `(` custom<DeviceTypeOperands>($async,
+            type($async), $asyncDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
             type($gangFirstPrivateOperands), $firstprivatizations)
         `)`
-      | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
-      | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
+      | `num_gangs` `(` custom<NumGangs>($numGangs,
+            type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
+      | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
+            type($numWorkers), $numWorkersDeviceType) `)`
       | `private` `(` custom<SymOperandList>(
             $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
         `)`
-      | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
+            type($vectorLength), $vectorLengthDeviceType) `)`
+      | `wait` `(` custom<WaitOperands>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
       | `reduction` `(` custom<SymOperandList>(
@@ -939,21 +1011,25 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
     ```
   }];
 
-  let arguments = (ins Optional<IntOrIndex>:$async,
-                       UnitAttr:$asyncAttr,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$waitAttr,
-                       Optional<I1>:$ifCond,
-                       Optional<I1>:$selfCond,
-                       UnitAttr:$selfAttr,
-                       Variadic<AnyType>:$reductionOperands,
-                       OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
-                       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
-                       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       OptionalAttr<DefaultValueAttr>:$defaultAttr);
+  let arguments = (ins
+      Variadic<IntOrIndex>:$async,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+      Optional<I1>:$ifCond,
+      Optional<I1>:$selfCond,
+      UnitAttr:$selfAttr,
+      Variadic<AnyType>:$reductionOperands,
+      OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+      OptionalAttr<SymbolRefArrayAttr>:$privatizations,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+      OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      OptionalAttr<DefaultValueAttr>:$defaultAttr);
 
   let regions = (region AnyRegion:$region);
 
@@ -963,19 +1039,44 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsyncOnly();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWaitOnly();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` $async `:` type($async) `)`
+      | `async` `(` custom<DeviceTypeOperands>($async,
+            type($async), $asyncDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
             type($gangFirstPrivateOperands), $firstprivatizations)
         `)`
       | `private` `(` custom<SymOperandList>(
             $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
         `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `wait` `(` custom<WaitOperands>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
       | `reduction` `(` custom<SymOperandList>(
@@ -1011,18 +1112,26 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
     ```
   }];
 
-  let arguments = (ins Optional<IntOrIndex>:$async,
-                       UnitAttr:$asyncAttr,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$waitAttr,
-                       Variadic<IntOrIndex>:$numGangs,
-                       Optional<IntOrIndex>:$numWorkers,
-                       Optional<IntOrIndex>:$vectorLength,
-                       Optional<I1>:$ifCond,
-                       Optional<I1>:$selfCond,
-                       UnitAttr:$selfAttr,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       OptionalAttr<DefaultValueAttr>:$defaultAttr);
+  let arguments = (ins
+      Variadic<IntOrIndex>:$async,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+      Variadic<IntOrIndex>:$numGangs,
+      OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$numGangsDeviceType,
+      Variadic<IntOrIndex>:$numWorkers,
+      OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
+      Variadic<IntOrIndex>:$vectorLength,
+      OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
+      Optional<I1>:$ifCond,
+      Optional<I1>:$selfCond,
+      UnitAttr:$selfAttr,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      OptionalAttr<DefaultValueAttr>:$defaultAttr);
 
   let regions = (region AnyRegion:$region);
 
@@ -1032,16 +1141,63 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsyncOnly();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return the value of the num_workers clause if present.
+    mlir::Value getNumWorkersValue();
+    /// Return the value of the num_workers clause for the given device_type if
+    /// present.
+    mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType);
+
+    /// Return the value of the vector_length clause if present.
+    mlir::Value getVectorLengthValue();
+    /// Return the value of the vector_length clause for the given device_type 
+    /// if present.
+    mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType);
+
+    /// Return the values of the num_gangs clause if present.
+    mlir::Operation::operand_range getNumGangsValues();
+    /// Return the values of the num_gangs clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getNumGangsValues(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWaitOnly();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` $async `:` type($async) `)`
-      | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
-      | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
-      | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `async` `(` custom<DeviceTypeOperands>($async,
+            type($async), $asyncDeviceType) `)`
+      | `num_gangs` `(` custom<NumGangs>($numGangs,
+            type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
+      | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
+            type($numWorkers), $numWorkersDeviceType) `)`
+      | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
+            type($vectorLength), $vectorLengthDeviceType) `)`
+      | `wait` `(` custom<WaitOperands>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
     )
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index df4f7825545c2b..45e0632db5ef2b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -615,15 +615,49 @@ unsigned ParallelOp::getNumDataOperands() {
 }
 
 Value ParallelOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsync() ? 1 : 0;
+  unsigned numOptional = getAsync().size();
   numOptional += getNumGangs().size();
-  numOptional += getNumWorkers() ? 1 : 0;
-  numOptional += getVectorLength() ? 1 : 0;
+  numOptional += getNumWorkers().size();
+  numOptional += getVectorLength().size();
   numOptional += getIfCond() ? 1 : 0;
   numOptional += getSelfCond() ? 1 : 0;
   return getOperand(getWaitOperands().size() + numOptional + i);
 }
 
+template <typename Op>
+static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
+                                                ArrayAttr deviceTypes,
+                                                llvm::StringRef keyword) {
+  if (operands.size() > 0 && deviceTypes.getValue().size() != operands.size())
+    return op.emitOpError() << keyword << " operands count must match "
+                            << keyword << " device_type count";
+  return success();
+}
+
+template <typename Op>
+static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
+    Op op, OperandRange operands, DenseI32ArrayAttr segments,
+    ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
+  std::size_t numOperandsInSegments = 0;
+
+  if (!segments)
+    return success();
+
+  for (auto segCount : segments.asArrayRef()) {
+    if (maxInSegment != 0 && segCount > maxInSegment)
+      return op.emitOpError() << keyword << " expects a maximum of "
+                              << maxInSegment << " values per segment";
+    numOperandsInSegments += segCount;
+  }
+  if (numOperandsInSegments != operands.size())
+    return op.emitOpError()
+           << keyword << " operand count does not match count in segments";
+  if (deviceTypes.getValue().size() != (size_t)segments.size())
+    return op.emitOpError()
+           << keyword << " segment count does not match device_type count";
+  return success();
+}
+
 LogicalResult acc::ParallelOp::verify() {
   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
           *this, getPrivatizations(), getGangPrivateOperands(), "private",
@@ -633,11 +667,322 @@ LogicalResult acc::ParallelOp::verify() {
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
           "reductions", false)))
     return failure();
-  if (getNumGangs().size() > 3)
-    return emitOpError() << "num_gangs expects a maximum of 3 values";
+
+  if (failed(verifyDeviceTypeAndSegmentCountMatch(
+          *this, getNumGangs(), getNumGangsSegmentsAttr(),
+          getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
+    return failure();
+
+  if (failed(verifyDeviceTypeAndSegmentCountMatch(
+          *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+          getWaitOperandsDeviceTypeAttr(), "wait")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
+                                        getNumWorkersDeviceTypeAttr(),
+                                        "num_workers")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
+                                        getVectorLengthDeviceTypeAttr(),
+                                        "vector_length")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
+                                        getAsyncDeviceTypeAttr(), "async")))
+    return failure();
+
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
+static std::optional<unsigned> findSegment(ArrayAttr segments,
+                                           mlir::acc::DeviceType deviceType) {
+  unsigned segmentIdx = 0;
+  for (auto attr : segments) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return std::make_optional(segmentIdx);
+    ++segmentIdx;
+  }
+  return std::nullopt;
+}
+
+static mlir::Value
+getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
+                            mlir::Operation::operand_range range,
+                            mlir::acc::DeviceType deviceType) {
+  if (!arrayAttr)
+    return {};
+  if (auto pos = findSegment(*arrayAttr, deviceType))
+    return range[*pos];
+  return {};
+}
+
+bool acc::ParallelOp::hasAsyncOnly() {
+  return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getAsyncOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Value acc::ParallelOp::getAsyncValue() {
+  return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+                                     deviceType);
+}
+
+mlir::Value acc::ParallelOp::getNumWorkersValue() {
+  return getNumWorkersValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
+                                     deviceType);
+}
+
+mlir::Value acc::ParallelOp::getVectorLengthValue() {
+  return getVectorLengthValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
+                                     getVectorLength(), deviceType);
+}
+
+mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
+  return getNumGangsValues(mlir::acc::DeviceType::None);
+}
+
+static mlir::Operation::operand_range
+getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
+                      mlir::Operation::operand_range range,
+                      std::optional<llvm::ArrayRef<int32_t>> segments,
+                      mlir::acc::DeviceType deviceType) {
+  if (!arrayAttr)
+    return range.take_front(0);
+  if (auto pos = findSegment(*arrayAttr, deviceType)) {
+    int32_t nbOperandsBefore = 0;
+    for (unsigned i = 0; i < *pos; ++i)
+      nbOperandsBefore += (*segments)[i];
+    return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
+  }
+  return range.take_front(0);
+}
+
+mlir::Operation::operand_range
+ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
+                               getNumGangsSegments(), deviceType);
+}
+
+bool acc::ParallelOp::hasWaitOnly() {
+  return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getWaitOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Operation::operand_range ParallelOp::getWaitValues() {
+  return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+                               getWaitOperandsSegments(), deviceType);
+}
+
+static ParseResult parseNumGangs(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::DenseI32ArrayAttr &segments) {
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  llvm::SmallVector<int32_t> seg;
+
+  do {
+    if (failed(parser.parseLBrace()))
+      return failure();
+
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(operands.emplace_back()) ||
+                  parser.parseColonType(types.emplace_back()))
+                return failure();
+              return success();
+            })))
+      return failure();
+
+    seg.push_back(operands.size());
+
+    if (failed(parser.parseRBrace()))
+      return failure();
+
+    if (succeeded(parser.parseOptionalLSquare())) {
+      if (parser.parseAttribute(attributes.emplace_back()) ||
+          parser.parseRSquare())
+        return failure();
+    } else {
+      attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+          parser.getContext(), mlir::acc::DeviceType::None));
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+  return success();
+}
+
+static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                          mlir::OperandRange operands, mlir::TypeRange types,
+                          std::optional<mlir::ArrayAttr> deviceTypes,
+                          std::optional<mlir::DenseI32ArrayAttr> segments) {
+  unsigned opIdx = 0;
+  for (unsigned i = 0; i < deviceTypes->size(); ++i) {
+    if (i != 0)
+      p << ", ";
+    p << "{";
+    for (int32_t j = 0; j < (*segments)[i]; ++j) {
+      if (j != 0)
+        p << ", ";
+      p << operands[opIdx] << " : " << operands[opIdx].getType();
+      ++opIdx;
+    }
+    p << "}";
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+      p << " [" << (*deviceTypes)[i] << "]";
+  }
+}
+
+static ParseResult parseWaitOperands(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::DenseI32ArrayAttr &segments) {
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  llvm::SmallVector<int32_t> seg;
+
+  do {
+    if (failed(parser.parseLBrace()))
+      return failure();
+
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(operands.emplace_back()) ||
+                  parser.parseColonType(types.emplace_back()))
+                return failure();
+              return success();
+            })))
+      return failure();
+
+    seg.push_back(operands.size());
+
+    if (failed(parser.parseRBrace()))
+      return failure();
+
+    if (succeeded(parser.parseOptionalLSquare())) {
+      if (parser.parseAttribute(attributes.emplace_back()) ||
+          parser.parseRSquare())
+        return failure();
+    } else {
+      attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+          parser.getContext(), mlir::acc::DeviceType::None));
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+  return success();
+}
+
+static void printWaitOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                              mlir::OperandRange operands,
+                              mlir::TypeRange types,
+                              std::optional<mlir::ArrayAttr> deviceTypes,
+                              std::optional<mlir::DenseI32ArrayAttr> segments) {
+  unsigned opIdx = 0;
+  for (unsigned i = 0; i < deviceTypes->size(); ++i) {
+    if (i != 0)
+      p << ", ";
+    p << "{";
+    for (int32_t j = 0; j < (*segments)[i]; ++j) {
+      if (j != 0)
+        p << ", ";
+      p << operands[opIdx] << " : " << operands[opIdx].getType();
+      ++opIdx;
+    }
+    p << "}";
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+      p << " [" << (*deviceTypes)[i] << "]";
+  }
+}
+
+static ParseResult parseDeviceTypeOperands(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(attributes.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  return success();
+}
+
+static void
+printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                        mlir::OperandRange operands, mlir::TypeRange types,
+                        std::optional<mlir::ArrayAttr> deviceTypes) {
+  for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << operands[i] << " : " << operands[i].getType();
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+      p << " [" << (*deviceTypes)[i] << "]";
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // SerialOp
 //===----------------------------------------------------------------------===//
@@ -648,12 +993,55 @@ unsigned SerialOp::getNumDataOperands() {
 }
 
 Value SerialOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsync() ? 1 : 0;
+  unsigned numOptional = getAsync().size();
   numOptional += getIfCond() ? 1 : 0;
   numOptional += getSelfCond() ? 1 : 0;
   return getOperand(getWaitOperands().size() + numOptional + i);
 }
 
+bool acc::SerialOp::hasAsyncOnly() {
+  return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getAsyncOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Value acc::SerialOp::getAsyncValue() {
+  return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+                                     deviceType);
+}
+
+bool acc::SerialOp::hasWaitOnly() {
+  return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getWaitOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Operation::operand_range SerialOp::getWaitValues() {
+  return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+                               getWaitOperandsSegments(), deviceType);
+}
+
 LogicalResult acc::SerialOp::verify() {
   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
           *this, getPrivatizations(), getGangPrivateOperands(), "private",
@@ -663,6 +1051,16 @@ LogicalResult acc::SerialOp::verify() {
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
           "reductions", false)))
     return failure();
+
+  if (failed(verifyDeviceTypeAndSegmentCountMatch(
+          *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+          getWaitOperandsDeviceTypeAttr(), "wait")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
+                                        getAsyncDeviceTypeAttr(), "async")))
+    return failure();
+
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }
 
@@ -675,19 +1073,114 @@ unsigned KernelsOp::getNumDataOperands() {
 }
 
 Value KernelsOp::getDataOperand(unsigned i) {
-  unsigned numOptional = getAsync() ? 1 : 0;
+  unsigned numOptional = getAsync().size();
   numOptional += getWaitOperands().size();
   numOptional += getNumGangs().size();
-  numOptional += getNumWorkers() ? 1 : 0;
-  numOptional += getVectorLength() ? 1 : 0;
+  numOptional += getNumWorkers().size();
+  numOptional += getVectorLength().size();
   numOptional += getIfCond() ? 1 : 0;
   numOptional += getSelfCond() ? 1 : 0;
   return getOperand(numOptional + i);
 }
 
+bool acc::KernelsOp::hasAsyncOnly() {
+  return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getAsyncOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Value acc::KernelsOp::getAsyncValue() {
+  return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+                                     deviceType);
+}
+
+mlir::Value acc::KernelsOp::getNumWorkersValue() {
+  return getNumWorkersValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
+                                     deviceType);
+}
+
+mlir::Value acc::KernelsOp::getVectorLengthValue() {
+  return getVectorLengthValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
+  return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
+                                     getVectorLength(), deviceType);
+}
+
+mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
+  return getNumGangsValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
+                               getNumGangsSegments(), deviceType);
+}
+
+bool acc::KernelsOp::hasWaitOnly() {
+  return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+  if (auto arrayAttr = getWaitOnly()) {
+    if (findSegment(*arrayAttr, deviceType))
+      return true;
+  }
+  return false;
+}
+
+mlir::Operation::operand_range KernelsOp::getWaitValues() {
+  return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+  return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+                               getWaitOperandsSegments(), deviceType);
+}
+
 LogicalResult acc::KernelsOp::verify() {
-  if (getNumGangs().size() > 3)
-    return emitOpError() << "num_gangs expects a maximum of 3 values";
+  if (failed(verifyDeviceTypeAndSegmentCountMatch(
+          *this, getNumGangs(), getNumGangsSegmentsAttr(),
+          getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
+    return failure();
+
+  if (failed(verifyDeviceTypeAndSegmentCountMatch(
+          *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+          getWaitOperandsDeviceTypeAttr(), "wait")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
+                                        getNumWorkersDeviceTypeAttr(),
+                                        "num_workers")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
+                                        getVectorLengthDeviceTypeAttr(),
+                                        "vector_length")))
+    return failure();
+
+  if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
+                                        getAsyncDeviceTypeAttr(), "async")))
+    return failure();
+
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
 }
 
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index b9ac68d0592c87..c18d964b370f2c 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -462,8 +462,8 @@ acc.loop gang() {
 // -----
 
 %i64value = arith.constant 1 : i64
-// expected-error at +1 {{num_gangs expects a maximum of 3 values}}
-acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) {
+// expected-error at +1 {{num_gangs expects a maximum of 3 values per segment}}
+acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64value : i64}) {
 }
 
 // -----
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 05b0450c7fb916..5a95811685f845 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -137,7 +137,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
   %pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
   acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
     %private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32>
-    acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
+    acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
       acc.loop gang {
         scf.for %x = %lb to %c10 step %st {
           acc.loop worker {
@@ -180,7 +180,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK-NEXT:   [[NUMWORKERS:%.*]] = arith.constant 10 : i64
 // CHECK:        acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
 // CHECK-NEXT:     %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32> 
-// CHECK-NEXT:     acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
+// CHECK-NEXT:     acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
 // CHECK-NEXT:       acc.loop gang {
 // CHECK-NEXT:         scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:           acc.loop worker {
@@ -439,25 +439,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
   }
   acc.parallel async(%idxValue: index) {
   }
-  acc.parallel wait(%i64value: i64) {
+  acc.parallel wait({%i64value: i64}) {
   }
-  acc.parallel wait(%i32value: i32) {
+  acc.parallel wait({%i32value: i32}) {
   }
-  acc.parallel wait(%idxValue: index) {
+  acc.parallel wait({%idxValue: index}) {
   }
-  acc.parallel wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+  acc.parallel wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
   }
-  acc.parallel num_gangs(%i64value: i64) {
+  acc.parallel num_gangs({%i64value: i64}) {
   }
-  acc.parallel num_gangs(%i32value: i32) {
+  acc.parallel num_gangs({%i32value: i32}) {
   }
-  acc.parallel num_gangs(%idxValue: index) {
+  acc.parallel num_gangs({%idxValue: index}) {
   }
-  acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) {
+  acc.parallel num_gangs({%i64value: i64, %i64value: i64, %idxValue: index}) {
   }
-  acc.parallel num_workers(%i64value: i64) {
+  acc.parallel num_workers(%i64value: i64 [#acc.device_type<nvidia>]) {
   }
-  acc.parallel num_workers(%i32value: i32) {
+  acc.parallel num_workers(%i32value: i32 [#acc.device_type<default>]) {
   }
   acc.parallel num_workers(%idxValue: index) {
   }
@@ -492,25 +492,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
 // CHECK-NEXT: }
 // CHECK:      acc.parallel async([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel wait([[I64VALUE]] : i64) {
+// CHECK:      acc.parallel wait({[[I64VALUE]] : i64}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel wait([[I32VALUE]] : i32) {
+// CHECK:      acc.parallel wait({[[I32VALUE]] : i32}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel wait([[IDXVALUE]] : index) {
+// CHECK:      acc.parallel wait({[[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK:      acc.parallel wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel num_gangs([[I64VALUE]] : i64) {
+// CHECK:      acc.parallel num_gangs({[[I64VALUE]] : i64}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel num_gangs([[I32VALUE]] : i32) {
+// CHECK:      acc.parallel num_gangs({[[I32VALUE]] : i32}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel num_gangs([[IDXVALUE]] : index) {
+// CHECK:      acc.parallel num_gangs({[[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) {
+// CHECK:      acc.parallel num_gangs({[[I64VALUE]] : i64, [[I64VALUE]] : i64, [[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel num_workers([[I64VALUE]] : i64) {
+// CHECK:      acc.parallel num_workers([[I64VALUE]] : i64 [#acc.device_type<nvidia>]) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel num_workers([[I32VALUE]] : i32) {
+// CHECK:      acc.parallel num_workers([[I32VALUE]] : i32 [#acc.device_type<default>]) {
 // CHECK-NEXT: }
 // CHECK:      acc.parallel num_workers([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
@@ -590,13 +590,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
   }
   acc.serial async(%idxValue: index) {
   }
-  acc.serial wait(%i64value: i64) {
+  acc.serial wait({%i64value: i64}) {
   }
-  acc.serial wait(%i32value: i32) {
+  acc.serial wait({%i32value: i32}) {
   }
-  acc.serial wait(%idxValue: index) {
+  acc.serial wait({%idxValue: index}) {
   }
-  acc.serial wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+  acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
   }
   %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) -> memref<10xf32>
   acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) {
@@ -627,13 +627,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
 // CHECK-NEXT: }
 // CHECK:      acc.serial async([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
-// CHECK:      acc.serial wait([[I64VALUE]] : i64) {
+// CHECK:      acc.serial wait({[[I64VALUE]] : i64}) {
 // CHECK-NEXT: }
-// CHECK:      acc.serial wait([[I32VALUE]] : i32) {
+// CHECK:      acc.serial wait({[[I32VALUE]] : i32}) {
 // CHECK-NEXT: }
-// CHECK:      acc.serial wait([[IDXVALUE]] : index) {
+// CHECK:      acc.serial wait({[[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
-// CHECK:      acc.serial wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK:      acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
 // CHECK:      %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) -> memref<10xf32>
 // CHECK:      acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
@@ -665,13 +665,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
   }
   acc.kernels async(%idxValue: index) {
   }
-  acc.kernels wait(%i64value: i64) {
+  acc.kernels wait({%i64value: i64}) {
   }
-  acc.kernels wait(%i32value: i32) {
+  acc.kernels wait({%i32value: i32}) {
   }
-  acc.kernels wait(%idxValue: index) {
+  acc.kernels wait({%idxValue: index}) {
   }
-  acc.kernels wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+  acc.kernels wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
   }
   acc.kernels {
   } attributes {defaultAttr = #acc<defaultvalue none>}
@@ -699,13 +699,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
 // CHECK-NEXT: }
 // CHECK:      acc.kernels async([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
-// CHECK:      acc.kernels wait([[I64VALUE]] : i64) {
+// CHECK:      acc.kernels wait({[[I64VALUE]] : i64}) {
 // CHECK-NEXT: }
-// CHECK:      acc.kernels wait([[I32VALUE]] : i32) {
+// CHECK:      acc.kernels wait({[[I32VALUE]] : i32}) {
 // CHECK-NEXT: }
-// CHECK:      acc.kernels wait([[IDXVALUE]] : index) {
+// CHECK:      acc.kernels wait({[[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
-// CHECK:      acc.kernels wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK:      acc.kernels wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
 // CHECK-NEXT: }
 // CHECK:      acc.kernels {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 2dec4ba3c001e8..13393569f36fe7 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(ArmSME)
 add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
+add_subdirectory(OpenACC)
 add_subdirectory(SCF)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
new file mode 100644
index 00000000000000..5133d7fc38296c
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIROpenACCTests
+  OpenACCOpsTest.cpp
+)
+target_link_libraries(MLIROpenACCTests
+  PRIVATE
+  MLIRIR
+  MLIROpenACCDialect
+)
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
new file mode 100644
index 00000000000000..dcf6c1240c55d3
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -0,0 +1,275 @@
+//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class OpenACCOpsTest : public ::testing::Test {
+protected:
+  OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) {
+    context.loadDialect<acc::OpenACCDialect, arith::ArithDialect>();
+  }
+
+  MLIRContext context;
+  OpBuilder b;
+  Location loc;
+  llvm::SmallVector<DeviceType> dtypes = {
+      DeviceType::None,    DeviceType::Star, DeviceType::Multicore,
+      DeviceType::Default, DeviceType::Host, DeviceType::Nvidia,
+      DeviceType::Radeon};
+  llvm::SmallVector<DeviceType> dtypesWithoutNone = {
+      DeviceType::Star, DeviceType::Multicore, DeviceType::Default,
+      DeviceType::Host, DeviceType::Nvidia,    DeviceType::Radeon};
+};
+
+template <typename Op>
+void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc,
+                   llvm::SmallVector<DeviceType> &dtypes) {
+  Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+  EXPECT_FALSE(op.hasAsyncOnly());
+  for (auto d : dtypes)
+    EXPECT_FALSE(op.hasAsyncOnly(d));
+
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op.setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op.hasAsyncOnly());
+  EXPECT_TRUE(op.hasAsyncOnly(DeviceType::None));
+  op.removeAsyncOnlyAttr();
+
+  auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+  op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
+  EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host));
+  EXPECT_FALSE(op.hasAsyncOnly());
+  op.removeAsyncOnlyAttr();
+
+  auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+  op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+  EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Star));
+  EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host));
+  EXPECT_FALSE(op.hasAsyncOnly());
+}
+
+TEST_F(OpenACCOpsTest, asyncOnlyTest) {
+  testAsyncOnly<ParallelOp>(b, context, loc, dtypes);
+  testAsyncOnly<KernelsOp>(b, context, loc, dtypes);
+  testAsyncOnly<SerialOp>(b, context, loc, dtypes);
+}
+
+template <typename Op>
+void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
+                    llvm::SmallVector<DeviceType> &dtypes) {
+  Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+
+  mlir::Value empty;
+  EXPECT_EQ(op.getAsyncValue(), empty);
+  for (auto d : dtypes)
+    EXPECT_EQ(op.getAsyncValue(d), empty);
+
+  mlir::Value val = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+  auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+  op.setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+  op.getAsyncMutable().assign(val);
+  EXPECT_EQ(op.getAsyncValue(), empty);
+  EXPECT_EQ(op.getAsyncValue(DeviceType::Nvidia), val);
+}
+
+TEST_F(OpenACCOpsTest, asyncValueTest) {
+  testAsyncValue<ParallelOp>(b, context, loc, dtypes);
+  testAsyncValue<KernelsOp>(b, context, loc, dtypes);
+  testAsyncValue<SerialOp>(b, context, loc, dtypes);
+}
+
+template <typename Op>
+void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
+                        llvm::SmallVector<DeviceType> &dtypes,
+                        llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
+  Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+  EXPECT_EQ(op.getNumGangsValues().begin(), op.getNumGangsValues().end());
+
+  mlir::Value val1 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+  mlir::Value val2 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(4));
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op.getNumGangsMutable().assign(val1);
+  op.setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+  op.setNumGangsSegments(b.getDenseI32ArrayAttr({1}));
+  EXPECT_EQ(op.getNumGangsValues().front(), val1);
+  for (auto d : dtypesWithoutNone)
+    EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
+
+  op.getNumGangsMutable().clear();
+  op.removeNumGangsDeviceTypeAttr();
+  op.removeNumGangsSegmentsAttr();
+  for (auto d : dtypes)
+    EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
+
+  op.getNumGangsMutable().append(val1);
+  op.getNumGangsMutable().append(val2);
+  op.setNumGangsDeviceTypeAttr(
+      b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
+                      DeviceTypeAttr::get(&context, DeviceType::Star)}));
+  op.setNumGangsSegments(b.getDenseI32ArrayAttr({1, 1}));
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(),
+            op.getNumGangsValues(DeviceType::None).end());
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::Host).front(), val1);
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::Star).front(), val2);
+
+  op.getNumGangsMutable().clear();
+  op.removeNumGangsDeviceTypeAttr();
+  op.removeNumGangsSegmentsAttr();
+  for (auto d : dtypes)
+    EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
+
+  op.getNumGangsMutable().append(val1);
+  op.getNumGangsMutable().append(val2);
+  op.getNumGangsMutable().append(val1);
+  op.setNumGangsDeviceTypeAttr(
+      b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
+                      DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
+  op.setNumGangsSegments(b.getDenseI32ArrayAttr({2, 1}));
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(),
+            op.getNumGangsValues(DeviceType::None).end());
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).front(), val1);
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).drop_front().front(),
+            val2);
+  EXPECT_EQ(op.getNumGangsValues(DeviceType::Multicore).front(), val1);
+}
+
+TEST_F(OpenACCOpsTest, numGangsValuesTest) {
+  testNumGangsValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testNumGangsValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
+}
+
+template <typename Op>
+void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc,
+                      llvm::SmallVector<DeviceType> &dtypes) {
+  auto op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+
+  mlir::Value empty;
+  EXPECT_EQ(op.getVectorLengthValue(), empty);
+  for (auto d : dtypes)
+    EXPECT_EQ(op.getVectorLengthValue(d), empty);
+
+  mlir::Value val = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+  auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+  op.setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+  op.getVectorLengthMutable().assign(val);
+  EXPECT_EQ(op.getVectorLengthValue(), empty);
+  EXPECT_EQ(op.getVectorLengthValue(DeviceType::Nvidia), val);
+}
+
+TEST_F(OpenACCOpsTest, vectorLengthTest) {
+  testVectorLength<ParallelOp>(b, context, loc, dtypes);
+  testVectorLength<KernelsOp>(b, context, loc, dtypes);
+}
+
+template <typename Op>
+void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc,
+                  llvm::SmallVector<DeviceType> &dtypes,
+                  llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
+  Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+  EXPECT_FALSE(op.hasWaitOnly());
+  for (auto d : dtypes)
+    EXPECT_FALSE(op.hasWaitOnly(d));
+
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op.setWaitOnlyAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op.hasWaitOnly());
+  EXPECT_TRUE(op.hasWaitOnly(DeviceType::None));
+  for (auto d : dtypesWithoutNone)
+    EXPECT_FALSE(op.hasWaitOnly(d));
+  op.removeWaitOnlyAttr();
+
+  auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+  op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost}));
+  EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host));
+  EXPECT_FALSE(op.hasWaitOnly());
+  op.removeWaitOnlyAttr();
+
+  auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+  op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+  EXPECT_TRUE(op.hasWaitOnly(DeviceType::Star));
+  EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host));
+  EXPECT_FALSE(op.hasWaitOnly());
+}
+
+TEST_F(OpenACCOpsTest, waitOnlyTest) {
+  testWaitOnly<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testWaitOnly<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testWaitOnly<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
+}
+
+template <typename Op>
+void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
+                    llvm::SmallVector<DeviceType> &dtypes,
+                    llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
+  Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+  EXPECT_EQ(op.getWaitValues().begin(), op.getWaitValues().end());
+
+  mlir::Value val1 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+  mlir::Value val2 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(4));
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op.getWaitOperandsMutable().assign(val1);
+  op.setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+  op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1}));
+  EXPECT_EQ(op.getWaitValues().front(), val1);
+  for (auto d : dtypesWithoutNone)
+    EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
+
+  op.getWaitOperandsMutable().clear();
+  op.removeWaitOperandsDeviceTypeAttr();
+  op.removeWaitOperandsSegmentsAttr();
+  for (auto d : dtypes)
+    EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
+
+  op.getWaitOperandsMutable().append(val1);
+  op.getWaitOperandsMutable().append(val2);
+  op.setWaitOperandsDeviceTypeAttr(
+      b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
+                      DeviceTypeAttr::get(&context, DeviceType::Star)}));
+  op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1}));
+  EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(),
+            op.getWaitValues(DeviceType::None).end());
+  EXPECT_EQ(op.getWaitValues(DeviceType::Host).front(), val1);
+  EXPECT_EQ(op.getWaitValues(DeviceType::Star).front(), val2);
+
+  op.getWaitOperandsMutable().clear();
+  op.removeWaitOperandsDeviceTypeAttr();
+  op.removeWaitOperandsSegmentsAttr();
+  for (auto d : dtypes)
+    EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
+
+  op.getWaitOperandsMutable().append(val1);
+  op.getWaitOperandsMutable().append(val2);
+  op.getWaitOperandsMutable().append(val1);
+  op.setWaitOperandsDeviceTypeAttr(
+      b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
+                      DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
+  op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1}));
+  EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(),
+            op.getWaitValues(DeviceType::None).end());
+  EXPECT_EQ(op.getWaitValues(DeviceType::Default).front(), val1);
+  EXPECT_EQ(op.getWaitValues(DeviceType::Default).drop_front().front(), val2);
+  EXPECT_EQ(op.getWaitValues(DeviceType::Multicore).front(), val1);
+}
+
+TEST_F(OpenACCOpsTest, waitValuesTest) {
+  testWaitValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testWaitValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
+  testWaitValues<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
+}



More information about the flang-commits mailing list