[llvm-branch-commits] [mlir] 10df608 - Revert "[mlir][openacc] Add device_type support for compute operations (#75864)"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 20 16:06:55 PST 2023
Author: Valentin Clement (バレンタイン クレメン)
Date: 2023-12-20T16:06:50-08:00
New Revision: 10df6087ffdacaab86f659c1fe080f0673517ba6
URL: https://github.com/llvm/llvm-project/commit/10df6087ffdacaab86f659c1fe080f0673517ba6
DIFF: https://github.com/llvm/llvm-project/commit/10df6087ffdacaab86f659c1fe080f0673517ba6.diff
LOG: Revert "[mlir][openacc] Add device_type support for compute operations (#75864)"
This reverts commit 8b885eb90ff14862b579b191c3f469a5a4fed1bc.
Added:
Modified:
flang/lib/Lower/OpenACC.cpp
flang/test/Lower/OpenACC/acc-kernels-loop.f90
flang/test/Lower/OpenACC/acc-kernels.f90
flang/test/Lower/OpenACC/acc-parallel-loop.f90
flang/test/Lower/OpenACC/acc-parallel.f90
flang/test/Lower/OpenACC/acc-serial-loop.f90
flang/test/Lower/OpenACC/acc-serial.f90
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/invalid.mlir
mlir/test/Dialect/OpenACC/ops.mlir
mlir/unittests/Dialect/CMakeLists.txt
Removed:
flang/test/Lower/OpenACC/acc-device-type.f90
mlir/unittests/Dialect/OpenACC/CMakeLists.txt
mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
################################################################################
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecf70818c4ac0f..c2ee1a44a06bde 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
return mlir::acc::DeviceType::Multicore;
}
- return mlir::acc::DeviceType::None;
+ return mlir::acc::DeviceType::Default;
}
static void gatherDeviceTypeAttrs(
@@ -1781,24 +1781,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
bool outerCombined = false) {
// Parallel operation operands
+ mlir::Value async;
+ mlir::Value numWorkers;
+ mlir::Value vectorLength;
mlir::Value ifCond;
mlir::Value selfCond;
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
- dataClauseOperands, numGangs, numWorkers, vectorLength, async;
- llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
- vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes;
- llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+ dataClauseOperands, numGangs;
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
- // Self clause has optional values but can be present with
+ // Async, wait and self clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
+ bool addAsyncAttr = false;
+ bool addWaitAttr = false;
bool addSelfAttr = false;
bool hasDefaultNone = false;
@@ -1806,11 +1807,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- // device_type attribute is set to `none` until a device_type clause is
- // encountered.
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), mlir::acc::DeviceType::None);
-
// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
@@ -1818,52 +1814,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- const auto &asyncClauseValue = asyncClause->v;
- if (asyncClauseValue) { // async has a value.
- async.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
- asyncDeviceTypes.push_back(crtDeviceTypeAttr);
- } else {
- asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
- }
+ genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- const auto &waitClauseValue = waitClause->v;
- if (waitClauseValue) { // wait has a value.
- const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
- const auto &waitList =
- std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- auto crtWaitOperands = waitOperands.size();
- for (const Fortran::parser::ScalarIntExpr &value : waitList) {
- waitOperands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(value), stmtCtx)));
- }
- waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
- waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
- } else {
- waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
- }
+ genWaitClause(converter, waitClause, waitOperands, waitDevnum,
+ addWaitAttr, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
- auto crtNumGangs = numGangs.size();
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
numGangs.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(expr), stmtCtx)));
- numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
- numGangsSegments.push_back(numGangs.size() - crtNumGangs);
} else if (const auto *numWorkersClause =
std::get_if<Fortran::parser::AccClause::NumWorkers>(
&clause.u)) {
- numWorkers.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
- numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+ numWorkers = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
} else if (const auto *vectorLengthClause =
std::get_if<Fortran::parser::AccClause::VectorLength>(
&clause.u)) {
- vectorLength.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
- vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+ vectorLength = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
} else if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2014,27 +1985,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
else if ((defaultClause->v).v ==
llvm::acc::DefaultValue::ACC_Default_present)
hasDefaultPresent = true;
- } else if (const auto *deviceTypeClause =
- std::get_if<Fortran::parser::AccClause::DeviceType>(
- &clause.u)) {
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
- deviceTypeClause->v;
- assert(deviceTypeExprList.v.size() == 1 &&
- "expect only one device_type expr");
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
}
}
// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value, 8> operands;
llvm::SmallVector<int32_t, 8> operandSegments;
- addOperands(operands, operandSegments, async);
+ addOperand(operands, operandSegments, async);
addOperands(operands, operandSegments, waitOperands);
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
addOperands(operands, operandSegments, numGangs);
- addOperands(operands, operandSegments, numWorkers);
- addOperands(operands, operandSegments, vectorLength);
+ addOperand(operands, operandSegments, numWorkers);
+ addOperand(operands, operandSegments, vectorLength);
}
addOperand(operands, operandSegments, ifCond);
addOperand(operands, operandSegments, selfCond);
@@ -2055,6 +2017,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder, currentLocation, eval, operands, operandSegments,
outerCombined);
+ if (addAsyncAttr)
+ computeOp.setAsyncAttrAttr(builder.getUnitAttr());
+ if (addWaitAttr)
+ computeOp.setWaitAttrAttr(builder.getUnitAttr());
if (addSelfAttr)
computeOp.setSelfAttrAttr(builder.getUnitAttr());
@@ -2063,34 +2029,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (hasDefaultPresent)
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
- if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
- if (!numWorkersDeviceTypes.empty())
- computeOp.setNumWorkersDeviceTypeAttr(
- mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
- if (!vectorLengthDeviceTypes.empty())
- computeOp.setVectorLengthDeviceTypeAttr(
- mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
- if (!numGangsDeviceTypes.empty())
- computeOp.setNumGangsDeviceTypeAttr(
- mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
- if (!numGangsSegments.empty())
- computeOp.setNumGangsSegmentsAttr(
- builder.getDenseI32ArrayAttr(numGangsSegments));
- }
- if (!asyncDeviceTypes.empty())
- computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
- if (!asyncOnlyDeviceTypes.empty())
- computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
-
- if (!waitOperandsDeviceTypes.empty())
- computeOp.setWaitOperandsDeviceTypeAttr(
- builder.getArrayAttr(waitOperandsDeviceTypes));
- if (!waitOperandsSegments.empty())
- computeOp.setWaitOperandsSegmentsAttr(
- builder.getDenseI32ArrayAttr(waitOperandsSegments));
- if (!waitOnlyDeviceTypes.empty())
- computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
-
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
if (!privatizations.empty())
computeOp.setPrivatizationsAttr(
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
deleted file mode 100644
index 871dbc95f60fcb..00000000000000
--- a/flang/test/Lower/OpenACC/acc-device-type.f90
+++ /dev/null
@@ -1,44 +0,0 @@
-! This test checks lowering of OpenACC device_type clause on directive where its
-! position and the clauses that follow have special semantic
-
-! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
-
-subroutine sub1()
-
- !$acc parallel num_workers(16)
- !$acc end parallel
-
-! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
-
- !$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
- !$acc end parallel
-
-! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
-
- !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
- !$acc end parallel
-
-! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
-
- !$acc parallel vector_length(1)
- !$acc end parallel
-
-! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
-
- !$acc parallel device_type(multicore) vector_length(1)
- !$acc end parallel
-
-! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
-
- !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
- !$acc end parallel
-
-! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
-
- !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
- !$acc end parallel
-
-! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
-
-
-end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 93bc699031d550..34e72326972417 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {asyncAttr}
!$acc kernels loop async(1)
DO i = 1, n
@@ -103,7 +103,7 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {waitAttr}
!$acc kernels loop wait(1)
DO i = 1, n
@@ -111,7 +111,7 @@ subroutine acc_kernels_loop
END DO
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
+! CHECK: acc.kernels wait([[WAIT1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
+! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
+! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
END DO
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
+! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
END DO
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
+! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723..1f882c6df51061 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -40,7 +40,7 @@ subroutine acc_kernels
! CHECK: acc.kernels {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {asyncAttr}
!$acc kernels async(1)
!$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
! CHECK: acc.kernels {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {waitAttr}
!$acc kernels wait(1)
!$acc end kernels
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
+! CHECK: acc.kernels wait([[WAIT1]] : i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -78,7 +78,7 @@ subroutine acc_kernels
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
+! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -87,7 +87,7 @@ subroutine acc_kernels
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
+! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -95,7 +95,7 @@ subroutine acc_kernels
!$acc end kernels
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
+! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
@@ -103,7 +103,7 @@ subroutine acc_kernels
!$acc end kernels
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
+! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index deee7089033ead..1856215ce59d13 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -64,7 +64,7 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {asyncAttr}
!$acc parallel loop async(1)
DO i = 1, n
@@ -105,7 +105,7 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {waitAttr}
!$acc parallel loop wait(1)
DO i = 1, n
@@ -113,7 +113,7 @@ subroutine acc_parallel_loop
END DO
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel wait({[[WAIT1]] : i32}) {
+! CHECK: acc.parallel wait([[WAIT1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -128,7 +128,7 @@ subroutine acc_parallel_loop
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
+! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -143,7 +143,7 @@ subroutine acc_parallel_loop
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
+! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -157,7 +157,7 @@ subroutine acc_parallel_loop
END DO
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
+! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -171,7 +171,7 @@ subroutine acc_parallel_loop
END DO
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
+! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f25995..bbf51ba36a7dea 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -62,7 +62,7 @@ subroutine acc_parallel
! CHECK: acc.parallel {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {asyncAttr}
!$acc parallel async(1)
!$acc end parallel
@@ -85,13 +85,13 @@ subroutine acc_parallel
! CHECK: acc.parallel {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {waitAttr}
!$acc parallel wait(1)
!$acc end parallel
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel wait({[[WAIT1]] : i32}) {
+! CHECK: acc.parallel wait([[WAIT1]] : i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -100,7 +100,7 @@ subroutine acc_parallel
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
+! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -109,7 +109,7 @@ subroutine acc_parallel
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
+! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -117,7 +117,7 @@ subroutine acc_parallel
!$acc end parallel
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
+! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -125,14 +125,14 @@ subroutine acc_parallel
!$acc end parallel
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
+! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
!$acc parallel num_gangs(1, 1, 1)
!$acc end parallel
-! CHECK: acc.parallel num_gangs({%{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32}) {
+! CHECK: acc.parallel num_gangs(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 712bfc80ce387c..4ed7bb8da29a1a 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -83,7 +83,7 @@ subroutine acc_serial_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {asyncAttr}
!$acc serial loop async(1)
DO i = 1, n
@@ -124,7 +124,7 @@ subroutine acc_serial_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {waitAttr}
!$acc serial loop wait(1)
DO i = 1, n
@@ -132,7 +132,7 @@ subroutine acc_serial_loop
END DO
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.serial wait({[[WAIT1]] : i32}) {
+! CHECK: acc.serial wait([[WAIT1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -147,7 +147,7 @@ subroutine acc_serial_loop
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
+! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
@@ -162,7 +162,7 @@ subroutine acc_serial_loop
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
+! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f4..ab3b0ccd545958 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -62,7 +62,7 @@ subroutine acc_serial
! CHECK: acc.serial {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {asyncAttr}
!$acc serial async(1)
!$acc end serial
@@ -85,13 +85,13 @@ subroutine acc_serial
! CHECK: acc.serial {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: } attributes {waitAttr}
!$acc serial wait(1)
!$acc end serial
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.serial wait({[[WAIT1]] : i32}) {
+! CHECK: acc.serial wait([[WAIT1]] : i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -100,7 +100,7 @@ subroutine acc_serial
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
+! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
@@ -109,7 +109,7 @@ subroutine acc_serial
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
+! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 234c1076e14e3b..a78c3e98c95517 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -156,46 +156,29 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> {
}
// Device type enumeration.
-def OpenACC_DeviceTypeNone : I32EnumAttrCase<"None", 0, "none">;
-def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 1, "star">;
-def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 2, "default">;
-def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 3, "host">;
-def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 4, "multicore">;
-def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 5, "nvidia">;
-def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 6, "radeon">;
+def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 0, "star">;
+def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 1, "default">;
+def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 2, "host">;
+def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">;
+def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 4, "nvidia">;
+def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 5, "radeon">;
+
def OpenACC_DeviceType : I32EnumAttr<"DeviceType",
"built-in device type supported by OpenACC",
- [OpenACC_DeviceTypeNone, OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
+ [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore,
OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::acc";
}
-
-// Device type attribute is used to associate a value for for clauses that
-// appear after a device_type clause. The list of clauses allowed after the
-// device_type clause is defined per construct as follows:
-// Loop construct: collapse, gang, worker, vector, seq, independent, auto,
-// and tile
-// Compute construct: async, wait, num_gangs, num_workers, and vector_length
-// Data construct: async and wait
-// Routine: gang, worker, vector, seq and bind
-//
-// The `none` means that the value appears before any device_type clause.
-//
def OpenACC_DeviceTypeAttr : EnumAttr<OpenACC_Dialect,
OpenACC_DeviceType,
"device_type"> {
let assemblyFormat = [{ ```<` $value `>` }];
}
-def DeviceTypeArrayAttr :
- TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "device type array attribute"> {
- let constBuilderCall = ?;
-}
-
// Define a resource for the OpenACC runtime counters.
def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
@@ -880,32 +863,24 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
```
}];
- let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
- Variadic<IntOrIndex>:$waitOperands,
- OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
- OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
- Variadic<IntOrIndex>:$numGangs,
- OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
- OptionalAttr<DeviceTypeArrayAttr>:$numGangsDeviceType,
- Variadic<IntOrIndex>:$numWorkers,
- OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
- Variadic<IntOrIndex>:$vectorLength,
- OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
- Optional<I1>:$ifCond,
- Optional<I1>:$selfCond,
- UnitAttr:$selfAttr,
- Variadic<AnyType>:$reductionOperands,
- OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$privatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ let arguments = (ins Optional<IntOrIndex>:$async,
+ UnitAttr:$asyncAttr,
+ Variadic<IntOrIndex>:$waitOperands,
+ UnitAttr:$waitAttr,
+ Variadic<IntOrIndex>:$numGangs,
+ Optional<IntOrIndex>:$numWorkers,
+ Optional<IntOrIndex>:$vectorLength,
+ Optional<I1>:$ifCond,
+ Optional<I1>:$selfCond,
+ UnitAttr:$selfAttr,
+ Variadic<AnyType>:$reductionOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -915,69 +890,22 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
-
- /// Return true if the op has the async attribute for the
- /// mlir::acc::DeviceType::None device_type.
- bool hasAsyncOnly();
- /// Return true if the op has the async attribute for the given device_type.
- bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
- /// Return the value of the async clause if present.
- mlir::Value getAsyncValue();
- /// Return the value of the async clause for the given device_type if
- /// present.
- mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
-
- /// Return the value of the num_workers clause if present.
- mlir::Value getNumWorkersValue();
- /// Return the value of the num_workers clause for the given device_type if
- /// present.
- mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType);
-
- /// Return the value of the vector_length clause if present.
- mlir::Value getVectorLengthValue();
- /// Return the value of the vector_length clause for the given device_type
- /// if present.
- mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType);
-
- /// Return the values of the num_gangs clause if present.
- mlir::Operation::operand_range getNumGangsValues();
- /// Return the values of the num_gangs clause for the given device_type if
- /// present.
- mlir::Operation::operand_range
- getNumGangsValues(mlir::acc::DeviceType deviceType);
-
- /// Return true if the op has the wait attribute for the
- /// mlir::acc::DeviceType::None device_type.
- bool hasWaitOnly();
- /// Return true if the op has the wait attribute for the given device_type.
- bool hasWaitOnly(mlir::acc::DeviceType deviceType);
- /// Return the values of the wait clause if present.
- mlir::Operation::operand_range getWaitValues();
- /// Return the values of the wait clause for the given device_type if
- /// present.
- mlir::Operation::operand_range
- getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` $async `:` type($async) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
`)`
- | `num_gangs` `(` custom<NumGangs>($numGangs,
- type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
- | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
- type($numWorkers), $numWorkersDeviceType) `)`
+ | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
+ | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
- | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
- type($vectorLength), $vectorLengthDeviceType) `)`
- | `wait` `(` custom<WaitOperands>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+ | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
+ | `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
| `reduction` `(` custom<SymOperandList>(
@@ -1011,25 +939,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
```
}];
- let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
- Variadic<IntOrIndex>:$waitOperands,
- OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
- OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
- Optional<I1>:$ifCond,
- Optional<I1>:$selfCond,
- UnitAttr:$selfAttr,
- Variadic<AnyType>:$reductionOperands,
- OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$privatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
- OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ let arguments = (ins Optional<IntOrIndex>:$async,
+ UnitAttr:$asyncAttr,
+ Variadic<IntOrIndex>:$waitOperands,
+ UnitAttr:$waitAttr,
+ Optional<I1>:$ifCond,
+ Optional<I1>:$selfCond,
+ UnitAttr:$selfAttr,
+ Variadic<AnyType>:$reductionOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -1039,44 +963,19 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
-
- /// Return true if the op has the async attribute for the
- /// mlir::acc::DeviceType::None device_type.
- bool hasAsyncOnly();
- /// Return true if the op has the async attribute for the given device_type.
- bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
- /// Return the value of the async clause if present.
- mlir::Value getAsyncValue();
- /// Return the value of the async clause for the given device_type if
- /// present.
- mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
-
- /// Return true if the op has the wait attribute for the
- /// mlir::acc::DeviceType::None device_type.
- bool hasWaitOnly();
- /// Return true if the op has the wait attribute for the given device_type.
- bool hasWaitOnly(mlir::acc::DeviceType deviceType);
- /// Return the values of the wait clause if present.
- mlir::Operation::operand_range getWaitValues();
- /// Return the values of the wait clause for the given device_type if
- /// present.
- mlir::Operation::operand_range
- getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` $async `:` type($async) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
`)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
- | `wait` `(` custom<WaitOperands>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+ | `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
| `reduction` `(` custom<SymOperandList>(
@@ -1112,26 +1011,18 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
```
}];
- let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
- Variadic<IntOrIndex>:$waitOperands,
- OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
- OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
- Variadic<IntOrIndex>:$numGangs,
- OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
- OptionalAttr<DeviceTypeArrayAttr>:$numGangsDeviceType,
- Variadic<IntOrIndex>:$numWorkers,
- OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType,
- Variadic<IntOrIndex>:$vectorLength,
- OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType,
- Optional<I1>:$ifCond,
- Optional<I1>:$selfCond,
- UnitAttr:$selfAttr,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ let arguments = (ins Optional<IntOrIndex>:$async,
+ UnitAttr:$asyncAttr,
+ Variadic<IntOrIndex>:$waitOperands,
+ UnitAttr:$waitAttr,
+ Variadic<IntOrIndex>:$numGangs,
+ Optional<IntOrIndex>:$numWorkers,
+ Optional<IntOrIndex>:$vectorLength,
+ Optional<I1>:$ifCond,
+ Optional<I1>:$selfCond,
+ UnitAttr:$selfAttr,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -1141,63 +1032,16 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
-
- /// Return true if the op has the async attribute for the
- /// mlir::acc::DeviceType::None device_type.
- bool hasAsyncOnly();
- /// Return true if the op has the async attribute for the given device_type.
- bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
- /// Return the value of the async clause if present.
- mlir::Value getAsyncValue();
- /// Return the value of the async clause for the given device_type if
- /// present.
- mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
-
- /// Return the value of the num_workers clause if present.
- mlir::Value getNumWorkersValue();
- /// Return the value of the num_workers clause for the given device_type if
- /// present.
- mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType);
-
- /// Return the value of the vector_length clause if present.
- mlir::Value getVectorLengthValue();
- /// Return the value of the vector_length clause for the given device_type
- /// if present.
- mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType);
-
- /// Return the values of the num_gangs clause if present.
- mlir::Operation::operand_range getNumGangsValues();
- /// Return the values of the num_gangs clause for the given device_type if
- /// present.
- mlir::Operation::operand_range
- getNumGangsValues(mlir::acc::DeviceType deviceType);
-
- /// Return true if the op has the wait attribute for the
- /// mlir::acc::DeviceType::None device_type.
- bool hasWaitOnly();
- /// Return true if the op has the wait attribute for the given device_type.
- bool hasWaitOnly(mlir::acc::DeviceType deviceType);
- /// Return the values of the wait clause if present.
- mlir::Operation::operand_range getWaitValues();
- /// Return the values of the wait clause for the given device_type if
- /// present.
- mlir::Operation::operand_range
- getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
- | `num_gangs` `(` custom<NumGangs>($numGangs,
- type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
- | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
- type($numWorkers), $numWorkersDeviceType) `)`
- | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
- type($vectorLength), $vectorLengthDeviceType) `)`
- | `wait` `(` custom<WaitOperands>($waitOperands,
- type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)`
+ | `async` `(` $async `:` type($async) `)`
+ | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
+ | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
+ | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
+ | `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 45e0632db5ef2b..df4f7825545c2b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -615,49 +615,15 @@ unsigned ParallelOp::getNumDataOperands() {
}
Value ParallelOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync().size();
+ unsigned numOptional = getAsync() ? 1 : 0;
numOptional += getNumGangs().size();
- numOptional += getNumWorkers().size();
- numOptional += getVectorLength().size();
+ numOptional += getNumWorkers() ? 1 : 0;
+ numOptional += getVectorLength() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
}
-template <typename Op>
-static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
- ArrayAttr deviceTypes,
- llvm::StringRef keyword) {
- if (operands.size() > 0 && deviceTypes.getValue().size() != operands.size())
- return op.emitOpError() << keyword << " operands count must match "
- << keyword << " device_type count";
- return success();
-}
-
-template <typename Op>
-static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
- Op op, OperandRange operands, DenseI32ArrayAttr segments,
- ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
- std::size_t numOperandsInSegments = 0;
-
- if (!segments)
- return success();
-
- for (auto segCount : segments.asArrayRef()) {
- if (maxInSegment != 0 && segCount > maxInSegment)
- return op.emitOpError() << keyword << " expects a maximum of "
- << maxInSegment << " values per segment";
- numOperandsInSegments += segCount;
- }
- if (numOperandsInSegments != operands.size())
- return op.emitOpError()
- << keyword << " operand count does not match count in segments";
- if (deviceTypes.getValue().size() != (size_t)segments.size())
- return op.emitOpError()
- << keyword << " segment count does not match device_type count";
- return success();
-}
-
LogicalResult acc::ParallelOp::verify() {
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
*this, getPrivatizations(), getGangPrivateOperands(), "private",
@@ -667,322 +633,11 @@ LogicalResult acc::ParallelOp::verify() {
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
return failure();
-
- if (failed(verifyDeviceTypeAndSegmentCountMatch(
- *this, getNumGangs(), getNumGangsSegmentsAttr(),
- getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
- return failure();
-
- if (failed(verifyDeviceTypeAndSegmentCountMatch(
- *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
- getWaitOperandsDeviceTypeAttr(), "wait")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
- getNumWorkersDeviceTypeAttr(),
- "num_workers")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
- getVectorLengthDeviceTypeAttr(),
- "vector_length")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
- getAsyncDeviceTypeAttr(), "async")))
- return failure();
-
+ if (getNumGangs().size() > 3)
+ return emitOpError() << "num_gangs expects a maximum of 3 values";
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
-static std::optional<unsigned> findSegment(ArrayAttr segments,
- mlir::acc::DeviceType deviceType) {
- unsigned segmentIdx = 0;
- for (auto attr : segments) {
- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
- if (deviceTypeAttr.getValue() == deviceType)
- return std::make_optional(segmentIdx);
- ++segmentIdx;
- }
- return std::nullopt;
-}
-
-static mlir::Value
-getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
- mlir::Operation::operand_range range,
- mlir::acc::DeviceType deviceType) {
- if (!arrayAttr)
- return {};
- if (auto pos = findSegment(*arrayAttr, deviceType))
- return range[*pos];
- return {};
-}
-
-bool acc::ParallelOp::hasAsyncOnly() {
- return hasAsyncOnly(mlir::acc::DeviceType::None);
-}
-
-bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
-}
-
-mlir::Value acc::ParallelOp::getAsyncValue() {
- return getAsyncValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
-}
-
-mlir::Value acc::ParallelOp::getNumWorkersValue() {
- return getNumWorkersValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value
-acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
- deviceType);
-}
-
-mlir::Value acc::ParallelOp::getVectorLengthValue() {
- return getVectorLengthValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value
-acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
- getVectorLength(), deviceType);
-}
-
-mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
- return getNumGangsValues(mlir::acc::DeviceType::None);
-}
-
-static mlir::Operation::operand_range
-getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
- mlir::Operation::operand_range range,
- std::optional<llvm::ArrayRef<int32_t>> segments,
- mlir::acc::DeviceType deviceType) {
- if (!arrayAttr)
- return range.take_front(0);
- if (auto pos = findSegment(*arrayAttr, deviceType)) {
- int32_t nbOperandsBefore = 0;
- for (unsigned i = 0; i < *pos; ++i)
- nbOperandsBefore += (*segments)[i];
- return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
- }
- return range.take_front(0);
-}
-
-mlir::Operation::operand_range
-ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
- getNumGangsSegments(), deviceType);
-}
-
-bool acc::ParallelOp::hasWaitOnly() {
- return hasWaitOnly(mlir::acc::DeviceType::None);
-}
-
-bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
-}
-
-mlir::Operation::operand_range ParallelOp::getWaitValues() {
- return getWaitValues(mlir::acc::DeviceType::None);
-}
-
-mlir::Operation::operand_range
-ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
-}
-
-static ParseResult parseNumGangs(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
- mlir::DenseI32ArrayAttr &segments) {
- llvm::SmallVector<DeviceTypeAttr> attributes;
- llvm::SmallVector<int32_t> seg;
-
- do {
- if (failed(parser.parseLBrace()))
- return failure();
-
- if (failed(parser.parseCommaSeparatedList(
- mlir::AsmParser::Delimiter::None, [&]() {
- if (parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- return success();
- })))
- return failure();
-
- seg.push_back(operands.size());
-
- if (failed(parser.parseRBrace()))
- return failure();
-
- if (succeeded(parser.parseOptionalLSquare())) {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseRSquare())
- return failure();
- } else {
- attributes.push_back(mlir::acc::DeviceTypeAttr::get(
- parser.getContext(), mlir::acc::DeviceType::None));
- }
- } while (succeeded(parser.parseOptionalComma()));
-
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
- segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
-
- return success();
-}
-
-static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands, mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> deviceTypes,
- std::optional<mlir::DenseI32ArrayAttr> segments) {
- unsigned opIdx = 0;
- for (unsigned i = 0; i < deviceTypes->size(); ++i) {
- if (i != 0)
- p << ", ";
- p << "{";
- for (int32_t j = 0; j < (*segments)[i]; ++j) {
- if (j != 0)
- p << ", ";
- p << operands[opIdx] << " : " << operands[opIdx].getType();
- ++opIdx;
- }
- p << "}";
- auto deviceTypeAttr =
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
- if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
- p << " [" << (*deviceTypes)[i] << "]";
- }
-}
-
-static ParseResult parseWaitOperands(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
- mlir::DenseI32ArrayAttr &segments) {
- llvm::SmallVector<DeviceTypeAttr> attributes;
- llvm::SmallVector<int32_t> seg;
-
- do {
- if (failed(parser.parseLBrace()))
- return failure();
-
- if (failed(parser.parseCommaSeparatedList(
- mlir::AsmParser::Delimiter::None, [&]() {
- if (parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- return success();
- })))
- return failure();
-
- seg.push_back(operands.size());
-
- if (failed(parser.parseRBrace()))
- return failure();
-
- if (succeeded(parser.parseOptionalLSquare())) {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseRSquare())
- return failure();
- } else {
- attributes.push_back(mlir::acc::DeviceTypeAttr::get(
- parser.getContext(), mlir::acc::DeviceType::None));
- }
- } while (succeeded(parser.parseOptionalComma()));
-
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
- segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
-
- return success();
-}
-
-static void printWaitOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands,
- mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> deviceTypes,
- std::optional<mlir::DenseI32ArrayAttr> segments) {
- unsigned opIdx = 0;
- for (unsigned i = 0; i < deviceTypes->size(); ++i) {
- if (i != 0)
- p << ", ";
- p << "{";
- for (int32_t j = 0; j < (*segments)[i]; ++j) {
- if (j != 0)
- p << ", ";
- p << operands[opIdx] << " : " << operands[opIdx].getType();
- ++opIdx;
- }
- p << "}";
- auto deviceTypeAttr =
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
- if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
- p << " [" << (*deviceTypes)[i] << "]";
- }
-}
-
-static ParseResult parseDeviceTypeOperands(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
- llvm::SmallVector<DeviceTypeAttr> attributes;
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- if (succeeded(parser.parseOptionalLSquare())) {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseRSquare())
- return failure();
- } else {
- attributes.push_back(mlir::acc::DeviceTypeAttr::get(
- parser.getContext(), mlir::acc::DeviceType::None));
- }
- return success();
- })))
- return failure();
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
- return success();
-}
-
-static void
-printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands, mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> deviceTypes) {
- for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
- if (i != 0)
- p << ", ";
- p << operands[i] << " : " << operands[i].getType();
- auto deviceTypeAttr =
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
- if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
- p << " [" << (*deviceTypes)[i] << "]";
- }
-}
-
//===----------------------------------------------------------------------===//
// SerialOp
//===----------------------------------------------------------------------===//
@@ -993,55 +648,12 @@ unsigned SerialOp::getNumDataOperands() {
}
Value SerialOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync().size();
+ unsigned numOptional = getAsync() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
}
-bool acc::SerialOp::hasAsyncOnly() {
- return hasAsyncOnly(mlir::acc::DeviceType::None);
-}
-
-bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
-}
-
-mlir::Value acc::SerialOp::getAsyncValue() {
- return getAsyncValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
-}
-
-bool acc::SerialOp::hasWaitOnly() {
- return hasWaitOnly(mlir::acc::DeviceType::None);
-}
-
-bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
-}
-
-mlir::Operation::operand_range SerialOp::getWaitValues() {
- return getWaitValues(mlir::acc::DeviceType::None);
-}
-
-mlir::Operation::operand_range
-SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
-}
-
LogicalResult acc::SerialOp::verify() {
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
*this, getPrivatizations(), getGangPrivateOperands(), "private",
@@ -1051,16 +663,6 @@ LogicalResult acc::SerialOp::verify() {
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
return failure();
-
- if (failed(verifyDeviceTypeAndSegmentCountMatch(
- *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
- getWaitOperandsDeviceTypeAttr(), "wait")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
- getAsyncDeviceTypeAttr(), "async")))
- return failure();
-
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}
@@ -1073,114 +675,19 @@ unsigned KernelsOp::getNumDataOperands() {
}
Value KernelsOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsync().size();
+ unsigned numOptional = getAsync() ? 1 : 0;
numOptional += getWaitOperands().size();
numOptional += getNumGangs().size();
- numOptional += getNumWorkers().size();
- numOptional += getVectorLength().size();
+ numOptional += getNumWorkers() ? 1 : 0;
+ numOptional += getVectorLength() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
return getOperand(numOptional + i);
}
-bool acc::KernelsOp::hasAsyncOnly() {
- return hasAsyncOnly(mlir::acc::DeviceType::None);
-}
-
-bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
-}
-
-mlir::Value acc::KernelsOp::getAsyncValue() {
- return getAsyncValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
- deviceType);
-}
-
-mlir::Value acc::KernelsOp::getNumWorkersValue() {
- return getNumWorkersValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value
-acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
- deviceType);
-}
-
-mlir::Value acc::KernelsOp::getVectorLengthValue() {
- return getVectorLengthValue(mlir::acc::DeviceType::None);
-}
-
-mlir::Value
-acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
- return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
- getVectorLength(), deviceType);
-}
-
-mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
- return getNumGangsValues(mlir::acc::DeviceType::None);
-}
-
-mlir::Operation::operand_range
-KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
- getNumGangsSegments(), deviceType);
-}
-
-bool acc::KernelsOp::hasWaitOnly() {
- return hasWaitOnly(mlir::acc::DeviceType::None);
-}
-
-bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
-}
-
-mlir::Operation::operand_range KernelsOp::getWaitValues() {
- return getWaitValues(mlir::acc::DeviceType::None);
-}
-
-mlir::Operation::operand_range
-KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
- return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
- getWaitOperandsSegments(), deviceType);
-}
-
LogicalResult acc::KernelsOp::verify() {
- if (failed(verifyDeviceTypeAndSegmentCountMatch(
- *this, getNumGangs(), getNumGangsSegmentsAttr(),
- getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
- return failure();
-
- if (failed(verifyDeviceTypeAndSegmentCountMatch(
- *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
- getWaitOperandsDeviceTypeAttr(), "wait")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
- getNumWorkersDeviceTypeAttr(),
- "num_workers")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
- getVectorLengthDeviceTypeAttr(),
- "vector_length")))
- return failure();
-
- if (failed(verifyDeviceTypeCountMatch(*this, getAsync(),
- getAsyncDeviceTypeAttr(), "async")))
- return failure();
-
+ if (getNumGangs().size() > 3)
+ return emitOpError() << "num_gangs expects a maximum of 3 values";
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index c18d964b370f2c..b9ac68d0592c87 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -462,8 +462,8 @@ acc.loop gang() {
// -----
%i64value = arith.constant 1 : i64
-// expected-error at +1 {{num_gangs expects a maximum of 3 values per segment}}
-acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64value : i64}) {
+// expected-error at +1 {{num_gangs expects a maximum of 3 values}}
+acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) {
}
// -----
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 5a95811685f845..05b0450c7fb916 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -137,7 +137,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
%pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
%private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32>
- acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
+ acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
acc.loop gang {
scf.for %x = %lb to %c10 step %st {
acc.loop worker {
@@ -180,7 +180,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
// CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64
// CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32>
-// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
+// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
// CHECK-NEXT: acc.loop gang {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: acc.loop worker {
@@ -439,25 +439,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
}
acc.parallel async(%idxValue: index) {
}
- acc.parallel wait({%i64value: i64}) {
+ acc.parallel wait(%i64value: i64) {
}
- acc.parallel wait({%i32value: i32}) {
+ acc.parallel wait(%i32value: i32) {
}
- acc.parallel wait({%idxValue: index}) {
+ acc.parallel wait(%idxValue: index) {
}
- acc.parallel wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
+ acc.parallel wait(%i64value, %i32value, %idxValue : i64, i32, index) {
}
- acc.parallel num_gangs({%i64value: i64}) {
+ acc.parallel num_gangs(%i64value: i64) {
}
- acc.parallel num_gangs({%i32value: i32}) {
+ acc.parallel num_gangs(%i32value: i32) {
}
- acc.parallel num_gangs({%idxValue: index}) {
+ acc.parallel num_gangs(%idxValue: index) {
}
- acc.parallel num_gangs({%i64value: i64, %i64value: i64, %idxValue: index}) {
+ acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) {
}
- acc.parallel num_workers(%i64value: i64 [#acc.device_type<nvidia>]) {
+ acc.parallel num_workers(%i64value: i64) {
}
- acc.parallel num_workers(%i32value: i32 [#acc.device_type<default>]) {
+ acc.parallel num_workers(%i32value: i32) {
}
acc.parallel num_workers(%idxValue: index) {
}
@@ -492,25 +492,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK-NEXT: }
// CHECK: acc.parallel async([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait({[[I64VALUE]] : i64}) {
+// CHECK: acc.parallel wait([[I64VALUE]] : i64) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait({[[I32VALUE]] : i32}) {
+// CHECK: acc.parallel wait([[I32VALUE]] : i32) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait({[[IDXVALUE]] : index}) {
+// CHECK: acc.parallel wait([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
+// CHECK: acc.parallel wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64}) {
+// CHECK: acc.parallel num_gangs([[I64VALUE]] : i64) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs({[[I32VALUE]] : i32}) {
+// CHECK: acc.parallel num_gangs([[I32VALUE]] : i32) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs({[[IDXVALUE]] : index}) {
+// CHECK: acc.parallel num_gangs([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64, [[I64VALUE]] : i64, [[IDXVALUE]] : index}) {
+// CHECK: acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_workers([[I64VALUE]] : i64 [#acc.device_type<nvidia>]) {
+// CHECK: acc.parallel num_workers([[I64VALUE]] : i64) {
// CHECK-NEXT: }
-// CHECK: acc.parallel num_workers([[I32VALUE]] : i32 [#acc.device_type<default>]) {
+// CHECK: acc.parallel num_workers([[I32VALUE]] : i32) {
// CHECK-NEXT: }
// CHECK: acc.parallel num_workers([[IDXVALUE]] : index) {
// CHECK-NEXT: }
@@ -590,13 +590,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
}
acc.serial async(%idxValue: index) {
}
- acc.serial wait({%i64value: i64}) {
+ acc.serial wait(%i64value: i64) {
}
- acc.serial wait({%i32value: i32}) {
+ acc.serial wait(%i32value: i32) {
}
- acc.serial wait({%idxValue: index}) {
+ acc.serial wait(%idxValue: index) {
}
- acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
+ acc.serial wait(%i64value, %i32value, %idxValue : i64, i32, index) {
}
%firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) -> memref<10xf32>
acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) {
@@ -627,13 +627,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// CHECK-NEXT: }
// CHECK: acc.serial async([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait({[[I64VALUE]] : i64}) {
+// CHECK: acc.serial wait([[I64VALUE]] : i64) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait({[[I32VALUE]] : i32}) {
+// CHECK: acc.serial wait([[I32VALUE]] : i32) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait({[[IDXVALUE]] : index}) {
+// CHECK: acc.serial wait([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
+// CHECK: acc.serial wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
// CHECK-NEXT: }
// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) -> memref<10xf32>
// CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
@@ -665,13 +665,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
}
acc.kernels async(%idxValue: index) {
}
- acc.kernels wait({%i64value: i64}) {
+ acc.kernels wait(%i64value: i64) {
}
- acc.kernels wait({%i32value: i32}) {
+ acc.kernels wait(%i32value: i32) {
}
- acc.kernels wait({%idxValue: index}) {
+ acc.kernels wait(%idxValue: index) {
}
- acc.kernels wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
+ acc.kernels wait(%i64value, %i32value, %idxValue : i64, i32, index) {
}
acc.kernels {
} attributes {defaultAttr = #acc<defaultvalue none>}
@@ -699,13 +699,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10
// CHECK-NEXT: }
// CHECK: acc.kernels async([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait({[[I64VALUE]] : i64}) {
+// CHECK: acc.kernels wait([[I64VALUE]] : i64) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait({[[I32VALUE]] : i32}) {
+// CHECK: acc.kernels wait([[I32VALUE]] : i32) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait({[[IDXVALUE]] : index}) {
+// CHECK: acc.kernels wait([[IDXVALUE]] : index) {
// CHECK-NEXT: }
-// CHECK: acc.kernels wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) {
+// CHECK: acc.kernels wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
// CHECK-NEXT: }
// CHECK: acc.kernels {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 13393569f36fe7..2dec4ba3c001e8 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -10,7 +10,6 @@ add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
-add_subdirectory(OpenACC)
add_subdirectory(SCF)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
deleted file mode 100644
index 5133d7fc38296c..00000000000000
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-add_mlir_unittest(MLIROpenACCTests
- OpenACCOpsTest.cpp
-)
-target_link_libraries(MLIROpenACCTests
- PRIVATE
- MLIRIR
- MLIROpenACCDialect
-)
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
deleted file mode 100644
index dcf6c1240c55d3..00000000000000
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ /dev/null
@@ -1,275 +0,0 @@
-//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/OpenACC/OpenACC.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/OwningOpRef.h"
-#include "gtest/gtest.h"
-
-using namespace mlir;
-using namespace mlir::acc;
-
-//===----------------------------------------------------------------------===//
-// Test Fixture
-//===----------------------------------------------------------------------===//
-
-class OpenACCOpsTest : public ::testing::Test {
-protected:
- OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) {
- context.loadDialect<acc::OpenACCDialect, arith::ArithDialect>();
- }
-
- MLIRContext context;
- OpBuilder b;
- Location loc;
- llvm::SmallVector<DeviceType> dtypes = {
- DeviceType::None, DeviceType::Star, DeviceType::Multicore,
- DeviceType::Default, DeviceType::Host, DeviceType::Nvidia,
- DeviceType::Radeon};
- llvm::SmallVector<DeviceType> dtypesWithoutNone = {
- DeviceType::Star, DeviceType::Multicore, DeviceType::Default,
- DeviceType::Host, DeviceType::Nvidia, DeviceType::Radeon};
-};
-
-template <typename Op>
-void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc,
- llvm::SmallVector<DeviceType> &dtypes) {
- Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
- EXPECT_FALSE(op.hasAsyncOnly());
- for (auto d : dtypes)
- EXPECT_FALSE(op.hasAsyncOnly(d));
-
- auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
- op.setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
- EXPECT_TRUE(op.hasAsyncOnly());
- EXPECT_TRUE(op.hasAsyncOnly(DeviceType::None));
- op.removeAsyncOnlyAttr();
-
- auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
- op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
- EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host));
- EXPECT_FALSE(op.hasAsyncOnly());
- op.removeAsyncOnlyAttr();
-
- auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
- op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
- EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Star));
- EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host));
- EXPECT_FALSE(op.hasAsyncOnly());
-}
-
-TEST_F(OpenACCOpsTest, asyncOnlyTest) {
- testAsyncOnly<ParallelOp>(b, context, loc, dtypes);
- testAsyncOnly<KernelsOp>(b, context, loc, dtypes);
- testAsyncOnly<SerialOp>(b, context, loc, dtypes);
-}
-
-template <typename Op>
-void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
- llvm::SmallVector<DeviceType> &dtypes) {
- Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
-
- mlir::Value empty;
- EXPECT_EQ(op.getAsyncValue(), empty);
- for (auto d : dtypes)
- EXPECT_EQ(op.getAsyncValue(d), empty);
-
- mlir::Value val = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
- auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
- op.setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
- op.getAsyncMutable().assign(val);
- EXPECT_EQ(op.getAsyncValue(), empty);
- EXPECT_EQ(op.getAsyncValue(DeviceType::Nvidia), val);
-}
-
-TEST_F(OpenACCOpsTest, asyncValueTest) {
- testAsyncValue<ParallelOp>(b, context, loc, dtypes);
- testAsyncValue<KernelsOp>(b, context, loc, dtypes);
- testAsyncValue<SerialOp>(b, context, loc, dtypes);
-}
-
-template <typename Op>
-void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
- llvm::SmallVector<DeviceType> &dtypes,
- llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
- Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
- EXPECT_EQ(op.getNumGangsValues().begin(), op.getNumGangsValues().end());
-
- mlir::Value val1 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
- mlir::Value val2 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(4));
- auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
- op.getNumGangsMutable().assign(val1);
- op.setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
- op.setNumGangsSegments(b.getDenseI32ArrayAttr({1}));
- EXPECT_EQ(op.getNumGangsValues().front(), val1);
- for (auto d : dtypesWithoutNone)
- EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
-
- op.getNumGangsMutable().clear();
- op.removeNumGangsDeviceTypeAttr();
- op.removeNumGangsSegmentsAttr();
- for (auto d : dtypes)
- EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
-
- op.getNumGangsMutable().append(val1);
- op.getNumGangsMutable().append(val2);
- op.setNumGangsDeviceTypeAttr(
- b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
- DeviceTypeAttr::get(&context, DeviceType::Star)}));
- op.setNumGangsSegments(b.getDenseI32ArrayAttr({1, 1}));
- EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(),
- op.getNumGangsValues(DeviceType::None).end());
- EXPECT_EQ(op.getNumGangsValues(DeviceType::Host).front(), val1);
- EXPECT_EQ(op.getNumGangsValues(DeviceType::Star).front(), val2);
-
- op.getNumGangsMutable().clear();
- op.removeNumGangsDeviceTypeAttr();
- op.removeNumGangsSegmentsAttr();
- for (auto d : dtypes)
- EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end());
-
- op.getNumGangsMutable().append(val1);
- op.getNumGangsMutable().append(val2);
- op.getNumGangsMutable().append(val1);
- op.setNumGangsDeviceTypeAttr(
- b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
- DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
- op.setNumGangsSegments(b.getDenseI32ArrayAttr({2, 1}));
- EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(),
- op.getNumGangsValues(DeviceType::None).end());
- EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).front(), val1);
- EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).drop_front().front(),
- val2);
- EXPECT_EQ(op.getNumGangsValues(DeviceType::Multicore).front(), val1);
-}
-
-TEST_F(OpenACCOpsTest, numGangsValuesTest) {
- testNumGangsValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
- testNumGangsValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
-}
-
-template <typename Op>
-void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc,
- llvm::SmallVector<DeviceType> &dtypes) {
- auto op = b.create<Op>(loc, TypeRange{}, ValueRange{});
-
- mlir::Value empty;
- EXPECT_EQ(op.getVectorLengthValue(), empty);
- for (auto d : dtypes)
- EXPECT_EQ(op.getVectorLengthValue(d), empty);
-
- mlir::Value val = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
- auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
- op.setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
- op.getVectorLengthMutable().assign(val);
- EXPECT_EQ(op.getVectorLengthValue(), empty);
- EXPECT_EQ(op.getVectorLengthValue(DeviceType::Nvidia), val);
-}
-
-TEST_F(OpenACCOpsTest, vectorLengthTest) {
- testVectorLength<ParallelOp>(b, context, loc, dtypes);
- testVectorLength<KernelsOp>(b, context, loc, dtypes);
-}
-
-template <typename Op>
-void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc,
- llvm::SmallVector<DeviceType> &dtypes,
- llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
- Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
- EXPECT_FALSE(op.hasWaitOnly());
- for (auto d : dtypes)
- EXPECT_FALSE(op.hasWaitOnly(d));
-
- auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
- op.setWaitOnlyAttr(b.getArrayAttr({dtypeNone}));
- EXPECT_TRUE(op.hasWaitOnly());
- EXPECT_TRUE(op.hasWaitOnly(DeviceType::None));
- for (auto d : dtypesWithoutNone)
- EXPECT_FALSE(op.hasWaitOnly(d));
- op.removeWaitOnlyAttr();
-
- auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
- op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost}));
- EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host));
- EXPECT_FALSE(op.hasWaitOnly());
- op.removeWaitOnlyAttr();
-
- auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
- op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
- EXPECT_TRUE(op.hasWaitOnly(DeviceType::Star));
- EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host));
- EXPECT_FALSE(op.hasWaitOnly());
-}
-
-TEST_F(OpenACCOpsTest, waitOnlyTest) {
- testWaitOnly<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
- testWaitOnly<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
- testWaitOnly<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
-}
-
-template <typename Op>
-void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
- llvm::SmallVector<DeviceType> &dtypes,
- llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
- Op op = b.create<Op>(loc, TypeRange{}, ValueRange{});
- EXPECT_EQ(op.getWaitValues().begin(), op.getWaitValues().end());
-
- mlir::Value val1 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1));
- mlir::Value val2 = b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(4));
- auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
- op.getWaitOperandsMutable().assign(val1);
- op.setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
- op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1}));
- EXPECT_EQ(op.getWaitValues().front(), val1);
- for (auto d : dtypesWithoutNone)
- EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
-
- op.getWaitOperandsMutable().clear();
- op.removeWaitOperandsDeviceTypeAttr();
- op.removeWaitOperandsSegmentsAttr();
- for (auto d : dtypes)
- EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
-
- op.getWaitOperandsMutable().append(val1);
- op.getWaitOperandsMutable().append(val2);
- op.setWaitOperandsDeviceTypeAttr(
- b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
- DeviceTypeAttr::get(&context, DeviceType::Star)}));
- op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1}));
- EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(),
- op.getWaitValues(DeviceType::None).end());
- EXPECT_EQ(op.getWaitValues(DeviceType::Host).front(), val1);
- EXPECT_EQ(op.getWaitValues(DeviceType::Star).front(), val2);
-
- op.getWaitOperandsMutable().clear();
- op.removeWaitOperandsDeviceTypeAttr();
- op.removeWaitOperandsSegmentsAttr();
- for (auto d : dtypes)
- EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end());
-
- op.getWaitOperandsMutable().append(val1);
- op.getWaitOperandsMutable().append(val2);
- op.getWaitOperandsMutable().append(val1);
- op.setWaitOperandsDeviceTypeAttr(
- b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
- DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
- op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1}));
- EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(),
- op.getWaitValues(DeviceType::None).end());
- EXPECT_EQ(op.getWaitValues(DeviceType::Default).front(), val1);
- EXPECT_EQ(op.getWaitValues(DeviceType::Default).drop_front().front(), val2);
- EXPECT_EQ(op.getWaitValues(DeviceType::Multicore).front(), val1);
-}
-
-TEST_F(OpenACCOpsTest, waitValuesTest) {
- testWaitValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
- testWaitValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
- testWaitValues<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
-}
More information about the llvm-branch-commits
mailing list