[flang-commits] [flang] [mlir] [mlir][openacc][flang] Support wait devnum and homogenize async/wait IR (PR #79525)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 25 15:49:35 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

- Support wait(devnum: ) with device_type support on all operations that require it
  - devnum value is stored as the first value of waitOperands in its device_type sub-segment. The hasWaitDevnum attribute     inform which sub-segment has a wait(devnum) value.
- Make async/wait information homogenous on compute ops, data and update op.
  - Unify operands/attributes names across operations and use the same custom parser/printer

---

Patch is 55.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79525.diff


14 Files Affected:

- (modified) flang/lib/Lower/OpenACC.cpp (+57-36) 
- (modified) flang/test/Lower/OpenACC/acc-data.f90 (+3-3) 
- (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+2-2) 
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+1-4) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+62-34) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+198-77) 
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+1-9) 
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+4-4) 
- (modified) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+30-4) 


``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecfdaa5be993584..427b36a12a2df01 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
           builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, descTy);
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
           builder, loc, loadOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, loadOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
-static void
-genWaitClause(Fortran::lower::AbstractConverter &converter,
-              const Fortran::parser::AccClause::Wait *waitClause,
-              llvm::SmallVector<mlir::Value> &waitOperands,
-              llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
-              llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
-              llvm::SmallVector<int32_t> &waitOperandsSegments,
-              mlir::Value &waitDevnum,
-              llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
-              Fortran::lower::StatementContext &stmtCtx) {
+static void genWaitClauseWithDeviceType(
+    Fortran::lower::AbstractConverter &converter,
+    const Fortran::parser::AccClause::Wait *waitClause,
+    llvm::SmallVector<mlir::Value> &waitOperands,
+    llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+    llvm::SmallVector<bool> &hasDevnums,
+    llvm::SmallVector<int32_t> &waitOperandsSegments,
+    llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
+    Fortran::lower::StatementContext &stmtCtx) {
   const auto &waitClauseValue = waitClause->v;
   if (waitClauseValue) { // wait has a value.
+    llvm::SmallVector<mlir::Value> waitValues;
+
     const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+    const auto &waitDevnumValue =
+        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+    bool hasDevnum = false;
+    if (waitDevnumValue) {
+      waitValues.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
+      hasDevnum = true;
+    }
+
     const auto &waitList =
         std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    llvm::SmallVector<mlir::Value> waitValues;
     for (const Fortran::parser::ScalarIntExpr &value : waitList) {
       waitValues.push_back(fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(value), stmtCtx)));
     }
+
     for (auto deviceTypeAttr : deviceTypeAttrs) {
       for (auto value : waitValues)
         waitOperands.push_back(value);
       waitOperandsDeviceTypes.push_back(deviceTypeAttr);
       waitOperandsSegments.push_back(waitValues.size());
+      hasDevnums.push_back(hasDevnum);
     }
-
-    // TODO: move to device_type model.
-    const auto &waitDevnumValue =
-        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    if (waitDevnumValue)
-      waitDevnum = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
   } else {
     for (auto deviceTypeAttr : deviceTypeAttrs)
       waitOnlyDeviceTypes.push_back(deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
       waitOperandsDeviceTypes, waitOnlyDeviceTypes;
   llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+  llvm::SmallVector<bool> hasWaitDevnums;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
-  mlir::Value waitDevnum; // TODO not yet implemented on compute op.
 
   // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                      asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
           builder.getDenseI32ArrayAttr(numGangsSegments));
   }
   if (!asyncDeviceTypes.empty())
-    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+    computeOp.setAsyncOperandsDeviceTypeAttr(
+        builder.getArrayAttr(asyncDeviceTypes));
   if (!asyncOnlyDeviceTypes.empty())
     computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
 
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (!waitOperandsSegments.empty())
     computeOp.setWaitOperandsSegmentsAttr(
         builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!hasWaitDevnums.empty())
+    computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
   if (!waitOnlyDeviceTypes.empty())
     computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
       waitOperandsDeviceTypes, waitOnlyDeviceTypes;
   llvm::SmallVector<int32_t> waitOperandsSegments;
+  llvm::SmallVector<bool> hasWaitDevnums;
 
   bool hasDefaultNone = false;
   bool hasDefaultPresent = false;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                      asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if(const auto *defaultClause = 
                   std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
       if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<int32_t> operandSegments;
   addOperand(operands, operandSegments, ifCond);
   addOperands(operands, operandSegments, async);
-  addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
 
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
       operandSegments);
 
   if (!asyncDeviceTypes.empty())
-    dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+    dataOp.setAsyncOperandsDeviceTypeAttr(
+        builder.getArrayAttr(asyncDeviceTypes));
   if (!asyncOnlyDeviceTypes.empty())
     dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
   if (!waitOperandsDeviceTypes.empty())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   if (!waitOperandsSegments.empty())
     dataOp.setWaitOperandsSegmentsAttr(
         builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!hasWaitDevnums.empty())
+    dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
   if (!waitOnlyDeviceTypes.empty())
     dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
   return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
 }
 
+static inline mlir::ArrayAttr
+getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
+  return values.empty() ? nullptr : b.getBoolArrayAttr(values);
+}
+
 static inline mlir::DenseI32ArrayAttr
 getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
                      llvm::SmallVector<int32_t> &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
       waitOperands, deviceTypeOperands, asyncOperands;
   llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
       asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<bool> hasWaitDevnums;
   llvm::SmallVector<int32_t> waitOperandsSegments;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                      crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   builder.create<mlir::acc::UpdateOp>(
       currentLocation, ifCond, asyncOperands,
       getArrayAttr(builder, asyncOperandsDeviceTypes),
-      getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+      getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
       getDenseI32ArrayAttr(builder, waitOperandsSegments),
       getArrayAttr(builder, waitOperandsDeviceTypes),
+      getBoolArrayAttr(builder, hasWaitDevnums),
       getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
       ifPresent);
 
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index 75ffd1fc3fcab2f..5b4ab5a65ee6bd2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -164,8 +164,8 @@ subroutine acc_data
   !$acc data present(a) wait
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK: acc.data dataOperands(%{{.*}}) wait {
+! CHECK: }
 
   !$acc data present(a) wait(1)
   !$acc end data
@@ -176,7 +176,7 @@ subroutine acc_data
   !$acc data present(a) wait(devnum: 0: 1)
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
 ! CHECK: }{{$}}
 
   !$acc data default(none)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 21660d5c3a13163..d2134e8d2337ce6 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -93,12 +93,12 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.kernels {
+! CHECK:      acc.kernels wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc kernels loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723b..06194edbe165498 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -61,9 +61,9 @@ subroutine acc_kernels
   !$acc kernels wait
   !$acc end kernels
 
-! CHECK:      acc.kernels  {
+! CHECK:      acc.kernels wait {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc kernels wait(1)
   !$acc end kernels
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 614d201f98e26c4..24e443a20c895d1 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -95,12 +95,12 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.parallel {
+! CHECK:      acc.parallel wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc parallel loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f259955..6b37ecb5fab9aa6 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -83,9 +83,9 @@ subroutine acc_parallel
   !$acc parallel wait
   !$acc end parallel
 
-! CHECK:      acc.parallel {
+! CHECK:      acc.parallel wait {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc parallel wait(1)
   !$acc end parallel
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4134f9ff0ccf577..9c0dbff0d7dac16 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -114,12 +114,12 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.serial {
+! CHECK:      acc.serial wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc serial loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f45..d0fa9436be14a14 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -83,9 +83,9 @@ subroutine acc_serial
   !$acc serial wait
   !$acc end serial
 
-! CHECK:      acc.serial {
+! CHECK:      acc.serial wait {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc serial wait(1)
   !$acc end serial
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index ba036ac92811826..f42ae1356664b67 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -101,10 +101,7 @@ subroutine acc_update
 
   !$acc update host(a) wait(devnum: 1: queues: 1, 2)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) device_type(host, nvidia) async
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 87fd587782e7c35..9398cbfdacee469 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
   }];
 
   let arguments = (ins
-      Variadic<IntOrIndex>:$async,
-      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Variadic<IntOrIndex>:$numGangs,
       OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` custom<DeviceTypeOperands>($async,
-            type($async), $asyncDeviceType) `)`
+      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+            type($asyncOp...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79525


More information about the flang-commits mailing list