[flang-commits] [flang] [mlir] [mlir][openacc] Add device_type support for compute operations (PR #75864)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Mon Dec 18 14:35:49 PST 2023
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/75864
This patch adds representation for `device_type` clause information on compute construct (parallel, kernels, serial).
The `device_type` clause on compute construct impacts clauses that appear after it. The values impacted by `device_type` are now tied with a attribute array that represent the device_type. `DeviceType::None` is used to represent the value produced by a clause before any `device_type`. The operands and the attribute information are parser/printed together.
This is an example with `vector_length` clause. The first value (64) is not impacted by `device_type` so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the `device_type(multicore)` clause.
```
!$acc parallel vector_length(64) device_type(multicore) vector_length(256)
```
```
acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type<multicore>]) {
}
```
When multiple values can be produced for a single clause like `num_gangs` and `wait`, an extra attribute describe the number of values belonging to each `device_type`. Values and attributes are parsed/printed together.
```
acc.parallel num_gangs({%c2 : i32}, {%c4 : i32} [#acc.device_type<nvidia>])
```
While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.
>From cdad045654f4bfc6d5066958e35aac7b473451ce Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 7 Dec 2023 14:04:54 -0800
Subject: [PATCH] [mlir][openacc] Add device_type support for compute
operations
---
flang/lib/Lower/OpenACC.cpp | 106 +++-
flang/test/Lower/OpenACC/acc-device-type.f90 | 44 ++
flang/test/Lower/OpenACC/acc-kernels-loop.f90 | 14 +-
flang/test/Lower/OpenACC/acc-kernels.f90 | 14 +-
.../test/Lower/OpenACC/acc-parallel-loop.f90 | 14 +-
flang/test/Lower/OpenACC/acc-parallel.f90 | 16 +-
flang/test/Lower/OpenACC/acc-serial-loop.f90 | 10 +-
flang/test/Lower/OpenACC/acc-serial.f90 | 10 +-
.../mlir/Dialect/OpenACC/OpenACCOps.td | 286 +++++++---
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 515 +++++++++++++++++-
mlir/test/Dialect/OpenACC/invalid.mlir | 4 +-
mlir/test/Dialect/OpenACC/ops.mlir | 76 +--
mlir/unittests/Dialect/CMakeLists.txt | 1 +
mlir/unittests/Dialect/OpenACC/CMakeLists.txt | 8 +
.../Dialect/OpenACC/OpenACCOpsTest.cpp | 275 ++++++++++
15 files changed, 1216 insertions(+), 177 deletions(-)
create mode 100644 flang/test/Lower/OpenACC/acc-device-type.f90
create mode 100644 mlir/unittests/Dialect/OpenACC/CMakeLists.txt
create mode 100644 mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 531685948bc843..57e14bf77e092c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1451,7 +1451,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
return mlir::acc::DeviceType::Multicore;
}
- return mlir::acc::DeviceType::Default;
+ return mlir::acc::DeviceType::None;
}
static void gatherDeviceTypeAttrs(
@@ -1752,26 +1752,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
bool outerCombined = false) {
// Parallel operation operands
- mlir::Value async;
- mlir::Value numWorkers;
- mlir::Value vectorLength;
mlir::Value ifCond;
mlir::Value selfCond;
mlir::Value waitDevnum;
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
- dataClauseOperands, numGangs;
+ dataClauseOperands, numGangs, numWorkers, vectorLength, async;
+ llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
+ vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
- // Async, wait and self clause have optional values but can be present with
+ // Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
- bool addAsyncAttr = false;
- bool addWaitAttr = false;
bool addSelfAttr = false;
bool hasDefaultNone = false;
@@ -1779,6 +1778,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ // device_type attribute is set to `none` until a device_type clause is
+ // encountered.
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None);
+
// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
@@ -1786,27 +1790,52 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+ const auto &asyncClauseValue = asyncClause->v;
+ if (asyncClauseValue) { // async has a value.
+ async.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+ asyncDeviceTypes.push_back(crtDeviceTypeAttr);
+ } else {
+ asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands, waitDevnum,
- addWaitAttr, stmtCtx);
+ const auto &waitClauseValue = waitClause->v;
+ if (waitClauseValue) { // wait has a value.
+ const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+ const auto &waitList =
+ std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+ auto crtWaitOperands = waitOperands.size();
+ for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+ waitOperands.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(value), stmtCtx)));
+ }
+ waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+ } else {
+ waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
+ auto crtNumGangs = numGangs.size();
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
numGangs.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+ numGangsSegments.push_back(numGangs.size() - crtNumGangs);
} else if (const auto *numWorkersClause =
std::get_if<Fortran::parser::AccClause::NumWorkers>(
&clause.u)) {
- numWorkers = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+ numWorkers.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
+ numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *vectorLengthClause =
std::get_if<Fortran::parser::AccClause::VectorLength>(
&clause.u)) {
- vectorLength = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+ vectorLength.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
+ vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1957,18 +1986,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
else if ((defaultClause->v).v ==
llvm::acc::DefaultValue::ACC_Default_present)
hasDefaultPresent = true;
+ } else if (const auto *deviceTypeClause =
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
+ &clause.u)) {
+ const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+ deviceTypeClause->v;
+ assert(deviceTypeExprList.v.size() == 1 &&
+ "expect only one device_type expr");
+ crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
}
}
// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value, 8> operands;
llvm::SmallVector<int32_t, 8> operandSegments;
- addOperand(operands, operandSegments, async);
+ addOperands(operands, operandSegments, async);
addOperands(operands, operandSegments, waitOperands);
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
addOperands(operands, operandSegments, numGangs);
- addOperand(operands, operandSegments, numWorkers);
- addOperand(operands, operandSegments, vectorLength);
+ addOperands(operands, operandSegments, numWorkers);
+ addOperands(operands, operandSegments, vectorLength);
}
addOperand(operands, operandSegments, ifCond);
addOperand(operands, operandSegments, selfCond);
@@ -1989,10 +2027,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder, currentLocation, eval, operands, operandSegments,
outerCombined);
- if (addAsyncAttr)
- computeOp.setAsyncAttrAttr(builder.getUnitAttr());
- if (addWaitAttr)
- computeOp.setWaitAttrAttr(builder.getUnitAttr());
if (addSelfAttr)
computeOp.setSelfAttrAttr(builder.getUnitAttr());
@@ -2001,6 +2035,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (hasDefaultPresent)
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
+ if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
+ if (!numWorkersDeviceTypes.empty())
+ computeOp.setNumWorkersDeviceTypeAttr(
+ mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
+ if (!vectorLengthDeviceTypes.empty())
+ computeOp.setVectorLengthDeviceTypeAttr(
+ mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
+ if (!numGangsDeviceTypes.empty())
+ computeOp.setNumGangsDeviceTypeAttr(
+ mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
+ if (!numGangsSegments.empty())
+ computeOp.setNumGangsSegmentsAttr(
+ builder.getDenseI32ArrayAttr(numGangsSegments));
+ }
+ if (!asyncDeviceTypes.empty())
+ computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ if (!asyncOnlyDeviceTypes.empty())
+ computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+
+ if (!waitOperandsDeviceTypes.empty())
+ computeOp.setWaitOperandsDeviceTypeAttr(
+ builder.getArrayAttr(waitOperandsDeviceTypes));
+ if (!waitOperandsSegments.empty())
+ computeOp.setWaitOperandsSegmentsAttr(
+ builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!waitOnlyDeviceTypes.empty())
+ computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
+
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
if (!privatizations.empty())
computeOp.setPrivatizationsAttr(
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
new file mode 100644
index 00000000000000..871dbc95f60fcb
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -0,0 +1,44 @@
+! This test checks lowering of OpenACC device_type clause on directive where its
+! position and the clauses that follow have special semantic
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1()
+
+ !$acc parallel num_workers(16)
+ !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
+
+ !$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
+ !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+ !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
+ !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+ !$acc parallel vector_length(1)
+ !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
+
+ !$acc parallel device_type(multicore) vector_length(1)
+ !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
+
+ !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
+ !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
+
+ !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
+ !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+
+end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 34e72326972417..93bc699031d550 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc kernels loop async(1)
DO i = 1, n
@@ -103,7 +103,7 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc kernels loop wait(1)
DO i = 1, n
@@ -111,7 +111,7 @@ subroutine acc_kernels_loop
END DO
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels wait([[WAIT1]] : i32) {
+! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
END DO
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
+! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
END DO
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
+! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 1f882c6df51061..99629bb8351723 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -40,7 +40,7 @@ subroutine acc_kernels
! CHECK: acc.kernels {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc kernels async(1)
!$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
! CHECK: acc.kernels {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc kernels wait(1)
!$acc end kernels
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels wait([[WAIT1]] : i32) {
+! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -78,7 +78,7 @@ subroutine acc_kernels
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -87,7 +87,7 @@ subroutine acc_kernels
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -95,7 +95,7 @@ subroutine acc_kernels
!$acc end kernels
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
+! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -103,7 +103,7 @@ subroutine acc_kernels
!$acc end kernels
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
+! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 1856215ce59d13..deee7089033ead 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -64,7 +64,7 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc parallel loop async(1)
DO i = 1, n
@@ -105,7 +105,7 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc parallel loop wait(1)
DO i = 1, n
@@ -113,7 +113,7 @@ subroutine acc_parallel_loop
END DO
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel wait([[WAIT1]] : i32) {
+! CHECK: acc.parallel wait({[[WAIT1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -128,7 +128,7 @@ subroutine acc_parallel_loop
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -143,7 +143,7 @@ subroutine acc_parallel_loop
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -157,7 +157,7 @@ subroutine acc_parallel_loop
END DO
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -171,7 +171,7 @@ subroutine acc_parallel_loop
END DO
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index bbf51ba36a7dea..a369bf01f25995 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -62,7 +62,7 @@ subroutine acc_parallel
! CHECK: acc.parallel {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc parallel async(1)
!$acc end parallel
@@ -85,13 +85,13 @@ subroutine acc_parallel
! CHECK: acc.parallel {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc parallel wait(1)
!$acc end parallel
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel wait([[WAIT1]] : i32) {
+! CHECK: acc.parallel wait({[[WAIT1]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -100,7 +100,7 @@ subroutine acc_parallel
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -109,7 +109,7 @@ subroutine acc_parallel
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -117,7 +117,7 @@ subroutine acc_parallel
!$acc end parallel
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -125,14 +125,14 @@ subroutine acc_parallel
!$acc end parallel
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
!$acc parallel num_gangs(1, 1, 1)
!$acc end parallel
-! CHECK: acc.parallel num_gangs(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) {
+! CHECK: acc.parallel num_gangs({%{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4ed7bb8da29a1a..712bfc80ce387c 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -83,7 +83,7 @@ subroutine acc_serial_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc serial loop async(1)
DO i = 1, n
@@ -124,7 +124,7 @@ subroutine acc_serial_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc serial loop wait(1)
DO i = 1, n
@@ -132,7 +132,7 @@ subroutine acc_serial_loop
END DO
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.serial wait([[WAIT1]] : i32) {
+! CHECK: acc.serial wait({[[WAIT1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -147,7 +147,7 @@ subroutine acc_serial_loop
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -162,7 +162,7 @@ subroutine acc_serial_loop
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index ab3b0ccd545958..d05e51d3d274f4 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -62,7 +62,7 @@ subroutine acc_serial
! CHECK: acc.serial {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc serial async(1)
!$acc end serial
@@ -85,13 +85,13 @@ subroutine acc_serial
! CHECK: acc.serial {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc serial wait(1)
!$acc end serial
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.serial wait([[WAIT1]] : i32) {
+! CHECK: acc.serial wait({[[WAIT1]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -100,7 +100,7 @@ subroutine acc_serial
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -109,7 +109,7 @@ subroutine acc_serial
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 9d48b1f1c3f9af..1ac69f73d48a17 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -156,29 +156,46 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> {
}
// Device type enumeration.
-def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 0, "star">;
-def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 1, "default">;
-def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 2, "host">;
-def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">;
-def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 4, "nvidia">;
-def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 5, "radeon">;
-
+def OpenACC_DeviceTypeNone : I32EnumAttrCase<"None", 0, "none">;
+def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 1, "star">;
+def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 2, "default">;
+def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 3, "host">;
+def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 4, "multicore">;
+def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 5, "nvidia">;
+def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 6, "radeon">;
def OpenACC_DeviceType : I32EnumAttr<"DeviceType",
"built-in device type supported by OpenACC",
- [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
+ [OpenACC_DeviceTypeNone, OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore,
OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::acc";
}
+
+// Device type attribute is used to associate a value for for clauses that
+// appear after a device_type clause. The list of clauses allowed after the
+// device_type clause is defined per construct as follows:
+// Loop construct: collapse, gang, worker, vector, seq, independent, auto,
+// and tile
+// Compute construct: async, wait, num_gangs, num_workers, and vector_length
+// Data construct: async and wait
+// Routine: gang, worker, vector, seq and bind
+//
+// The `none` means that the value appears before any device_type clause.
+//
def OpenACC_DeviceTypeAttr : EnumAttr<OpenACC_Dialect,
OpenACC_DeviceType,
"device_type"> {
let assemblyFormat = [{ ```<` $value `>` }];
}
+def DeviceTypeArrayAttr :
+ TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "device type array attribute"> {
+ let constBuilderCall = ?;
+}
+
// Used for data specification in data clauses (2.7.1).
// Either (or both) extent and upperbound must be specified.
def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
@@ -764,24 +781,32 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
```
}];
- let arguments = (ins Optional<IntOrIndex>:$async,
- UnitAttr:$asyncAttr,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$waitAttr,
- Variadic<IntOrIndex>:$numGangs,
- Optional<IntOrIndex>:$numWorkers,
- Optional<IntOrIndex>:$vectorLength,
- Optional<I1>:$ifCond,
- Optional<I1>:$selfCond,
- UnitAttr:$selfAttr,
- Variadic<AnyType>:$reductionOperands,
- OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$privatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ let arguments = (ins
+ Variadic<IntOrIndex>:$async,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+ Variadic<IntOrIndex>:$numGangs,
+ OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$numGangsDeviceType,
+ Variadic<IntOrIndex>:$numWorkers,
+ OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
+ Variadic<IntOrIndex>:$vectorLength,
+ OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
+ Optional<I1>:$ifCond,
+ Optional<I1>:$selfCond,
+ UnitAttr:$selfAttr,
+ Variadic<AnyType>:$reductionOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -791,22 +816,69 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsyncOnly();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return the value of the num_workers clause if present.
+ mlir::Value getNumWorkersValue();
+ /// Return the value of the num_workers clause for the given device_type if
+ /// present.
+ mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType);
+
+ /// Return the value of the vector_length clause if present.
+ mlir::Value getVectorLengthValue();
+ /// Return the value of the vector_length clause for the given device_type
+ /// if present.
+ mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType);
+
+ /// Return the values of the num_gangs clause if present.
+ mlir::Operation::operand_range getNumGangsValues();
+ /// Return the values of the num_gangs clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getNumGangsValues(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWaitOnly();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` $async `:` type($async) `)`
+ | `async` `(` custom<DeviceTypeOperands>($async,
+ type($async), $asyncDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
`)`
- | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
- | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
+ | `num_gangs` `(` custom<NumGangs>($numGangs,
+ type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
+ | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
+ type($numWorkers), $numWorkersDeviceType) `)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
- | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
+ type($vectorLength), $vectorLengthDeviceType) `)`
+ | `wait` `(` custom<WaitOperands>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
| `reduction` `(` custom<SymOperandList>(
@@ -839,21 +911,25 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
```
}];
- let arguments = (ins Optional<IntOrIndex>:$async,
- UnitAttr:$asyncAttr,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$waitAttr,
- Optional<I1>:$ifCond,
- Optional<I1>:$selfCond,
- UnitAttr:$selfAttr,
- Variadic<AnyType>:$reductionOperands,
- OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$privatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ let arguments = (ins
+ Variadic<IntOrIndex>:$async,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+ Optional<I1>:$ifCond,
+ Optional<I1>:$selfCond,
+ UnitAttr:$selfAttr,
+ Variadic<AnyType>:$reductionOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -863,19 +939,44 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsyncOnly();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWaitOnly();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` $async `:` type($async) `)`
+ | `async` `(` custom<DeviceTypeOperands>($async,
+ type($async), $asyncDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
`)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `wait` `(` custom<WaitOperands>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
| `reduction` `(` custom<SymOperandList>(
@@ -910,18 +1011,26 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
```
}];
- let arguments = (ins Optional<IntOrIndex>:$async,
- UnitAttr:$asyncAttr,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$waitAttr,
- Variadic<IntOrIndex>:$numGangs,
- Optional<IntOrIndex>:$numWorkers,
- Optional<IntOrIndex>:$vectorLength,
- Optional<I1>:$ifCond,
- Optional<I1>:$selfCond,
- UnitAttr:$selfAttr,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ let arguments = (ins
+ Variadic<IntOrIndex>:$async,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+ Variadic<IntOrIndex>:$numGangs,
+ OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$numGangsDeviceType,
+ Variadic<IntOrIndex>:$numWorkers,
+ OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
+ Variadic<IntOrIndex>:$vectorLength,
+ OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
+ Optional<I1>:$ifCond,
+ Optional<I1>:$selfCond,
+ UnitAttr:$selfAttr,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -931,16 +1040,63 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsyncOnly();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return the value of the num_workers clause if present.
+ mlir::Value getNumWorkersValue();
+ /// Return the value of the num_workers clause for the given device_type if
+ /// present.
+ mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType);
+
+ /// Return the value of the vector_length clause if present.
+ mlir::Value getVectorLengthValue();
+ /// Return the value of the vector_length clause for the given device_type
+ /// if present.
+ mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType);
+
+ /// Return the values of the num_gangs clause if present.
+ mlir::Operation::operand_range getNumGangsValues();
+ /// Return the values of the num_gangs clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getNumGangsValues(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWaitOnly();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` $async `:` type($async) `)`
- | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
- | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
- | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `async` `(` custom<DeviceTypeOperands>($async,
+ type($async), $asyncDeviceType) `)`
+ | `num_gangs` `(` custom<NumGangs>($numGangs,
+ type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
+ | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
+ type($numWorkers), $numWorkersDeviceType) `)`
+ | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
+ type($vectorLength), $vectorLengthDeviceType) `)`
+ | `wait` `(` custom<WaitOperands>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 08e83cad482207..3011cff900090d 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -615,15 +615,49 @@ unsigned ParallelOp::getNumDataOperands() {
}
Value ParallelOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync() ? 1 : 0;
+ unsigned numOptional = getAsync().size();
numOptional += getNumGangs().size();
- numOptional += getNumWorkers() ? 1 : 0;
- numOptional += getVectorLength() ? 1 : 0;
+ numOptional += getNumWorkers().size();
+ numOptional += getVectorLength().size();
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
}
+template <typename Op>
+static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
+ ArrayAttr deviceTypes,
+ llvm::StringRef keyword) {
+ if (operands.size() > 0 && deviceTypes.getValue().size() != operands.size())
+ return op.emitOpError() << keyword << " operands count must match "
+ << keyword << " device_type count";
+ return success();
+}
+
+template <typename Op>
+static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
+ Op op, OperandRange operands, DenseI32ArrayAttr segments,
+ ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
+ std::size_t numOperandsInSegments = 0;
+
+ if (!segments)
+ return success();
+
+ for (auto segCount : segments.asArrayRef()) {
+ if (maxInSegment != 0 && segCount > maxInSegment)
+ return op.emitOpError() << keyword << " expects a maximum of "
+ << maxInSegment << " values per segment";
+ numOperandsInSegments += segCount;
+ }
+ if (numOperandsInSegments != operands.size())
+ return op.emitOpError()
+ << keyword << " operand count does not match count in segments";
+ if (deviceTypes.getValue().size() != (size_t)segments.size())
+ return op.emitOpError()
+ << keyword << " segment count does not match device_type count";
+ return success();
+}
+
LogicalResult acc::ParallelOp::verify() {
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
*this, getPrivatizations(), getGangPrivateOperands(), "private",
@@ -633,11 +667,322 @@ LogicalResult acc::ParallelOp::verify() {
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
return failure();
- if (getNumGangs().size() > 3)
- return emitOpError() << "num_gangs expects a maximum of 3 values";
+
+ if (failed(verifyDeviceTypeAndSegmentCountMatch(
+ *this, getNumGangs(), getNumGangsSegmentsAttr(),
+ getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
+ return failure();
+
+ if (failed(verifyDeviceTypeAndSegmentCountMatch(
+ *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+ getWaitOperandsDeviceTypeAttr(), "wait")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
+ getNumWorkersDeviceTypeAttr(),
+ "num_workers")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
+ getVectorLengthDeviceTypeAttr(),
+ "vector_length")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
+ getAsyncDeviceTypeAttr(), "async")))
+ return failure();
+
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
+static std::optional<unsigned> findSegment(ArrayAttr segments,
+ mlir::acc::DeviceType deviceType) {
+ unsigned segmentIdx = 0;
+ for (auto attr : segments) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ if (deviceTypeAttr.getValue() == deviceType)
+ return std::make_optional(segmentIdx);
+ ++segmentIdx;
+ }
+ return std::nullopt;
+}
+
+static mlir::Value
+getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
+ mlir::Operation::operand_range range,
+ mlir::acc::DeviceType deviceType) {
+ if (!arrayAttr)
+ return {};
+ if (auto pos = findSegment(*arrayAttr, deviceType))
+ return range[*pos];
+ return {};
+}
+
+bool acc::ParallelOp::hasAsyncOnly() {
+ return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getAsyncOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Value acc::ParallelOp::getAsyncValue() {
+ return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+ deviceType);
+}
+
+mlir::Value acc::ParallelOp::getNumWorkersValue() {
+ return getNumWorkersValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
+ deviceType);
+}
+
+mlir::Value acc::ParallelOp::getVectorLengthValue() {
+ return getVectorLengthValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
+ getVectorLength(), deviceType);
+}
+
+mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
+ return getNumGangsValues(mlir::acc::DeviceType::None);
+}
+
+static mlir::Operation::operand_range
+getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
+ mlir::Operation::operand_range range,
+ std::optional<llvm::ArrayRef<int32_t>> segments,
+ mlir::acc::DeviceType deviceType) {
+ if (!arrayAttr)
+ return range.take_front(0);
+ if (auto pos = findSegment(*arrayAttr, deviceType)) {
+ int32_t nbOperandsBefore = 0;
+ for (unsigned i = 0; i < *pos; ++i)
+ nbOperandsBefore += (*segments)[i];
+ return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
+ }
+ return range.take_front(0);
+}
+
+mlir::Operation::operand_range
+ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
+ getNumGangsSegments(), deviceType);
+}
+
+bool acc::ParallelOp::hasWaitOnly() {
+ return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getWaitOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Operation::operand_range ParallelOp::getWaitValues() {
+ return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), deviceType);
+}
+
+static ParseResult parseNumGangs(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+ mlir::DenseI32ArrayAttr &segments) {
+ llvm::SmallVector<DeviceTypeAttr> attributes;
+ llvm::SmallVector<int32_t> seg;
+
+ do {
+ if (failed(parser.parseLBrace()))
+ return failure();
+
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ seg.push_back(operands.size());
+
+ if (failed(parser.parseRBrace()))
+ return failure();
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+ segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+ return success();
+}
+
+static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands, mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> deviceTypes,
+ std::optional<mlir::DenseI32ArrayAttr> segments) {
+ unsigned opIdx = 0;
+ for (unsigned i = 0; i < deviceTypes->size(); ++i) {
+ if (i != 0)
+ p << ", ";
+ p << "{";
+ for (int32_t j = 0; j < (*segments)[i]; ++j) {
+ if (j != 0)
+ p << ", ";
+ p << operands[opIdx] << " : " << operands[opIdx].getType();
+ ++opIdx;
+ }
+ p << "}";
+ auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+ if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+ p << " [" << (*deviceTypes)[i] << "]";
+ }
+}
+
+static ParseResult parseWaitOperands(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+ mlir::DenseI32ArrayAttr &segments) {
+ llvm::SmallVector<DeviceTypeAttr> attributes;
+ llvm::SmallVector<int32_t> seg;
+
+ do {
+ if (failed(parser.parseLBrace()))
+ return failure();
+
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ seg.push_back(operands.size());
+
+ if (failed(parser.parseRBrace()))
+ return failure();
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+ segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+ return success();
+}
+
+static void printWaitOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands,
+ mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> deviceTypes,
+ std::optional<mlir::DenseI32ArrayAttr> segments) {
+ unsigned opIdx = 0;
+ for (unsigned i = 0; i < deviceTypes->size(); ++i) {
+ if (i != 0)
+ p << ", ";
+ p << "{";
+ for (int32_t j = 0; j < (*segments)[i]; ++j) {
+ if (j != 0)
+ p << ", ";
+ p << operands[opIdx] << " : " << operands[opIdx].getType();
+ ++opIdx;
+ }
+ p << "}";
+ auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+ if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+ p << " [" << (*deviceTypes)[i] << "]";
+ }
+}
+
+static ParseResult parseDeviceTypeOperands(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
+ llvm::SmallVector<DeviceTypeAttr> attributes;
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ return success();
+ })))
+ return failure();
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+ return success();
+}
+
+static void
+printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands, mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
+ if (i != 0)
+ p << ", ";
+ p << operands[i] << " : " << operands[i].getType();
+ auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+ if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+ p << " [" << (*deviceTypes)[i] << "]";
+ }
+}
+
//===----------------------------------------------------------------------===//
// SerialOp
//===----------------------------------------------------------------------===//
@@ -648,12 +993,55 @@ unsigned SerialOp::getNumDataOperands() {
}
Value SerialOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync() ? 1 : 0;
+ unsigned numOptional = getAsync().size();
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
}
+bool acc::SerialOp::hasAsyncOnly() {
+ return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getAsyncOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Value acc::SerialOp::getAsyncValue() {
+ return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+ deviceType);
+}
+
+bool acc::SerialOp::hasWaitOnly() {
+ return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getWaitOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Operation::operand_range SerialOp::getWaitValues() {
+ return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), deviceType);
+}
+
LogicalResult acc::SerialOp::verify() {
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
*this, getPrivatizations(), getGangPrivateOperands(), "private",
@@ -663,6 +1051,16 @@ LogicalResult acc::SerialOp::verify() {
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
return failure();
+
+ if (failed(verifyDeviceTypeAndSegmentCountMatch(
+ *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+ getWaitOperandsDeviceTypeAttr(), "wait")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
+ getAsyncDeviceTypeAttr(), "async")))
+ return failure();
+
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}
@@ -675,19 +1073,114 @@ unsigned KernelsOp::getNumDataOperands() {
}
Value KernelsOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync() ? 1 : 0;
+ unsigned numOptional = getAsync().size();
numOptional += getWaitOperands().size();
numOptional += getNumGangs().size();
- numOptional += getNumWorkers() ? 1 : 0;
- numOptional += getVectorLength() ? 1 : 0;
+ numOptional += getNumWorkers().size();
+ numOptional += getVectorLength().size();
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(numOptional + i);
}
+bool acc::KernelsOp::hasAsyncOnly() {
+ return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getAsyncOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Value acc::KernelsOp::getAsyncValue() {
+ return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+ deviceType);
+}
+
+mlir::Value acc::KernelsOp::getNumWorkersValue() {
+ return getNumWorkersValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
+ deviceType);
+}
+
+mlir::Value acc::KernelsOp::getVectorLengthValue() {
+ return getVectorLengthValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value
+acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
+ getVectorLength(), deviceType);
+}
+
+mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
+ return getNumGangsValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
+ getNumGangsSegments(), deviceType);
+}
+
+bool acc::KernelsOp::hasWaitOnly() {
+ return hasWaitOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getWaitOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Operation::operand_range KernelsOp::getWaitValues() {
+ return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), deviceType);
+}
+
LogicalResult acc::KernelsOp::verify() {
- if (getNumGangs().size() > 3)
- return emitOpError() << "num_gangs expects a maximum of 3 values";
+ if (failed(verifyDeviceTypeAndSegmentCountMatch(
+ *this, getNumGangs(), getNumGangsSegmentsAttr(),
+ getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
+ return failure();
+
+ if (failed(verifyDeviceTypeAndSegmentCountMatch(
+ *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
+ getWaitOperandsDeviceTypeAttr(), "wait")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
+ getNumWorkersDeviceTypeAttr(),
+ "num_workers")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
+ getVectorLengthDeviceTypeAttr(),
+ "vector_length")))
+ return failure();
+
+ if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
+ getAsyncDeviceTypeAttr(), "async")))
+ return failure();
+
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index b9ac68d0592c87..c18d964b370f2c 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -462,8 +462,8 @@ acc.loop gang() {
// -----
%i64value = arith.constant 1 : i64
-// expected-error at +1 {{num_gangs expects a maximum of 3 values}}
-acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) {
+// expected-error at +1 {{num_gangs expects a maximum of 3 values per segment}}
+acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64value : i64}) {
}
// -----
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 05b0450c7fb916..5a95811685f845 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -137,7 +137,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
%pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
%private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32>
- acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
+ acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
acc.loop gang {
scf.for %x = %lb to %c10 step %st {
acc.loop worker {
@@ -180,7 +180,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
// CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64
// CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32>
-// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
+// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
// CHECK-NEXT: acc.loop gang {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: acc.loop worker {
@@ -439,25 +439,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
}
acc.parallel async(%idxValue: index) {
}
- acc.parallel wait(%i64value: i64) {
+ acc.parallel wait({%i64value: i64}) {
}
- acc.parallel wait(%i32value: i32) {
+ acc.parallel wait({%i32value: i32}) {
}
- acc.parallel wait(%idxValue: index) {
+ acc.parallel wait({%idxValue: index}) {
}
- acc.parallel wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+ acc.parallel wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
}
- acc.parallel num_gangs(%i64value: i64) {
+ acc.parallel num_gangs({%i64value: i64}) {
}
- acc.parallel num_gangs(%i32value: i32) {
+ acc.parallel num_gangs({%i32value: i32}) {
}
- acc.parallel num_gangs(%idxValue: index) {
+ acc.parallel num_gangs({%idxValue: index}) {
}
- acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) {
+ acc.parallel num_gangs({%i64value: i64, %i64value: i64, %idxValue: index}) {
}
- acc.parallel num_workers(%i64value: i64) {
+ acc.parallel num_workers(%i64value: i64 [#acc.device_type<nvidia>]) {
}
- acc.parallel num_workers(%i32value: i32) {
+ acc.parallel num_workers(%i32value: i32 [#acc.device_type<default>]) {
}
acc.parallel num_workers(%idxValue: index) {
}
@@ -492,25 +492,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK-NEXT: }
// CHECK: acc.parallel async([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait([[I64VALUE]] : i64) {
+// CHECK: acc.parallel wait({[[I64VALUE]] : i64}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait([[I32VALUE]] : i32) {
+// CHECK: acc.parallel wait({[[I32VALUE]] : i32}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait([[IDXVALUE]] : index) {
+// CHECK: acc.parallel wait({[[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK: acc.parallel wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs([[I64VALUE]] : i64) {
+// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs([[I32VALUE]] : i32) {
+// CHECK: acc.parallel num_gangs({[[I32VALUE]] : i32}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs([[IDXVALUE]] : index) {
+// CHECK: acc.parallel num_gangs({[[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) {
+// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64, [[I64VALUE]] : i64, [[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_workers([[I64VALUE]] : i64) {
+// CHECK: acc.parallel num_workers([[I64VALUE]] : i64 [#acc.device_type<nvidia>]) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_workers([[I32VALUE]] : i32) {
+// CHECK: acc.parallel num_workers([[I32VALUE]] : i32 [#acc.device_type<default>]) {
// CHECK-NEXT: }
// CHECK: acc.parallel num_workers([[IDXVALUE]] : index) {
// CHECK-NEXT: }
@@ -590,13 +590,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
}
acc.serial async(%idxValue: index) {
}
- acc.serial wait(%i64value: i64) {
+ acc.serial wait({%i64value: i64}) {
}
- acc.serial wait(%i32value: i32) {
+ acc.serial wait({%i32value: i32}) {
}
- acc.serial wait(%idxValue: index) {
+ acc.serial wait({%idxValue: index}) {
}
- acc.serial wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+ acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
}
%firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) -> memref<10xf32>
acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) {
@@ -627,13 +627,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// CHECK-NEXT: }
// CHECK: acc.serial async([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait([[I64VALUE]] : i64) {
+// CHECK: acc.serial wait({[[I64VALUE]] : i64}) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait([[I32VALUE]] : i32) {
+// CHECK: acc.serial wait({[[I32VALUE]] : i32}) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait([[IDXVALUE]] : index) {
+// CHECK: acc.serial wait({[[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
// CHECK-NEXT: }
// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
@@ -665,13 +665,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
}
acc.kernels async(%idxValue: index) {
}
- acc.kernels wait(%i64value: i64) {
+ acc.kernels wait({%i64value: i64}) {
}
- acc.kernels wait(%i32value: i32) {
+ acc.kernels wait({%i32value: i32}) {
}
- acc.kernels wait(%idxValue: index) {
+ acc.kernels wait({%idxValue: index}) {
}
- acc.kernels wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+ acc.kernels wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
}
acc.kernels {
} attributes {defaultAttr = #acc<defaultvalue none>}
@@ -699,13 +699,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// CHECK-NEXT: }
// CHECK: acc.kernels async([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait([[I64VALUE]] : i64) {
+// CHECK: acc.kernels wait({[[I64VALUE]] : i64}) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait([[I32VALUE]] : i32) {
+// CHECK: acc.kernels wait({[[I32VALUE]] : i32}) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait([[IDXVALUE]] : index) {
+// CHECK: acc.kernels wait({[[IDXVALUE]] : index}) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK: acc.kernels wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
// CHECK-NEXT: }
// CHECK: acc.kernels {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 2dec4ba3c001e8..13393569f36fe7 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
+add_subdirectory(OpenACC)
add_subdirectory(SCF)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
new file mode 100644
index 00000000000000..5133d7fc38296c
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIROpenACCTests
+ OpenACCOpsTest.cpp
+)
+target_link_libraries(MLIROpenACCTests
+ PRIVATE
+ MLIRIR
+ MLIROpenACCDialect
+)
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
new file mode 100644
index 00000000000000..dcf6c1240c55d3
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -0,0 +1,275 @@
+//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class OpenACCOpsTest : public ::testing::Test {
+protected:
+ OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) {
+ context.loadDialect<acc::OpenACCDialect, arith::ArithDialect>();
+ }
+
+ MLIRContext context;
+ OpBuilder b;
+ Location loc;
+ llvm::SmallVector<DeviceType> dtypes = {
+ DeviceType::None, DeviceType::Star, DeviceType::Multicore,
+ DeviceType::Default, DeviceType::Host, DeviceType::Nvidia,
+ DeviceType::Radeon};
+ llvm::SmallVector<DeviceType> dtypesWithoutNone = {
+ DeviceType::Star, DeviceType::Multicore, DeviceType::Default,
+ DeviceType::Host, DeviceType::Nvidia, DeviceType::Radeon};
+};
+
+template <typename Op>
+void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+ EXPECT_FALSE(op.hasAsyncOnly());
+ for (auto d : dtypes)
+ EXPECT_FALSE(op.hasAsyncOnly(d));
+
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op.setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op.hasAsyncOnly());
+ EXPECT_TRUE(op.hasAsyncOnly(DeviceType::None));
+ op.removeAsyncOnlyAttr();
+
+ auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+ op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
+ EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host));
+ EXPECT_FALSE(op.hasAsyncOnly());
+ op.removeAsyncOnlyAttr();
+
+ auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+ op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+ EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Star));
+ EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host));
+ EXPECT_FALSE(op.hasAsyncOnly());
+}
+
+TEST_F(OpenACCOpsTest, asyncOnlyTest) {
+ testAsyncOnly<ParallelOp>(b, context, loc, dtypes);
+ testAsyncOnly<KernelsOp>(b, context, loc, dtypes);
+ testAsyncOnly<SerialOp>(b, context, loc, dtypes);
+}
+
+template <typename Op>
+void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+
+ mlir::Value empty;
+ EXPECT_EQ(op.getAsyncValue(), empty);
+ for (auto d : dtypes)
+ EXPECT_EQ(op.getAsyncValue(d), empty);
+
+ mlir::Value val = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+ auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+ op.setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+ op.getAsyncMutable().assign(val);
+ EXPECT_EQ(op.getAsyncValue(), empty);
+ EXPECT_EQ(op.getAsyncValue(DeviceType::Nvidia), val);
+}
+
+TEST_F(OpenACCOpsTest, asyncValueTest) {
+ testAsyncValue<ParallelOp>(b, context, loc, dtypes);
+ testAsyncValue<KernelsOp>(b, context, loc, dtypes);
+ testAsyncValue<SerialOp>(b, context, loc, dtypes);
+}
+
+template <typename Op>
+void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes,
+ llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
+ Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+ EXPECT_EQ(op.getNumGangsValues().begin(), op.getNumGangsValues().end());
+
+ mlir::Value val1 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+ mlir::Value val2 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(4));
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op.getNumGangsMutable().assign(val1);
+ op.setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+ op.setNumGangsSegments(b.getDenseI32ArrayAttr({1}));
+ EXPECT_EQ(op.getNumGangsValues().front(), val1);
+ for (auto d : dtypesWithoutNone)
+ EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
+
+ op.getNumGangsMutable().clear();
+ op.removeNumGangsDeviceTypeAttr();
+ op.removeNumGangsSegmentsAttr();
+ for (auto d : dtypes)
+ EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
+
+ op.getNumGangsMutable().append(val1);
+ op.getNumGangsMutable().append(val2);
+ op.setNumGangsDeviceTypeAttr(
+ b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
+ DeviceTypeAttr::get(&context, DeviceType::Star)}));
+ op.setNumGangsSegments(b.getDenseI32ArrayAttr({1, 1}));
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(),
+ op.getNumGangsValues(DeviceType::None).end());
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::Host).front(), val1);
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::Star).front(), val2);
+
+ op.getNumGangsMutable().clear();
+ op.removeNumGangsDeviceTypeAttr();
+ op.removeNumGangsSegmentsAttr();
+ for (auto d : dtypes)
+ EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
+
+ op.getNumGangsMutable().append(val1);
+ op.getNumGangsMutable().append(val2);
+ op.getNumGangsMutable().append(val1);
+ op.setNumGangsDeviceTypeAttr(
+ b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
+ DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
+ op.setNumGangsSegments(b.getDenseI32ArrayAttr({2, 1}));
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(),
+ op.getNumGangsValues(DeviceType::None).end());
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).front(), val1);
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).drop_front().front(),
+ val2);
+ EXPECT_EQ(op.getNumGangsValues(DeviceType::Multicore).front(), val1);
+}
+
+TEST_F(OpenACCOpsTest, numGangsValuesTest) {
+ testNumGangsValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testNumGangsValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
+}
+
+template <typename Op>
+void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ auto op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+
+ mlir::Value empty;
+ EXPECT_EQ(op.getVectorLengthValue(), empty);
+ for (auto d : dtypes)
+ EXPECT_EQ(op.getVectorLengthValue(d), empty);
+
+ mlir::Value val = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+ auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+ op.setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+ op.getVectorLengthMutable().assign(val);
+ EXPECT_EQ(op.getVectorLengthValue(), empty);
+ EXPECT_EQ(op.getVectorLengthValue(DeviceType::Nvidia), val);
+}
+
+TEST_F(OpenACCOpsTest, vectorLengthTest) {
+ testVectorLength<ParallelOp>(b, context, loc, dtypes);
+ testVectorLength<KernelsOp>(b, context, loc, dtypes);
+}
+
+template <typename Op>
+void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes,
+ llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
+ Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+ EXPECT_FALSE(op.hasWaitOnly());
+ for (auto d : dtypes)
+ EXPECT_FALSE(op.hasWaitOnly(d));
+
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op.setWaitOnlyAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op.hasWaitOnly());
+ EXPECT_TRUE(op.hasWaitOnly(DeviceType::None));
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op.hasWaitOnly(d));
+ op.removeWaitOnlyAttr();
+
+ auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+ op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost}));
+ EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host));
+ EXPECT_FALSE(op.hasWaitOnly());
+ op.removeWaitOnlyAttr();
+
+ auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+ op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+ EXPECT_TRUE(op.hasWaitOnly(DeviceType::Star));
+ EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host));
+ EXPECT_FALSE(op.hasWaitOnly());
+}
+
+TEST_F(OpenACCOpsTest, waitOnlyTest) {
+ testWaitOnly<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testWaitOnly<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testWaitOnly<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
+}
+
+template <typename Op>
+void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes,
+ llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
+ Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
+ EXPECT_EQ(op.getWaitValues().begin(), op.getWaitValues().end());
+
+ mlir::Value val1 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
+ mlir::Value val2 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(4));
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op.getWaitOperandsMutable().assign(val1);
+ op.setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+ op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1}));
+ EXPECT_EQ(op.getWaitValues().front(), val1);
+ for (auto d : dtypesWithoutNone)
+ EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
+
+ op.getWaitOperandsMutable().clear();
+ op.removeWaitOperandsDeviceTypeAttr();
+ op.removeWaitOperandsSegmentsAttr();
+ for (auto d : dtypes)
+ EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
+
+ op.getWaitOperandsMutable().append(val1);
+ op.getWaitOperandsMutable().append(val2);
+ op.setWaitOperandsDeviceTypeAttr(
+ b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
+ DeviceTypeAttr::get(&context, DeviceType::Star)}));
+ op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1}));
+ EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(),
+ op.getWaitValues(DeviceType::None).end());
+ EXPECT_EQ(op.getWaitValues(DeviceType::Host).front(), val1);
+ EXPECT_EQ(op.getWaitValues(DeviceType::Star).front(), val2);
+
+ op.getWaitOperandsMutable().clear();
+ op.removeWaitOperandsDeviceTypeAttr();
+ op.removeWaitOperandsSegmentsAttr();
+ for (auto d : dtypes)
+ EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
+
+ op.getWaitOperandsMutable().append(val1);
+ op.getWaitOperandsMutable().append(val2);
+ op.getWaitOperandsMutable().append(val1);
+ op.setWaitOperandsDeviceTypeAttr(
+ b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
+ DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
+ op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1}));
+ EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(),
+ op.getWaitValues(DeviceType::None).end());
+ EXPECT_EQ(op.getWaitValues(DeviceType::Default).front(), val1);
+ EXPECT_EQ(op.getWaitValues(DeviceType::Default).drop_front().front(), val2);
+ EXPECT_EQ(op.getWaitValues(DeviceType::Multicore).front(), val1);
+}
+
+TEST_F(OpenACCOpsTest, waitValuesTest) {
+ testWaitValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testWaitValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
+ testWaitValues<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
+}
More information about the flang-commits
mailing list