[flang-commits] [flang] [mlir] [mlir][openacc][flang] Support wait devnum and homogenize async/wait IR (PR #79525)
via flang-commits
flang-commits at lists.llvm.org
Thu Jan 25 15:49:35 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
- Support wait(devnum: ) with device_type support on all operations that require it
- devnum value is stored as the first value of waitOperands in its device_type sub-segment. The hasWaitDevnum attribute inform which sub-segment has a wait(devnum) value.
- Make async/wait information homogenous on compute ops, data and update op.
- Unify operands/attributes names across operations and use the same custom parser/printer
---
Patch is 55.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79525.diff
14 Files Affected:
- (modified) flang/lib/Lower/OpenACC.cpp (+57-36)
- (modified) flang/test/Lower/OpenACC/acc-data.f90 (+3-3)
- (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+1-4)
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+62-34)
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+198-77)
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+1-9)
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+4-4)
- (modified) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+30-4)
``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecfdaa5be993584..427b36a12a2df01 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}
-static void
-genWaitClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::AccClause::Wait *waitClause,
- llvm::SmallVector<mlir::Value> &waitOperands,
- llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
- llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
- llvm::SmallVector<int32_t> &waitOperandsSegments,
- mlir::Value &waitDevnum,
- llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
- Fortran::lower::StatementContext &stmtCtx) {
+static void genWaitClauseWithDeviceType(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::AccClause::Wait *waitClause,
+ llvm::SmallVector<mlir::Value> &waitOperands,
+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+ llvm::SmallVector<bool> &hasDevnums,
+ llvm::SmallVector<int32_t> &waitOperandsSegments,
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
+ Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
+ llvm::SmallVector<mlir::Value> waitValues;
+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+ const auto &waitDevnumValue =
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+ bool hasDevnum = false;
+ if (waitDevnumValue) {
+ waitValues.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
+ hasDevnum = true;
+ }
+
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
+
for (auto deviceTypeAttr : deviceTypeAttrs) {
for (auto value : waitValues)
waitOperands.push_back(value);
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
waitOperandsSegments.push_back(waitValues.size());
+ hasDevnums.push_back(hasDevnum);
}
-
- // TODO: move to device_type model.
- const auto &waitDevnumValue =
- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- if (waitDevnumValue)
- waitDevnum = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
for (auto deviceTypeAttr : deviceTypeAttrs)
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+ llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
- mlir::Value waitDevnum; // TODO not yet implemented on compute op.
// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
- computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ computeOp.setAsyncOperandsDeviceTypeAttr(
+ builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!hasWaitDevnums.empty())
+ computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> waitOperandsSegments;
+ llvm::SmallVector<bool> hasWaitDevnums;
bool hasDefaultNone = false;
bool hasDefaultPresent = false;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
addOperands(operands, operandSegments, async);
- addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
operandSegments);
if (!asyncDeviceTypes.empty())
- dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ dataOp.setAsyncOperandsDeviceTypeAttr(
+ builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
if (!waitOperandsDeviceTypes.empty())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
dataOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!hasWaitDevnums.empty())
+ dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
}
+static inline mlir::ArrayAttr
+getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
+ return values.empty() ? nullptr : b.getBoolArrayAttr(values);
+}
+
static inline mlir::DenseI32ArrayAttr
getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
llvm::SmallVector<int32_t> &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
waitOperands, deviceTypeOperands, asyncOperands;
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<int32_t> waitOperandsSegments;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
builder.create<mlir::acc::UpdateOp>(
currentLocation, ifCond, asyncOperands,
getArrayAttr(builder, asyncOperandsDeviceTypes),
- getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+ getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
getDenseI32ArrayAttr(builder, waitOperandsSegments),
getArrayAttr(builder, waitOperandsDeviceTypes),
+ getBoolArrayAttr(builder, hasWaitDevnums),
getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
ifPresent);
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index 75ffd1fc3fcab2f..5b4ab5a65ee6bd2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -164,8 +164,8 @@ subroutine acc_data
!$acc data present(a) wait
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK: acc.data dataOperands(%{{.*}}) wait {
+! CHECK: }
!$acc data present(a) wait(1)
!$acc end data
@@ -176,7 +176,7 @@ subroutine acc_data
!$acc data present(a) wait(devnum: 0: 1)
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
! CHECK: }{{$}}
!$acc data default(none)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 21660d5c3a13163..d2134e8d2337ce6 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -93,12 +93,12 @@ subroutine acc_kernels_loop
a(i) = b(i)
END DO
-! CHECK: acc.kernels {
+! CHECK: acc.kernels wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc kernels loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723b..06194edbe165498 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -61,9 +61,9 @@ subroutine acc_kernels
!$acc kernels wait
!$acc end kernels
-! CHECK: acc.kernels {
+! CHECK: acc.kernels wait {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc kernels wait(1)
!$acc end kernels
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 614d201f98e26c4..24e443a20c895d1 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -95,12 +95,12 @@ subroutine acc_parallel_loop
a(i) = b(i)
END DO
-! CHECK: acc.parallel {
+! CHECK: acc.parallel wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc parallel loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f259955..6b37ecb5fab9aa6 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -83,9 +83,9 @@ subroutine acc_parallel
!$acc parallel wait
!$acc end parallel
-! CHECK: acc.parallel {
+! CHECK: acc.parallel wait {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc parallel wait(1)
!$acc end parallel
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4134f9ff0ccf577..9c0dbff0d7dac16 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -114,12 +114,12 @@ subroutine acc_serial_loop
a(i) = b(i)
END DO
-! CHECK: acc.serial {
+! CHECK: acc.serial wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc serial loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f45..d0fa9436be14a14 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -83,9 +83,9 @@ subroutine acc_serial
!$acc serial wait
!$acc end serial
-! CHECK: acc.serial {
+! CHECK: acc.serial wait {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc serial wait(1)
!$acc end serial
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index ba036ac92811826..f42ae1356664b67 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -101,10 +101,7 @@ subroutine acc_update
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) device_type(host, nvidia) async
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 87fd587782e7c35..9398cbfdacee469 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
}];
let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<IntOrIndex>:$numGangs,
OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+ type($asyncOp...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/79525
More information about the flang-commits
mailing list