[flang-commits] [mlir] [flang] [mlir][flang][openacc] Add device_type support for update op (PR #78764)
via flang-commits
flang-commits at lists.llvm.org
Fri Jan 19 11:19:44 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir-openacc
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Add support for device_type information on the acc.update operation and update lowering from Flang.
---
Patch is 29.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78764.diff
6 Files Affected:
- (modified) flang/lib/Lower/OpenACC.cpp (+43-37)
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+8-13)
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+39-10)
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+189-40)
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+2-2)
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+6-6)
``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 682ca06cabd6f6..541ea2e114324f 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2840,27 +2840,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
}
}
+static inline mlir::ArrayAttr
+getArrayAttr(fir::FirOpBuilder &b,
+ llvm::SmallVector<mlir::Attribute> &attributes) {
+ return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
+}
+
+static inline mlir::DenseI32ArrayAttr
+getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
+ llvm::SmallVector<int32_t> &values) {
+ return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
+}
+
static void
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
- mlir::Value ifCond, async, waitDevnum;
+ mlir::Value ifCond, waitDevnum;
llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
- waitOperands, deviceTypeOperands;
- llvm::SmallVector<mlir::Attribute> deviceTypes;
-
- // Async and wait clause have optional values but can be present with
- // no value as well. When there is no value, the op has an attribute to
- // represent the clause.
- bool addAsyncAttr = false;
- bool addWaitAttr = false;
- bool addIfPresentAttr = false;
+ waitOperands, deviceTypeOperands, asyncOperands;
+ llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
+ asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<int32_t> waitOperandsSegments;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- // Lower clauses values mapped to operands.
+ // device_type attribute is set to `none` until a device_type clause is
+ // encountered.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
+
+ bool ifPresent = false;
+
+ // Lower clauses values mapped to operands and array attributes.
// Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2870,15 +2885,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+ genAsyncClause(converter, asyncClause, asyncOperands,
+ asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands, waitDevnum,
- addWaitAttr, stmtCtx);
+ genWaitClause(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
} else if (const auto *hostClause =
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2892,7 +2911,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
/*implicit=*/false);
} else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
- addIfPresentAttr = true;
+ ifPresent = true;
} else if (const auto *selfClause =
std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -2909,30 +2928,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands.append(updateHostOperands);
- // Prepare the operand segment size attribute and the operands value range.
- llvm::SmallVector<mlir::Value> operands;
- llvm::SmallVector<int32_t> operandSegments;
- addOperand(operands, operandSegments, ifCond);
- addOperand(operands, operandSegments, async);
- addOperand(operands, operandSegments, waitDevnum);
- addOperands(operands, operandSegments, waitOperands);
- addOperands(operands, operandSegments, dataClauseOperands);
-
- mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
- builder, currentLocation, operands, operandSegments);
- if (!deviceTypes.empty())
- updateOp.setDeviceTypesAttr(
- mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
+ builder.create<mlir::acc::UpdateOp>(
+ currentLocation, ifCond, asyncOperands,
+ getArrayAttr(builder, asyncOperandsDeviceTypes),
+ getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+ getDenseI32ArrayAttr(builder, waitOperandsSegments),
+ getArrayAttr(builder, waitOperandsDeviceTypes),
+ getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
+ ifPresent);
genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
builder, updateHostOperands, /*structured=*/false);
-
- if (addAsyncAttr)
- updateOp.setAsyncAttr(builder.getUnitAttr());
- if (addWaitAttr)
- updateOp.setWaitAttr(builder.getUnitAttr());
- if (addIfPresentAttr)
- updateOp.setIfPresentAttr(builder.getUnitAttr());
}
static void
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index d2b15f8bd258e7..ac7a56c56b1f20 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -61,17 +61,17 @@ subroutine acc_update
!$acc update host(a) async
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+! CHECK: acc.update async dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) async wait
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+! CHECK: acc.update async wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) async(1)
@@ -89,14 +89,14 @@ subroutine acc_update
!$acc update host(a) wait(1)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait(queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
@@ -104,17 +104,12 @@ subroutine acc_update
! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
- !$acc update host(a) device_type(default, host)
+ !$acc update host(a) device_type(host, nvidia) async
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]}
-! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
-
- !$acc update host(a) device_type(*)
-! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]}
+! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
end subroutine acc_update
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 7344ab2852b9ce..5b678e84b93ee4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2187,14 +2187,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
}];
let arguments = (ins Optional<I1>:$ifCond,
- Optional<IntOrIndex>:$asyncOperand,
- Optional<IntOrIndex>:$waitDevnum,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$async,
- UnitAttr:$wait,
- OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- UnitAttr:$ifPresent);
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$async,
+ Optional<IntOrIndex>:$waitDevnum,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$wait,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ UnitAttr:$ifPresent);
let extraClassDeclaration = [{
/// The number of data operands.
@@ -2202,14 +2204,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsync();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsync(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWait();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWait(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
- | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
+ | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+ $asyncOperands, type($asyncOperands),
+ $asyncOperandsDeviceType, $async)
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `wait` `` custom<WaitClause>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType,
+ $waitOperandsSegments, $wait)
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
)
attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..4e31f7b163b9dc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -936,6 +936,138 @@ static void printDeviceTypeOperandsWithSegment(
});
}
+static ParseResult parseWaitClause(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+ mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+ llvm::SmallVector<int32_t> seg;
+
+ bool needCommaBeforeOperands = false;
+
+ // Keyword only
+ if (failed(parser.parseOptionalLParen())) {
+ keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+ return success();
+ }
+
+ // Parse keyword only attributes
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(keywordAttrs.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+ needCommaBeforeOperands = true;
+ }
+
+ if (needCommaBeforeOperands && failed(parser.parseComma()))
+ return failure();
+
+ do {
+ if (failed(parser.parseLBrace()))
+ return failure();
+
+ int32_t crtOperandsSize = operands.size();
+
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ seg.push_back(operands.size() - crtOperandsSize);
+
+ if (failed(parser.parseRBrace()))
+ return failure();
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (failed(parser.parseRParen()))
+ return failure();
+
+ deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+ keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+ segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+ return success();
+}
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+ if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+ return true;
+ return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
+
+ p << "[";
+ llvm::interleaveComma(*deviceTypes, p,
+ [&](mlir::Attribute attr) { p << attr; });
+ p << "]";
+}
+
+static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
+ if (!hasDeviceTypeValues(attrs))
+ return false;
+ if (attrs->size() != 1)
+ return false;
+ if (auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
+ return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
+ return false;
+}
+
+static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands, mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> deviceTypes,
+ std::optional<mlir::DenseI32ArrayAttr> segments,
+ std::optional<mlir::ArrayAttr> keywordOnly) {
+
+ if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
+ return;
+
+ p << "(";
+
+ printDeviceTypes(p, keywordOnly);
+ if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
+ p << ", ";
+
+ unsigned opIdx = 0;
+ llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
+ p << "{";
+ llvm::interleaveComma(
+ llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
+ p << operands[opIdx] << " : " << operands[opIdx].getType();
+ ++opIdx;
+ });
+ p << "}";
+ printSingleDeviceType(p, it.value());
+ });
+
+ p << ")";
+}
+
static ParseResult parseDeviceTypeOperands(
mlir::OpAsmP...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/78764
More information about the flang-commits
mailing list