[flang-commits] [flang] [mlir] [mlir][openacc][flang] Support wait devnum and homogenize async/wait IR (PR #79525)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Jan 25 15:49:54 PST 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/79525
>From 170935dae9efee24b8e905d035ad41ec0cb78e27 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 23 Jan 2024 09:55:28 -0800
Subject: [PATCH] [mlir][openacc][flang] Support wait devnum and homogenize
async/wait IR
---
flang/lib/Lower/OpenACC.cpp | 93 +++---
flang/test/Lower/OpenACC/acc-data.f90 | 6 +-
flang/test/Lower/OpenACC/acc-kernels-loop.f90 | 4 +-
flang/test/Lower/OpenACC/acc-kernels.f90 | 4 +-
.../test/Lower/OpenACC/acc-parallel-loop.f90 | 4 +-
flang/test/Lower/OpenACC/acc-parallel.f90 | 4 +-
flang/test/Lower/OpenACC/acc-serial-loop.f90 | 4 +-
flang/test/Lower/OpenACC/acc-serial.f90 | 4 +-
flang/test/Lower/OpenACC/acc-update.f90 | 5 +-
.../mlir/Dialect/OpenACC/OpenACCOps.td | 96 +++---
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 275 +++++++++++++-----
mlir/test/Dialect/OpenACC/invalid.mlir | 10 +-
mlir/test/Dialect/OpenACC/ops.mlir | 8 +-
.../Dialect/OpenACC/OpenACCOpsTest.cpp | 34 ++-
14 files changed, 368 insertions(+), 183 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecfdaa5be993584..427b36a12a2df01 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}
-static void
-genWaitClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::AccClause::Wait *waitClause,
- llvm::SmallVector<mlir::Value> &waitOperands,
- llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
- llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
- llvm::SmallVector<int32_t> &waitOperandsSegments,
- mlir::Value &waitDevnum,
- llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
- Fortran::lower::StatementContext &stmtCtx) {
+static void genWaitClauseWithDeviceType(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::AccClause::Wait *waitClause,
+ llvm::SmallVector<mlir::Value> &waitOperands,
+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+ llvm::SmallVector<bool> &hasDevnums,
+ llvm::SmallVector<int32_t> &waitOperandsSegments,
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
+ Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
+ llvm::SmallVector<mlir::Value> waitValues;
+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+ const auto &waitDevnumValue =
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+ bool hasDevnum = false;
+ if (waitDevnumValue) {
+ waitValues.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
+ hasDevnum = true;
+ }
+
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
+
for (auto deviceTypeAttr : deviceTypeAttrs) {
for (auto value : waitValues)
waitOperands.push_back(value);
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
waitOperandsSegments.push_back(waitValues.size());
+ hasDevnums.push_back(hasDevnum);
}
-
- // TODO: move to device_type model.
- const auto &waitDevnumValue =
- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- if (waitDevnumValue)
- waitDevnum = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
for (auto deviceTypeAttr : deviceTypeAttrs)
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+ llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
- mlir::Value waitDevnum; // TODO not yet implemented on compute op.
// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
- computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ computeOp.setAsyncOperandsDeviceTypeAttr(
+ builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!hasWaitDevnums.empty())
+ computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> waitOperandsSegments;
+ llvm::SmallVector<bool> hasWaitDevnums;
bool hasDefaultNone = false;
bool hasDefaultPresent = false;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
addOperands(operands, operandSegments, async);
- addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
operandSegments);
if (!asyncDeviceTypes.empty())
- dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ dataOp.setAsyncOperandsDeviceTypeAttr(
+ builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
if (!waitOperandsDeviceTypes.empty())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
dataOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!hasWaitDevnums.empty())
+ dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
}
+static inline mlir::ArrayAttr
+getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
+ return values.empty() ? nullptr : b.getBoolArrayAttr(values);
+}
+
static inline mlir::DenseI32ArrayAttr
getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
llvm::SmallVector<int32_t> &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
waitOperands, deviceTypeOperands, asyncOperands;
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<int32_t> waitOperandsSegments;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
builder.create<mlir::acc::UpdateOp>(
currentLocation, ifCond, asyncOperands,
getArrayAttr(builder, asyncOperandsDeviceTypes),
- getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+ getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
getDenseI32ArrayAttr(builder, waitOperandsSegments),
getArrayAttr(builder, waitOperandsDeviceTypes),
+ getBoolArrayAttr(builder, hasWaitDevnums),
getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
ifPresent);
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index 75ffd1fc3fcab2f..5b4ab5a65ee6bd2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -164,8 +164,8 @@ subroutine acc_data
!$acc data present(a) wait
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK: acc.data dataOperands(%{{.*}}) wait {
+! CHECK: }
!$acc data present(a) wait(1)
!$acc end data
@@ -176,7 +176,7 @@ subroutine acc_data
!$acc data present(a) wait(devnum: 0: 1)
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
! CHECK: }{{$}}
!$acc data default(none)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 21660d5c3a13163..d2134e8d2337ce6 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -93,12 +93,12 @@ subroutine acc_kernels_loop
a(i) = b(i)
END DO
-! CHECK: acc.kernels {
+! CHECK: acc.kernels wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc kernels loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723b..06194edbe165498 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -61,9 +61,9 @@ subroutine acc_kernels
!$acc kernels wait
!$acc end kernels
-! CHECK: acc.kernels {
+! CHECK: acc.kernels wait {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc kernels wait(1)
!$acc end kernels
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 614d201f98e26c4..24e443a20c895d1 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -95,12 +95,12 @@ subroutine acc_parallel_loop
a(i) = b(i)
END DO
-! CHECK: acc.parallel {
+! CHECK: acc.parallel wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc parallel loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f259955..6b37ecb5fab9aa6 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -83,9 +83,9 @@ subroutine acc_parallel
!$acc parallel wait
!$acc end parallel
-! CHECK: acc.parallel {
+! CHECK: acc.parallel wait {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc parallel wait(1)
!$acc end parallel
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4134f9ff0ccf577..9c0dbff0d7dac16 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -114,12 +114,12 @@ subroutine acc_serial_loop
a(i) = b(i)
END DO
-! CHECK: acc.serial {
+! CHECK: acc.serial wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc serial loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f45..d0fa9436be14a14 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -83,9 +83,9 @@ subroutine acc_serial
!$acc serial wait
!$acc end serial
-! CHECK: acc.serial {
+! CHECK: acc.serial wait {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc serial wait(1)
!$acc end serial
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index ba036ac92811826..f42ae1356664b67 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -101,10 +101,7 @@ subroutine acc_update
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) device_type(host, nvidia) async
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 87fd587782e7c35..9398cbfdacee469 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
}];
let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<IntOrIndex>:$numGangs,
OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+ type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
`)`
@@ -998,8 +1004,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
`)`
| `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
type($vectorLength), $vectorLengthDeviceType) `)`
- | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+ | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+ $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+ $waitOnly)
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
| `reduction` `(` custom<SymOperandList>(
@@ -1034,12 +1041,13 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
}];
let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Optional<I1>:$ifCond,
Optional<I1>:$selfCond,
@@ -1084,21 +1092,27 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+ type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
`)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
- | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+ | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+ $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+ $waitOnly)
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
| `reduction` `(` custom<SymOperandList>(
@@ -1135,12 +1149,13 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
}];
let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<IntOrIndex>:$numGangs,
OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -1205,22 +1220,27 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+ type($asyncOperands), $asyncOperandsDeviceType) `)`
| `num_gangs` `(` custom<NumGangs>($numGangs,
type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
| `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
type($numWorkers), $numWorkersDeviceType) `)`
| `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
type($vectorLength), $vectorLengthDeviceType) `)`
- | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType,
- $waitOperandsSegments) `)`
+ | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+ $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+ $waitOnly)
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
)
@@ -1258,13 +1278,13 @@ def OpenACC_DataOp : OpenACC_Op<"data",
let arguments = (ins Optional<I1>:$ifCond,
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
- Optional<IntOrIndex>:$waitDevnum,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
@@ -1300,18 +1320,22 @@ def OpenACC_DataOp : OpenACC_Op<"data",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+ type($asyncOperands), $asyncOperandsDeviceType) `)`
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType,
- $waitOperandsSegments) `)`
+ | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+ $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+ $waitOnly)
)
$region attr-dict-with-keyword
}];
@@ -2199,11 +2223,11 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
Variadic<IntOrIndex>:$asyncOperands,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$async,
- Optional<IntOrIndex>:$waitDevnum,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$wait,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
UnitAttr:$ifPresent);
@@ -2236,6 +2260,11 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
@@ -2244,10 +2273,9 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
| `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
$asyncOperands, type($asyncOperands),
$asyncOperandsDeviceType, $async)
- | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `wait` `` custom<WaitClause>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType,
- $waitOperandsSegments, $wait)
+ | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+ $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+ $waitOnly)
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
)
attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e1e69113bca1683..042153ac749102b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -104,6 +104,87 @@ static void printDeviceTypes(mlir::OpAsmPrinter &p,
p << "]";
}
+static std::optional<unsigned> findSegment(ArrayAttr segments,
+ mlir::acc::DeviceType deviceType) {
+ unsigned segmentIdx = 0;
+ for (auto attr : segments) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ if (deviceTypeAttr.getValue() == deviceType)
+ return std::make_optional(segmentIdx);
+ ++segmentIdx;
+ }
+ return std::nullopt;
+}
+
+static mlir::Operation::operand_range
+getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
+ mlir::Operation::operand_range range,
+ std::optional<llvm::ArrayRef<int32_t>> segments,
+ mlir::acc::DeviceType deviceType) {
+ if (!arrayAttr)
+ return range.take_front(0);
+ if (auto pos = findSegment(*arrayAttr, deviceType)) {
+ int32_t nbOperandsBefore = 0;
+ for (unsigned i = 0; i < *pos; ++i)
+ nbOperandsBefore += (*segments)[i];
+ return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
+ }
+ return range.take_front(0);
+}
+
+static mlir::Value
+getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
+ mlir::Operation::operand_range operands,
+ std::optional<llvm::ArrayRef<int32_t>> segments,
+ std::optional<mlir::ArrayAttr> hasWaitDevnum,
+ mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(deviceTypeAttr))
+ return {};
+ if (auto pos = findSegment(*deviceTypeAttr, deviceType))
+ if (hasWaitDevnum->getValue()[*pos])
+ return getValuesFromSegments(deviceTypeAttr, operands, segments,
+ deviceType)
+ .front();
+ return {};
+}
+
+static mlir::Operation::operand_range
+getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
+ mlir::Operation::operand_range operands,
+ std::optional<llvm::ArrayRef<int32_t>> segments,
+ std::optional<mlir::ArrayAttr> hasWaitDevnum,
+ mlir::acc::DeviceType deviceType) {
+ auto range =
+ getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
+ if (range.empty())
+ return range;
+ if (auto pos = findSegment(*deviceTypeAttr, deviceType))
+ if (hasWaitDevnum && *hasWaitDevnum && hasWaitDevnum->getValue()[*pos])
+ return range.drop_front(1); // first value is devnum
+ return range;
+}
+
+template <typename Op>
+static LogicalResult checkWaitAndAsyncConflict(Op op) {
+ for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+ ++dtypeInt) {
+ auto dtype = static_cast<acc::DeviceType>(dtypeInt);
+
+ // The async attribute represent the async clause without value. Therefore
+ // the attribute and operand cannot appear at the same time.
+ if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
+ op.hasAsyncOnly(dtype))
+ return op.emitError("async attribute cannot appear with asyncOperand");
+
+ // The wait attribute represent the wait clause without values. Therefore
+ // the attribute and operands cannot appear at the same time.
+ if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
+ op.hasWaitOnly(dtype))
+ return op.emitError("wait attribute cannot appear with waitOperands");
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
@@ -649,7 +730,7 @@ unsigned ParallelOp::getNumDataOperands() {
}
Value ParallelOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync().size();
+ unsigned numOptional = getAsyncOperands().size();
numOptional += getNumGangs().size();
numOptional += getNumWorkers().size();
numOptional += getVectorLength().size();
@@ -722,23 +803,15 @@ LogicalResult acc::ParallelOp::verify() {
"vector_length")))
return failure();
- if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
- getAsyncDeviceTypeAttr(), "async")))
+ if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+ getAsyncOperandsDeviceTypeAttr(),
+ "async")))
return failure();
- return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
-}
+ if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
+ return failure();
-static std::optional<unsigned> findSegment(ArrayAttr segments,
- mlir::acc::DeviceType deviceType) {
- unsigned segmentIdx = 0;
- for (auto attr : segments) {
- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
- if (deviceTypeAttr.getValue() == deviceType)
- return std::make_optional(segmentIdx);
- ++segmentIdx;
- }
- return std::nullopt;
+ return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
static mlir::Value
@@ -765,8 +838,8 @@ mlir::Value acc::ParallelOp::getAsyncValue() {
}
mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
+ return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+ getAsyncOperands(), deviceType);
}
mlir::Value acc::ParallelOp::getNumWorkersValue() {
@@ -793,22 +866,6 @@ mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
return getNumGangsValues(mlir::acc::DeviceType::None);
}
-static mlir::Operation::operand_range
-getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
- mlir::Operation::operand_range range,
- std::optional<llvm::ArrayRef<int32_t>> segments,
- mlir::acc::DeviceType deviceType) {
- if (!arrayAttr)
- return range.take_front(0);
- if (auto pos = findSegment(*arrayAttr, deviceType)) {
- int32_t nbOperandsBefore = 0;
- for (unsigned i = 0; i < *pos; ++i)
- nbOperandsBefore += (*segments)[i];
- return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
- }
- return range.take_front(0);
-}
-
mlir::Operation::operand_range
ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
@@ -829,8 +886,19 @@ mlir::Operation::operand_range ParallelOp::getWaitValues() {
mlir::Operation::operand_range
ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
+ return getWaitValuesWithoutDevnum(
+ getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+ getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value ParallelOp::getWaitDevnum() {
+ return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+ return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), getHasWaitDevnum(),
+ deviceType);
}
static ParseResult parseNumGangs(
@@ -967,8 +1035,9 @@ static ParseResult parseWaitClause(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
- mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
- llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+ mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
+ mlir::ArrayAttr &keywordOnly) {
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
llvm::SmallVector<int32_t> seg;
bool needCommaBeforeOperands = false;
@@ -1003,6 +1072,14 @@ static ParseResult parseWaitClause(
int32_t crtOperandsSize = operands.size();
+ if (succeeded(parser.parseOptionalKeyword("devnum"))) {
+ if (failed(parser.parseColon()))
+ return failure();
+ devnum.push_back(BoolAttr::get(parser.getContext(), true));
+ } else {
+ devnum.push_back(BoolAttr::get(parser.getContext(), false));
+ }
+
if (failed(parser.parseCommaSeparatedList(
mlir::AsmParser::Delimiter::None, [&]() {
if (parser.parseOperand(operands.emplace_back()) ||
@@ -1033,6 +1110,7 @@ static ParseResult parseWaitClause(
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+ hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
return success();
}
@@ -1052,6 +1130,7 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
mlir::OperandRange operands, mlir::TypeRange types,
std::optional<mlir::ArrayAttr> deviceTypes,
std::optional<mlir::DenseI32ArrayAttr> segments,
+ std::optional<mlir::ArrayAttr> hasDevNum,
std::optional<mlir::ArrayAttr> keywordOnly) {
if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
@@ -1066,6 +1145,9 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
unsigned opIdx = 0;
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
p << "{";
+ auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
+ if (boolAttr && boolAttr.getValue())
+ p << "devnum: ";
llvm::interleaveComma(
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
p << operands[opIdx] << " : " << operands[opIdx].getType();
@@ -1209,7 +1291,7 @@ unsigned SerialOp::getNumDataOperands() {
}
Value SerialOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync().size();
+ unsigned numOptional = getAsyncOperands().size();
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
@@ -1228,8 +1310,8 @@ mlir::Value acc::SerialOp::getAsyncValue() {
}
mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
+ return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+ getAsyncOperands(), deviceType);
}
bool acc::SerialOp::hasWaitOnly() {
@@ -1246,8 +1328,19 @@ mlir::Operation::operand_range SerialOp::getWaitValues() {
mlir::Operation::operand_range
SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
+ return getWaitValuesWithoutDevnum(
+ getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+ getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value SerialOp::getWaitDevnum() {
+ return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+ return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), getHasWaitDevnum(),
+ deviceType);
}
LogicalResult acc::SerialOp::verify() {
@@ -1265,8 +1358,12 @@ LogicalResult acc::SerialOp::verify() {
getWaitOperandsDeviceTypeAttr(), "wait")))
return failure();
- if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
- getAsyncDeviceTypeAttr(), "async")))
+ if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+ getAsyncOperandsDeviceTypeAttr(),
+ "async")))
+ return failure();
+
+ if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
return failure();
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
@@ -1281,7 +1378,7 @@ unsigned KernelsOp::getNumDataOperands() {
}
Value KernelsOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync().size();
+ unsigned numOptional = getAsyncOperands().size();
numOptional += getWaitOperands().size();
numOptional += getNumGangs().size();
numOptional += getNumWorkers().size();
@@ -1304,8 +1401,8 @@ mlir::Value acc::KernelsOp::getAsyncValue() {
}
mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
+ return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+ getAsyncOperands(), deviceType);
}
mlir::Value acc::KernelsOp::getNumWorkersValue() {
@@ -1352,8 +1449,19 @@ mlir::Operation::operand_range KernelsOp::getWaitValues() {
mlir::Operation::operand_range
KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
+ return getWaitValuesWithoutDevnum(
+ getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+ getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value KernelsOp::getWaitDevnum() {
+ return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+ return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), getHasWaitDevnum(),
+ deviceType);
}
LogicalResult acc::KernelsOp::verify() {
@@ -1377,8 +1485,12 @@ LogicalResult acc::KernelsOp::verify() {
"vector_length")))
return failure();
- if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
- getAsyncDeviceTypeAttr(), "async")))
+ if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
+ getAsyncOperandsDeviceTypeAttr(),
+ "async")))
+ return failure();
+
+ if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
return failure();
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
@@ -1943,6 +2055,9 @@ LogicalResult acc::DataOp::verify() {
return emitError("expect data entry/exit operation or acc.getdeviceptr "
"as defining op");
+ if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
+ return failure();
+
return success();
}
@@ -1950,7 +2065,7 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
Value DataOp::getDataOperand(unsigned i) {
unsigned numOptional = getIfCond() ? 1 : 0;
- numOptional += getAsync().size() ? 1 : 0;
+ numOptional += getAsyncOperands().size() ? 1 : 0;
numOptional += getWaitOperands().size();
return getOperand(numOptional + i);
}
@@ -1968,8 +2083,8 @@ mlir::Value DataOp::getAsyncValue() {
}
mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
+ return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
+ getAsyncOperands(), deviceType);
}
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
@@ -1984,8 +2099,19 @@ mlir::Operation::operand_range DataOp::getWaitValues() {
mlir::Operation::operand_range
DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
+ return getWaitValuesWithoutDevnum(
+ getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+ getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value DataOp::getWaitDevnum() {
+ return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+ return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), getHasWaitDevnum(),
+ deviceType);
}
//===----------------------------------------------------------------------===//
@@ -2549,23 +2675,8 @@ LogicalResult acc::UpdateOp::verify() {
getWaitOperandsDeviceTypeAttr(), "wait")))
return failure();
- for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
- ++dtypeInt) {
- auto dtype = static_cast<acc::DeviceType>(dtypeInt);
-
- // The async attribute represent the async clause without value. Therefore
- // the attribute and operand cannot appear at the same time.
- if (getAsyncValue(dtype) && hasAsyncOnly(dtype))
- return emitError("async attribute cannot appear with asyncOperand");
-
- // The wait attribute represent the wait clause without values. Therefore
- // the attribute and operands cannot appear at the same time.
- if (!getWaitValues(dtype).empty() && hasWaitOnly(dtype))
- return emitError("wait attribute cannot appear with waitOperands");
- }
-
- if (getWaitDevnum() && getWaitOperands().empty())
- return emitError("wait_devnum cannot appear without waitOperands");
+ if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
+ return failure();
for (mlir::Value operand : getDataClauseOperands())
if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
@@ -2582,7 +2693,6 @@ unsigned UpdateOp::getNumDataOperands() {
Value UpdateOp::getDataOperand(unsigned i) {
unsigned numOptional = getAsyncOperands().size();
- numOptional += getWaitDevnum() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
}
@@ -2619,7 +2729,7 @@ bool UpdateOp::hasWaitOnly() {
}
bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- return hasDeviceType(getWait(), deviceType);
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range UpdateOp::getWaitValues() {
@@ -2628,8 +2738,19 @@ mlir::Operation::operand_range UpdateOp::getWaitValues() {
mlir::Operation::operand_range
UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
+ return getWaitValuesWithoutDevnum(
+ getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
+ getHasWaitDevnum(), deviceType);
+}
+
+mlir::Value UpdateOp::getWaitDevnum() {
+ return getWaitDevnum(mlir::acc::DeviceType::None);
+}
+
+mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
+ return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), getHasWaitDevnum(),
+ deviceType);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 80d439f19d9f4cf..16df33eec642ce7 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -126,14 +126,6 @@ acc.update
// -----
-%cst = arith.constant 1 : index
-%value = memref.alloc() : memref<f32>
-%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
-// expected-error at +1 {{wait_devnum cannot appear without waitOperands}}
-acc.update wait_devnum(%cst: index) dataOperands(%0: memref<f32>)
-
-// -----
-
%cst = arith.constant 1 : index
%value = memref.alloc() : memref<f32>
%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
@@ -146,7 +138,7 @@ acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async =
%value = memref.alloc() : memref<f32>
%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
// expected-error at +1 {{wait attribute cannot appear with waitOperands}}
-acc.update wait({%cst: index}) dataOperands(%0: memref<f32>) attributes {wait = [#acc.device_type<none>]}
+acc.update wait({%cst: index}) dataOperands(%0: memref<f32>) attributes {waitOnly = [#acc.device_type<none>]}
// -----
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 45b41f1a7722566..4e6ed8645cdbce7 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -802,7 +802,7 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
%wd1 = arith.constant 1 : i64
- acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) {
+ acc.data wait({devnum: %wd1 : i64, %w1 : i64}) {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
return
@@ -916,7 +916,7 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
// CHECK: acc.data wait({%{{.*}} : i64}) {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
-// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) {
+// CHECK: acc.data wait({devnum: %{{.*}} : i64, %{{.*}} : i64}) {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
// -----
@@ -934,7 +934,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
- acc.update wait_devnum(%i64Value: i64) wait({%i32Value : i32, %idxValue : index}) dataOperands(%0: memref<f32>)
+ acc.update wait({devnum: %i64Value: i64, %i32Value : i32, %idxValue : index}) dataOperands(%0: memref<f32>)
acc.update if(%ifCond) dataOperands(%0: memref<f32>)
acc.update dataOperands(%0: memref<f32>)
acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
@@ -953,7 +953,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
// CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
-// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait({[[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref<f32>)
+// CHECK: acc.update wait({devnum: [[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index 474f887928992a9..41b751b8d7f7cbc 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -86,13 +86,13 @@ void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
OwningOpRef<arith::ConstantIndexOp> val =
b.create<arith::ConstantIndexOp>(loc, 1);
auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
- op->setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
- op->getAsyncMutable().assign(val->getResult());
+ op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+ op->getAsyncOperandsMutable().assign(val->getResult());
EXPECT_EQ(op->getAsyncValue(), empty);
EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
- op->getAsyncMutable().clear();
- op->removeAsyncDeviceTypeAttr();
+ op->getAsyncOperandsMutable().clear();
+ op->removeAsyncOperandsDeviceTypeAttr();
}
TEST_F(OpenACCOpsTest, asyncValueTest) {
@@ -232,6 +232,8 @@ TEST_F(OpenACCOpsTest, waitOnlyTest) {
testWaitOnly<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
testWaitOnly<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
testWaitOnly<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testWaitOnly<UpdateOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testWaitOnly<DataOp>(b, context, loc, dtypes, dtypesWithoutNone);
}
template <typename Op>
@@ -245,6 +247,8 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
b.create<arith::ConstantIndexOp>(loc, 1);
OwningOpRef<arith::ConstantIndexOp> val2 =
b.create<arith::ConstantIndexOp>(loc, 4);
+ OwningOpRef<arith::ConstantIndexOp> val3 =
+ b.create<arith::ConstantIndexOp>(loc, 5);
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
op->getWaitOperandsMutable().assign(val1->getResult());
op->setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
@@ -294,6 +298,28 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
op->getWaitOperandsMutable().clear();
op->removeWaitOperandsDeviceTypeAttr();
op->removeWaitOperandsSegmentsAttr();
+
+ op->getWaitOperandsMutable().append(val3->getResult());
+ op->getWaitOperandsMutable().append(val2->getResult());
+ op->getWaitOperandsMutable().append(val1->getResult());
+ op->setWaitOperandsDeviceTypeAttr(
+ b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
+ op->setHasWaitDevnumAttr(b.getBoolArrayAttr({true}));
+ op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({3}));
+ EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(),
+ op->getWaitValues(DeviceType::None).end());
+ EXPECT_FALSE(op->getWaitDevnum());
+
+ EXPECT_EQ(op->getWaitDevnum(DeviceType::Multicore), val3->getResult());
+ EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).front(),
+ val2->getResult());
+ EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).drop_front().front(),
+ val1->getResult());
+
+ op->getWaitOperandsMutable().clear();
+ op->removeWaitOperandsDeviceTypeAttr();
+ op->removeWaitOperandsSegmentsAttr();
+ op->removeHasWaitDevnumAttr();
}
TEST_F(OpenACCOpsTest, waitValuesTest) {
More information about the flang-commits
mailing list