[Mlir-commits] [mlir] [flang][mlir][openacc] Switch device_type representation to an enum (PR #70250)
Valentin Clement バレンタイン クレメン
llvmlistbot at llvm.org
Wed Oct 25 13:29:27 PDT 2023
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/70250
Switch the representation from scalar integer to a enumeration. The parser transform the string in the input to the correct enumeration.
>From 605777b63a74aebc1c50948764e4e9a3ac6341c7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Oct 2023 13:26:17 -0700
Subject: [PATCH] [flang][mlir][openacc] Switch device_type representation to
an enum
---
flang/include/flang/Parser/dump-parse-tree.h | 1 +
flang/include/flang/Parser/parse-tree.h | 4 +-
flang/lib/Lower/OpenACC.cpp | 93 +++++++++++--------
flang/lib/Parser/openacc-parsers.cpp | 12 ++-
flang/test/Lower/OpenACC/acc-init.f90 | 10 +-
flang/test/Lower/OpenACC/acc-set.f90 | 8 +-
flang/test/Lower/OpenACC/acc-shutdown.f90 | 6 +-
flang/test/Lower/OpenACC/acc-update.f90 | 9 +-
flang/test/Semantics/OpenACC/acc-data.f90 | 2 +-
.../Semantics/OpenACC/acc-init-validity.f90 | 8 +-
.../Semantics/OpenACC/acc-kernels-loop.f90 | 4 +-
flang/test/Semantics/OpenACC/acc-kernels.f90 | 4 +-
flang/test/Semantics/OpenACC/acc-parallel.f90 | 6 +-
.../Semantics/OpenACC/acc-set-validity.f90 | 8 +-
.../OpenACC/acc-shutdown-validity.f90 | 8 +-
.../Semantics/OpenACC/acc-update-validity.f90 | 2 +-
.../mlir/Dialect/OpenACC/OpenACCOps.td | 45 ++++++---
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 5 +-
mlir/test/Dialect/OpenACC/invalid.mlir | 2 +-
mlir/test/Dialect/OpenACC/ops.mlir | 20 ++--
20 files changed, 144 insertions(+), 113 deletions(-)
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 494e54faa64c841..7c479a2334ea555 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -101,6 +101,7 @@ class ParseTreeDumper {
NODE(parser, AccSelfClause)
NODE(parser, AccStandaloneDirective)
NODE(parser, AccDeviceTypeExpr)
+ NODE_ENUM(parser::AccDeviceTypeExpr, Device)
NODE(parser, AccDeviceTypeExprList)
NODE(parser, AccTileExpr)
NODE(parser, AccTileExprList)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 83c8db936934a03..4806fc49f3441de 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4061,9 +4061,9 @@ struct AccWaitArgument {
};
struct AccDeviceTypeExpr {
- TUPLE_CLASS_BOILERPLATE(AccDeviceTypeExpr);
+ ENUM_CLASS(Device, Star, Default, Nvidia, Radeon, Host, Multicore)
+ WRAPPER_CLASS_BOILERPLATE(AccDeviceTypeExpr, Device);
CharBlock source;
- std::tuple<std::optional<ScalarIntExpr>> t; // if null then *
};
struct AccDeviceTypeExprList {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 90644279d9e78ce..024c38c96277313 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -150,7 +150,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -219,7 +219,7 @@ static void createDeclareDeallocFuncWithArg(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1416,27 +1416,35 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
}
}
-static void genDeviceTypeClause(
- Fortran::lower::AbstractConverter &converter, mlir::Location clauseLocation,
+static mlir::acc::DeviceType
+getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
+ switch (device) {
+ case Fortran::parser::AccDeviceTypeExpr::Device::Star:
+ return mlir::acc::DeviceType::Star;
+ case Fortran::parser::AccDeviceTypeExpr::Device::Default:
+ return mlir::acc::DeviceType::Default;
+ case Fortran::parser::AccDeviceTypeExpr::Device::Nvidia:
+ return mlir::acc::DeviceType::Nvidia;
+ case Fortran::parser::AccDeviceTypeExpr::Device::Radeon:
+ return mlir::acc::DeviceType::Radeon;
+ case Fortran::parser::AccDeviceTypeExpr::Device::Host:
+ return mlir::acc::DeviceType::Host;
+ case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
+ return mlir::acc::DeviceType::Multicore;
+ }
+ return mlir::acc::DeviceType::Default;
+}
+
+static void gatherDeviceTypeAttrs(
+ fir::FirOpBuilder &builder, mlir::Location clauseLocation,
const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVector<mlir::Attribute> &deviceTypes,
Fortran::lower::StatementContext &stmtCtx) {
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
deviceTypeClause->v;
- for (const auto &deviceTypeExpr : deviceTypeExprList.v) {
- const auto &expr = std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
- deviceTypeExpr.t);
- if (expr) {
- operands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(expr), stmtCtx, &clauseLocation)));
- } else {
- // * was passed as value and will be represented as a special constant.
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::Value star = firOpBuilder.createIntegerConstant(
- clauseLocation, firOpBuilder.getIndexType(), starCst);
- operands.push_back(star);
- }
- }
+ for (const auto &deviceTypeExpr : deviceTypeExprList.v)
+ deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), getDeviceType(deviceTypeExpr.v)));
}
static void genIfClause(Fortran::lower::AbstractConverter &converter,
@@ -2428,10 +2436,10 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::AccClauseList &accClauseList) {
mlir::Value ifCond, deviceNum;
- llvm::SmallVector<mlir::Value> deviceTypeOperands;
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
// Lower clauses values mapped to operands.
// Keep track of each group of operands separately as clauses can appear
@@ -2449,19 +2457,23 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
- deviceTypeOperands, stmtCtx);
+ gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+ deviceTypes, stmtCtx);
}
}
// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value, 6> operands;
- llvm::SmallVector<int32_t, 3> operandSegments;
- addOperands(operands, operandSegments, deviceTypeOperands);
+ llvm::SmallVector<int32_t, 2> operandSegments;
+
addOperand(operands, operandSegments, deviceNum);
addOperand(operands, operandSegments, ifCond);
- createSimpleOp<Op>(firOpBuilder, currentLocation, operands, operandSegments);
+ Op op =
+ createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
+ if (!deviceTypes.empty())
+ op.setDeviceTypesAttr(
+ mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
}
void genACCSetOp(Fortran::lower::AbstractConverter &converter,
@@ -2470,8 +2482,9 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
mlir::Value ifCond, deviceNum, defaultAsync;
llvm::SmallVector<mlir::Value> deviceTypeOperands;
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
// Lower clauses values mapped to operands.
// Keep track of each group of operands separately as clauses can appear
@@ -2494,21 +2507,22 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
- deviceTypeOperands, stmtCtx);
+ gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+ deviceTypes, stmtCtx);
}
}
// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value> operands;
- llvm::SmallVector<int32_t, 4> operandSegments;
- addOperands(operands, operandSegments, deviceTypeOperands);
+ llvm::SmallVector<int32_t, 3> operandSegments;
addOperand(operands, operandSegments, defaultAsync);
addOperand(operands, operandSegments, deviceNum);
addOperand(operands, operandSegments, ifCond);
- createSimpleOp<mlir::acc::SetOp>(firOpBuilder, currentLocation, operands,
- operandSegments);
+ auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
+ operandSegments);
+ if (!deviceTypes.empty())
+ op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
}
static void
@@ -2520,6 +2534,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
mlir::Value ifCond, async, waitDevnum;
llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
waitOperands, deviceTypeOperands;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
// Async and wait clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
@@ -2548,8 +2563,8 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
- deviceTypeOperands, stmtCtx);
+ gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+ deviceTypes, stmtCtx);
} else if (const auto *hostClause =
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2587,11 +2602,13 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
addOperand(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
- addOperands(operands, operandSegments, deviceTypeOperands);
addOperands(operands, operandSegments, dataClauseOperands);
mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
builder, currentLocation, operands, operandSegments);
+ if (!deviceTypes.empty())
+ updateOp.setDeviceTypesAttr(
+ mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
builder, updateHostOperands, /*structured=*/false, /*implicit=*/false);
@@ -2772,7 +2789,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -2848,7 +2865,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp
index 131f7332a69701a..5b9267e0e17c6db 100644
--- a/flang/lib/Parser/openacc-parsers.cpp
+++ b/flang/lib/Parser/openacc-parsers.cpp
@@ -53,9 +53,15 @@ TYPE_PARSER(construct<AccSizeExpr>(scalarIntExpr) ||
construct<AccSizeExpr>("*" >> construct<std::optional<ScalarIntExpr>>()))
TYPE_PARSER(construct<AccSizeExprList>(nonemptyList(Parser<AccSizeExpr>{})))
-TYPE_PARSER(construct<AccDeviceTypeExpr>(scalarIntExpr) ||
- construct<AccDeviceTypeExpr>(
- "*" >> construct<std::optional<ScalarIntExpr>>()))
+TYPE_PARSER(sourced(construct<AccDeviceTypeExpr>(
+ first("*" >> pure(AccDeviceTypeExpr::Device::Star),
+ "DEFAULT" >> pure(AccDeviceTypeExpr::Device::Default),
+ "NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
+ "ACC_DEVICE_NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
+ "RADEON" >> pure(AccDeviceTypeExpr::Device::Radeon),
+ "HOST" >> pure(AccDeviceTypeExpr::Device::Host),
+ "MULTICORE" >> pure(AccDeviceTypeExpr::Device::Multicore)))))
+
TYPE_PARSER(
construct<AccDeviceTypeExprList>(nonemptyList(Parser<AccDeviceTypeExpr>{})))
diff --git a/flang/test/Lower/OpenACC/acc-init.f90 b/flang/test/Lower/OpenACC/acc-init.f90
index de940426b6f1c0b..d1fd638c7ac0e8a 100644
--- a/flang/test/Lower/OpenACC/acc-init.f90
+++ b/flang/test/Lower/OpenACC/acc-init.f90
@@ -4,6 +4,7 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
subroutine acc_init
+ implicit none
logical :: ifCondition = .TRUE.
integer :: ifInt = 1
@@ -23,15 +24,16 @@ subroutine acc_init
!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
!CHECK: acc.init device_num([[DEVNUM]] : i32){{$}}
- !$acc init device_num(1) device_type(1, 2)
+ !$acc init device_num(1) device_type(host, multicore)
!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-!CHECK: acc.init device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}
+!CHECK: acc.init device_num([[DEVNUM]] : i32) attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
!$acc init if(ifInt)
!CHECK: %[[IFINT:.*]] = fir.load %{{.*}} : !fir.ref<i32>
!CHECK: %[[CONV:.*]] = fir.convert %[[IFINT]] : (i32) -> i1
!CHECK: acc.init if(%[[CONV]])
+ !$acc init device_type(nvidia)
+!CHECK: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
+
end subroutine acc_init
diff --git a/flang/test/Lower/OpenACC/acc-set.f90 b/flang/test/Lower/OpenACC/acc-set.f90
index 52baedeafecb2bb..39bf26e0072b7ca 100644
--- a/flang/test/Lower/OpenACC/acc-set.f90
+++ b/flang/test/Lower/OpenACC/acc-set.f90
@@ -14,7 +14,7 @@ program test_acc_set
!$acc set device_type(*)
-!$acc set device_type(0)
+!$acc set device_type(multicore)
end
@@ -34,10 +34,8 @@ program test_acc_set
! CHECK: %[[C0:.*]] = arith.constant 0 : i32
! CHECK: acc.set device_num(%[[C0]] : i32)
-! CHECK: %[[C_1:.*]] = arith.constant -1 : index
-! CHECK: acc.set device_type(%[[C_1]] : index)
+! CHECK: acc.set attributes {device_type = #acc.device_type<*>}
-! CHECK: %[[C0:.*]] = arith.constant 0 : i32
-! CHECK: acc.set device_type(%[[C0]] : i32)
+! CHECK: acc.set attributes {device_type = #acc.device_type<multicore>}
diff --git a/flang/test/Lower/OpenACC/acc-shutdown.f90 b/flang/test/Lower/OpenACC/acc-shutdown.f90
index 49e1acc546d900c..f63f5d62b4fe921 100644
--- a/flang/test/Lower/OpenACC/acc-shutdown.f90
+++ b/flang/test/Lower/OpenACC/acc-shutdown.f90
@@ -22,10 +22,8 @@ subroutine acc_shutdown
!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
!CHECK: acc.shutdown device_num([[DEVNUM]] : i32){{$}}
- !$acc shutdown device_num(1) device_type(1, 2)
+ !$acc shutdown device_num(1) device_type(default, nvidia)
!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-!CHECK: acc.shutdown device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}
+!CHECK: acc.shutdown device_num([[DEVNUM]] : i32) attributes {device_types = [#acc.device_type<default>, #acc.device_type<nvidia>]}
end subroutine acc_shutdown
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index f7343a69285f85a..5d5f5733ef7f1a6 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -145,20 +145,17 @@ subroutine acc_update
! FIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
- !$acc update host(a) device_type(1, 2)
+ !$acc update host(a) device_type(default, host)
! FIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! HLFIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-! CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]}
! FIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) device_type(*)
! FIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! HLFIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[DEVTYPE3:%.*]] = arith.constant -1 : index
-! CHECK: acc.update device_type([[DEVTYPE3]] : index) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<*>]}
! FIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
diff --git a/flang/test/Semantics/OpenACC/acc-data.f90 b/flang/test/Semantics/OpenACC/acc-data.f90
index 17e0624b8cf24d4..1a7a6f95f3d891e 100644
--- a/flang/test/Semantics/OpenACC/acc-data.f90
+++ b/flang/test/Semantics/OpenACC/acc-data.f90
@@ -184,7 +184,7 @@ program openacc_data_validity
!$acc data copy(aa) wait
!$acc end data
- !$acc data copy(aa) device_type(1) wait
+ !$acc data copy(aa) device_type(default) wait
!$acc end data
end program openacc_data_validity
diff --git a/flang/test/Semantics/OpenACC/acc-init-validity.f90 b/flang/test/Semantics/OpenACC/acc-init-validity.f90
index f54898f73fdce28..3b594a25217c094 100644
--- a/flang/test/Semantics/OpenACC/acc-init-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-init-validity.f90
@@ -20,9 +20,9 @@ program openacc_init_validity
!$acc init if(ifInt)
!$acc init device_num(1)
!$acc init device_num(i)
- !$acc init device_type(i)
- !$acc init device_type(2, i, j)
- !$acc init device_num(i) device_type(i, j) if(ifCondition)
+ !$acc init device_type(default)
+ !$acc init device_type(nvidia, radeon)
+ !$acc init device_num(i) device_type(host, multicore) if(ifCondition)
!$acc parallel
!ERROR: Directive INIT may not be called within a compute region
@@ -94,7 +94,7 @@ program openacc_init_validity
!$acc init device_num(1) device_num(i)
!ERROR: At most one DEVICE_TYPE clause can appear on the INIT directive
- !$acc init device_type(2) device_type(i, j)
+ !$acc init device_type(nvidia) device_type(default, *)
!ERROR: Must have LOGICAL or INTEGER type
!$acc init if(ifReal)
diff --git a/flang/test/Semantics/OpenACC/acc-kernels-loop.f90 b/flang/test/Semantics/OpenACC/acc-kernels-loop.f90
index 5facd4737788030..1a280f7c54f5cd3 100644
--- a/flang/test/Semantics/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Semantics/OpenACC/acc-kernels-loop.f90
@@ -264,12 +264,12 @@ program openacc_kernels_loop_validity
a(i) = 3.14
end do
- !$acc kernels loop device_type(1)
+ !$acc kernels loop device_type(multicore)
do i = 1, N
a(i) = 3.14
end do
- !$acc kernels loop device_type(1, 3)
+ !$acc kernels loop device_type(host, multicore)
do i = 1, N
a(i) = 3.14
end do
diff --git a/flang/test/Semantics/OpenACC/acc-kernels.f90 b/flang/test/Semantics/OpenACC/acc-kernels.f90
index a2c9c9e8be99b17..de220f7c7ddf7cf 100644
--- a/flang/test/Semantics/OpenACC/acc-kernels.f90
+++ b/flang/test/Semantics/OpenACC/acc-kernels.f90
@@ -122,10 +122,10 @@ program openacc_kernels_validity
!$acc kernels device_type(*)
!$acc end kernels
- !$acc kernels device_type(1)
+ !$acc kernels device_type(default)
!$acc end kernels
- !$acc kernels device_type(1, 3)
+ !$acc kernels device_type(default, host)
!$acc end kernels
!$acc kernels device_type(*) async wait num_gangs(8) num_workers(8) vector_length(128)
diff --git a/flang/test/Semantics/OpenACC/acc-parallel.f90 b/flang/test/Semantics/OpenACC/acc-parallel.f90
index e85922e37c63e06..0e8d240d019983f 100644
--- a/flang/test/Semantics/OpenACC/acc-parallel.f90
+++ b/flang/test/Semantics/OpenACC/acc-parallel.f90
@@ -111,10 +111,10 @@ program openacc_parallel_validity
!$acc parallel device_type(*)
!$acc end parallel
- !$acc parallel device_type(1)
+ !$acc parallel device_type(default)
!$acc end parallel
- !$acc parallel device_type(1, 3)
+ !$acc parallel device_type(default, host)
!$acc end parallel
!ERROR: Clause PRIVATE is not allowed after clause DEVICE_TYPE on the PARALLEL directive
@@ -131,7 +131,7 @@ program openacc_parallel_validity
!$acc parallel device_type(*) num_gangs(8)
!$acc end parallel
- !$acc parallel device_type(1) async device_type(2) wait
+ !$acc parallel device_type(*) async device_type(host) wait
!$acc end parallel
!ERROR: Clause IF is not allowed after clause DEVICE_TYPE on the PARALLEL directive
diff --git a/flang/test/Semantics/OpenACC/acc-set-validity.f90 b/flang/test/Semantics/OpenACC/acc-set-validity.f90
index 896e39df6535c9a..74522b30d11bc49 100644
--- a/flang/test/Semantics/OpenACC/acc-set-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-set-validity.f90
@@ -90,17 +90,17 @@ program openacc_clause_validity
!$acc set device_num(1) device_num(i)
!ERROR: At most one DEVICE_TYPE clause can appear on the SET directive
- !$acc set device_type(i) device_type(2)
+ !$acc set device_type(*) device_type(nvidia)
!$acc set default_async(2)
!$acc set default_async(i)
!$acc set device_num(1)
!$acc set device_num(i)
- !$acc set device_type(i)
- !$acc set device_num(1) default_async(2) device_type(2)
+ !$acc set device_type(default)
+ !$acc set device_num(1) default_async(2) device_type(*)
!ERROR: The DEVICE_TYPE clause on the SET directive accepts only one value
- !$acc set device_type(1, 2)
+ !$acc set device_type(*, default)
!ERROR: At least one of DEFAULT_ASYNC, DEVICE_NUM, DEVICE_TYPE clause must appear on the SET directive
!$acc set
diff --git a/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90 b/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90
index de40963f99e0486..43aed4fc98f42eb 100644
--- a/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90
@@ -80,9 +80,9 @@ program openacc_shutdown_validity
!$acc shutdown if(ifCondition)
!$acc shutdown device_num(1)
!$acc shutdown device_num(i)
- !$acc shutdown device_type(i)
- !$acc shutdown device_type(2, i, j)
- !$acc shutdown device_num(i) device_type(i, j) if(ifCondition)
+ !$acc shutdown device_type(*)
+ !$acc shutdown device_type(*, default, host)
+ !$acc shutdown device_num(i) device_type(default, host) if(ifCondition)
!ERROR: At most one IF clause can appear on the SHUTDOWN directive
!$acc shutdown if(.TRUE.) if(ifCondition)
@@ -91,6 +91,6 @@ program openacc_shutdown_validity
!$acc shutdown device_num(1) device_num(i)
!ERROR: At most one DEVICE_TYPE clause can appear on the SHUTDOWN directive
- !$acc shutdown device_type(2) device_type(i, j)
+ !$acc shutdown device_type(*) device_type(host, default)
end program openacc_shutdown_validity
diff --git a/flang/test/Semantics/OpenACC/acc-update-validity.f90 b/flang/test/Semantics/OpenACC/acc-update-validity.f90
index a409ba5ea549f80..1e75742e63e97b1 100644
--- a/flang/test/Semantics/OpenACC/acc-update-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-update-validity.f90
@@ -53,7 +53,7 @@ program openacc_update_validity
!$acc update host(bb) device_type(*) wait
- !$acc update self(cc) device_type(1,2) async device_type(3) wait
+ !$acc update self(cc) device_type(host,multicore) async device_type(*) wait
!ERROR: At most one IF clause can appear on the UPDATE directive
!$acc update device(aa) if(.true.) if(ifCondition)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 10018c9fc7e27e8..3c5173fdda7f66a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -155,6 +155,30 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> {
let assemblyFormat = "`<` struct(params) `>`";
}
+// Device type enumeration.
+def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 0, "*">;
+def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 1, "default">;
+def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 2, "host">;
+def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">;
+def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 4, "nvidia">;
+def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 5, "radeon">;
+
+
+def OpenACC_DeviceType : I32EnumAttr<"DeviceType",
+ "built-in device type supported by OpenACC",
+ [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
+ OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore,
+ OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::acc";
+}
+def OpenACC_DeviceTypeAttr : EnumAttr<OpenACC_Dialect,
+ OpenACC_DeviceType,
+ "device_type"> {
+ let assemblyFormat = [{ ```<` $value `>` }];
+}
+
// Used for data specification in data clauses (2.7.1).
// Either (or both) extent and upperbound must be specified.
def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
@@ -1624,14 +1648,12 @@ def OpenACC_InitOp : OpenACC_Op<"init", [AttrSizedOperandSegments]> {
```
}];
- let arguments = (ins Variadic<AnyInteger>:$deviceTypeOperands,
+ let arguments = (ins OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
Optional<IntOrIndex>:$deviceNumOperand,
Optional<I1>:$ifCond);
let assemblyFormat = [{
- oilist(
- `device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)`
- | `device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
+ oilist(`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
| `if` `(` $ifCond `)`
) attr-dict-with-keyword
}];
@@ -1657,13 +1679,12 @@ def OpenACC_ShutdownOp : OpenACC_Op<"shutdown", [AttrSizedOperandSegments]> {
```
}];
- let arguments = (ins Variadic<AnyInteger>:$deviceTypeOperands,
+ let arguments = (ins OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
Optional<IntOrIndex>:$deviceNumOperand,
Optional<I1>:$ifCond);
let assemblyFormat = [{
- oilist(`device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)`
- |`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
+ oilist(`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
|`if` `(` $ifCond `)`
) attr-dict-with-keyword
}];
@@ -1687,15 +1708,13 @@ def OpenACC_SetOp : OpenACC_Op<"set", [AttrSizedOperandSegments]> {
```
}];
- let arguments = (ins Optional<IntOrIndex>:$deviceType,
+ let arguments = (ins OptionalAttr<OpenACC_DeviceTypeAttr>:$device_type,
Optional<IntOrIndex>:$defaultAsync,
Optional<IntOrIndex>:$deviceNum,
Optional<I1>:$ifCond);
let assemblyFormat = [{
- oilist(
- `device_type` `(` $deviceType `:` type($deviceType) `)`
- | `default_async` `(` $defaultAsync `:` type($defaultAsync) `)`
+ oilist(`default_async` `(` $defaultAsync `:` type($defaultAsync) `)`
| `device_num` `(` $deviceNum `:` type($deviceNum) `)`
| `if` `(` $ifCond `)`
) attr-dict-with-keyword
@@ -1729,7 +1748,7 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
Variadic<IntOrIndex>:$waitOperands,
UnitAttr:$async,
UnitAttr:$wait,
- Variadic<IntOrIndex>:$deviceTypeOperands,
+ OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
UnitAttr:$ifPresent);
@@ -1746,8 +1765,6 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
`if` `(` $ifCond `)`
| `async` `(` $asyncOperand `:` type($asyncOperand) `)`
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `device_type` `(` $deviceTypeOperands `:`
- type($deviceTypeOperands) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index b7e2aec6a4e6a92..639f0b85450f591 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1238,7 +1238,7 @@ LogicalResult acc::SetOp::verify() {
while ((currOp = currOp->getParentOp()))
if (isComputeOperation(currOp))
return emitOpError("cannot be nested in a compute operation");
- if (!getDeviceType() && !getDefaultAsync() && !getDeviceNum())
+ if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
return emitOpError("at least one default_async, device_num, or device_type "
"operand must appear");
return success();
@@ -1283,8 +1283,7 @@ Value UpdateOp::getDataOperand(unsigned i) {
unsigned numOptional = getAsyncOperand() ? 1 : 0;
numOptional += getWaitDevnum() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
- return getOperand(getWaitOperands().size() + getDeviceTypeOperands().size() +
- numOptional + i);
+ return getOperand(getWaitOperands().size() + numOptional + i);
}
void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index ff92eab478bb4f5..b5241a8e4dc47fa 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -475,7 +475,7 @@ acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i6
%i64value = arith.constant 1 : i64
acc.parallel {
// expected-error at +1 {{'acc.set' op cannot be nested in a compute operation}}
- acc.set device_type(%i64value : i64)
+ acc.set attributes {device_type = #acc.device_type<nvidia>}
acc.yield
}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index d1950b1fb3f2916..cf7a838f55ef855 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -974,7 +974,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) dataOperands(%0: memref<f32>)
acc.update if(%ifCond) dataOperands(%0: memref<f32>)
- acc.update device_type(%i32Value : i32) dataOperands(%0: memref<f32>)
+ acc.update dataOperands(%0: memref<f32>) attributes {acc.device_types = [#acc.device_type<nvidia>]}
acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {async}
acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
@@ -993,7 +993,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
// CHECK: acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
-// CHECK: acc.update device_type([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
+// CHECK: acc.update dataOperands(%{{.*}} : memref<f32>) attributes {acc.device_types = [#acc.device_type<nvidia>]}
// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {async}
// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
@@ -1047,8 +1047,7 @@ acc.wait if(%ifCond)
%idxValue = arith.constant 1 : index
%ifCond = arith.constant true
acc.init
-acc.init device_type(%i32Value : i32)
-acc.init device_type(%i32Value, %i32Value2 : i32, i32)
+acc.init attributes {acc.device_types = [#acc.device_type<nvidia>]}
acc.init device_num(%i64Value : i64)
acc.init device_num(%i32Value : i32)
acc.init device_num(%idxValue : index)
@@ -1062,8 +1061,7 @@ acc.init device_num(%idxValue : index) if(%ifCond)
// CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
// CHECK: [[IFCOND:%.*]] = arith.constant true
// CHECK: acc.init
-// CHECK: acc.init device_type([[I32VALUE]] : i32)
-// CHECK: acc.init device_type([[I32VALUE]], [[I32VALUE2]] : i32, i32)
+// CHECK: acc.init attributes {acc.device_types = [#acc.device_type<nvidia>]}
// CHECK: acc.init device_num([[I64VALUE]] : i64)
// CHECK: acc.init device_num([[I32VALUE]] : i32)
// CHECK: acc.init device_num([[IDXVALUE]] : index)
@@ -1079,8 +1077,7 @@ acc.init device_num(%idxValue : index) if(%ifCond)
%idxValue = arith.constant 1 : index
%ifCond = arith.constant true
acc.shutdown
-acc.shutdown device_type(%i32Value : i32)
-acc.shutdown device_type(%i32Value, %i32Value2 : i32, i32)
+acc.shutdown attributes {acc.device_types = [#acc.device_type<default>]}
acc.shutdown device_num(%i64Value : i64)
acc.shutdown device_num(%i32Value : i32)
acc.shutdown device_num(%idxValue : index)
@@ -1094,8 +1091,7 @@ acc.shutdown device_num(%idxValue : index) if(%ifCond)
// CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
// CHECK: [[IFCOND:%.*]] = arith.constant true
// CHECK: acc.shutdown
-// CHECK: acc.shutdown device_type([[I32VALUE]] : i32)
-// CHECK: acc.shutdown device_type([[I32VALUE]], [[I32VALUE2]] : i32, i32)
+// CHECK: acc.shutdown attributes {acc.device_types = [#acc.device_type<default>]}
// CHECK: acc.shutdown device_num([[I64VALUE]] : i64)
// CHECK: acc.shutdown device_num([[I32VALUE]] : i32)
// CHECK: acc.shutdown device_num([[IDXVALUE]] : index)
@@ -1718,7 +1714,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
%i32Value2 = arith.constant 2 : i32
%idxValue = arith.constant 1 : index
%ifCond = arith.constant true
-acc.set device_type(%i32Value : i32)
+acc.set attributes {device_type = #acc.device_type<nvidia>}
acc.set device_num(%i64Value : i64)
acc.set device_num(%i32Value : i32)
acc.set device_num(%idxValue : index)
@@ -1730,7 +1726,7 @@ acc.set default_async(%i32Value : i32)
// CHECK: [[I32VALUE2:%.*]] = arith.constant 2 : i32
// CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
// CHECK: [[IFCOND:%.*]] = arith.constant true
-// CHECK: acc.set device_type([[I32VALUE]] : i32)
+// CHECK: acc.set attributes {device_type = #acc.device_type<nvidia>}
// CHECK: acc.set device_num([[I64VALUE]] : i64)
// CHECK: acc.set device_num([[I32VALUE]] : i32)
// CHECK: acc.set device_num([[IDXVALUE]] : index)
More information about the Mlir-commits
mailing list