[Mlir-commits] [mlir] [flang] [mlir][flang][openacc] Add device_type support for update op (PR #78764)
Valentin Clement バレンタイン クレメン
llvmlistbot at llvm.org
Fri Jan 19 13:16:02 PST 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/78764
>From 427a6f93269fb50a7d1d29045e42e802131f9b92 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 18 Jan 2024 10:09:31 -0800
Subject: [PATCH 1/2] [mlir][flang][openacc] Add device_type support for update
op
---
flang/lib/Lower/OpenACC.cpp | 80 +++---
flang/test/Lower/OpenACC/acc-update.f90 | 21 +-
.../mlir/Dialect/OpenACC/OpenACCOps.td | 49 +++-
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 229 +++++++++++++++---
mlir/test/Dialect/OpenACC/invalid.mlir | 4 +-
mlir/test/Dialect/OpenACC/ops.mlir | 12 +-
6 files changed, 287 insertions(+), 108 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 682ca06cabd6f6..541ea2e114324f 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2840,27 +2840,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
}
}
+static inline mlir::ArrayAttr
+getArrayAttr(fir::FirOpBuilder &b,
+ llvm::SmallVector<mlir::Attribute> &attributes) {
+ return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
+}
+
+static inline mlir::DenseI32ArrayAttr
+getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
+ llvm::SmallVector<int32_t> &values) {
+ return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
+}
+
static void
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
- mlir::Value ifCond, async, waitDevnum;
+ mlir::Value ifCond, waitDevnum;
llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
- waitOperands, deviceTypeOperands;
- llvm::SmallVector<mlir::Attribute> deviceTypes;
-
- // Async and wait clause have optional values but can be present with
- // no value as well. When there is no value, the op has an attribute to
- // represent the clause.
- bool addAsyncAttr = false;
- bool addWaitAttr = false;
- bool addIfPresentAttr = false;
+ waitOperands, deviceTypeOperands, asyncOperands;
+ llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
+ asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<int32_t> waitOperandsSegments;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- // Lower clauses values mapped to operands.
+ // device_type attribute is set to `none` until a device_type clause is
+ // encountered.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
+
+ bool ifPresent = false;
+
+ // Lower clauses values mapped to operands and array attributes.
// Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2870,15 +2885,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+ genAsyncClause(converter, asyncClause, asyncOperands,
+ asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands, waitDevnum,
- addWaitAttr, stmtCtx);
+ genWaitClause(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
} else if (const auto *hostClause =
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2892,7 +2911,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
/*implicit=*/false);
} else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
- addIfPresentAttr = true;
+ ifPresent = true;
} else if (const auto *selfClause =
std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -2909,30 +2928,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands.append(updateHostOperands);
- // Prepare the operand segment size attribute and the operands value range.
- llvm::SmallVector<mlir::Value> operands;
- llvm::SmallVector<int32_t> operandSegments;
- addOperand(operands, operandSegments, ifCond);
- addOperand(operands, operandSegments, async);
- addOperand(operands, operandSegments, waitDevnum);
- addOperands(operands, operandSegments, waitOperands);
- addOperands(operands, operandSegments, dataClauseOperands);
-
- mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
- builder, currentLocation, operands, operandSegments);
- if (!deviceTypes.empty())
- updateOp.setDeviceTypesAttr(
- mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
+ builder.create<mlir::acc::UpdateOp>(
+ currentLocation, ifCond, asyncOperands,
+ getArrayAttr(builder, asyncOperandsDeviceTypes),
+ getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+ getDenseI32ArrayAttr(builder, waitOperandsSegments),
+ getArrayAttr(builder, waitOperandsDeviceTypes),
+ getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
+ ifPresent);
genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
builder, updateHostOperands, /*structured=*/false);
-
- if (addAsyncAttr)
- updateOp.setAsyncAttr(builder.getUnitAttr());
- if (addWaitAttr)
- updateOp.setWaitAttr(builder.getUnitAttr());
- if (addIfPresentAttr)
- updateOp.setIfPresentAttr(builder.getUnitAttr());
}
static void
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index d2b15f8bd258e7..ac7a56c56b1f20 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -61,17 +61,17 @@ subroutine acc_update
!$acc update host(a) async
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+! CHECK: acc.update async dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) async wait
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+! CHECK: acc.update async wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) async(1)
@@ -89,14 +89,14 @@ subroutine acc_update
!$acc update host(a) wait(1)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait(queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
@@ -104,17 +104,12 @@ subroutine acc_update
! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
- !$acc update host(a) device_type(default, host)
+ !$acc update host(a) device_type(host, nvidia) async
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]}
-! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
-
- !$acc update host(a) device_type(*)
-! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]}
+! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
end subroutine acc_update
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 7344ab2852b9ce..5b678e84b93ee4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2187,14 +2187,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
}];
let arguments = (ins Optional<I1>:$ifCond,
- Optional<IntOrIndex>:$asyncOperand,
- Optional<IntOrIndex>:$waitDevnum,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$async,
- UnitAttr:$wait,
- OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- UnitAttr:$ifPresent);
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$async,
+ Optional<IntOrIndex>:$waitDevnum,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$wait,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ UnitAttr:$ifPresent);
let extraClassDeclaration = [{
/// The number of data operands.
@@ -2202,14 +2204,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsync();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsync(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWait();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWait(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
- | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
+ | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+ $asyncOperands, type($asyncOperands),
+ $asyncOperandsDeviceType, $async)
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `wait` `` custom<WaitClause>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType,
+ $waitOperandsSegments, $wait)
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
)
attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..4e31f7b163b9dc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -936,6 +936,138 @@ static void printDeviceTypeOperandsWithSegment(
});
}
+static ParseResult parseWaitClause(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+ mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+ llvm::SmallVector<int32_t> seg;
+
+ bool needCommaBeforeOperands = false;
+
+ // Keyword only
+ if (failed(parser.parseOptionalLParen())) {
+ keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+ return success();
+ }
+
+ // Parse keyword only attributes
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(keywordAttrs.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+ needCommaBeforeOperands = true;
+ }
+
+ if (needCommaBeforeOperands && failed(parser.parseComma()))
+ return failure();
+
+ do {
+ if (failed(parser.parseLBrace()))
+ return failure();
+
+ int32_t crtOperandsSize = operands.size();
+
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ seg.push_back(operands.size() - crtOperandsSize);
+
+ if (failed(parser.parseRBrace()))
+ return failure();
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (failed(parser.parseRParen()))
+ return failure();
+
+ deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+ keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+ segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+ return success();
+}
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+ if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+ return true;
+ return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
+
+ p << "[";
+ llvm::interleaveComma(*deviceTypes, p,
+ [&](mlir::Attribute attr) { p << attr; });
+ p << "]";
+}
+
+static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
+ if (!hasDeviceTypeValues(attrs))
+ return false;
+ if (attrs->size() != 1)
+ return false;
+ if (auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
+ return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
+ return false;
+}
+
+static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands, mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> deviceTypes,
+ std::optional<mlir::DenseI32ArrayAttr> segments,
+ std::optional<mlir::ArrayAttr> keywordOnly) {
+
+ if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
+ return;
+
+ p << "(";
+
+ printDeviceTypes(p, keywordOnly);
+ if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
+ p << ", ";
+
+ unsigned opIdx = 0;
+ llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
+ p << "{";
+ llvm::interleaveComma(
+ llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
+ p << operands[opIdx] << " : " << operands[opIdx].getType();
+ ++opIdx;
+ });
+ p << "}";
+ printSingleDeviceType(p, it.value());
+ });
+
+ p << ")";
+}
+
static ParseResult parseDeviceTypeOperands(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -966,6 +1098,8 @@ static void
printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
mlir::OperandRange operands, mlir::TypeRange types,
std::optional<mlir::ArrayAttr> deviceTypes) {
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
p << std::get<1>(it) << " : " << std::get<1>(it).getType();
printSingleDeviceType(p, std::get<0>(it));
@@ -1033,35 +1167,14 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
return success();
}
-static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
- if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
- return true;
- return false;
-}
-
-static void printDeviceTypes(mlir::OpAsmPrinter &p,
- std::optional<mlir::ArrayAttr> deviceTypes) {
- if (!hasDeviceTypeValues(deviceTypes))
- return;
-
- p << "[";
- llvm::interleaveComma(*deviceTypes, p,
- [&](mlir::Attribute attr) { p << attr; });
- p << "]";
-}
-
static void printDeviceTypeOperandsWithKeywordOnly(
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
- if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
- keywordOnlyDeviceTypes->size() == 1) {
- auto deviceTypeAttr =
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
- if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
- return;
- }
+ if (operands.begin() == operands.end() &&
+ hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes))
+ return;
p << "(";
printDeviceTypes(p, keywordOnlyDeviceTypes);
@@ -1450,13 +1563,8 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
if (operands.begin() == operands.end() &&
- hasDeviceTypeValues(gangOnlyDeviceTypes) &&
- gangOnlyDeviceTypes->size() == 1) {
- auto deviceTypeAttr =
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
- if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
- return;
- }
+ hasOnlyDeviceTypeNone(gangOnlyDeviceTypes))
+ return;
p << "(";
printDeviceTypes(p, gangOnlyDeviceTypes);
@@ -2416,15 +2524,20 @@ LogicalResult acc::UpdateOp::verify() {
if (getDataClauseOperands().empty())
return emitError("at least one value must be present in dataOperands");
- // The async attribute represent the async clause without value. Therefore the
- // attribute and operand cannot appear at the same time.
- if (getAsyncOperand() && getAsync())
- return emitError("async attribute cannot appear with asyncOperand");
+ for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+ ++dtypeInt) {
+ auto dtype = static_cast<acc::DeviceType>(dtypeInt);
- // The wait attribute represent the wait clause without values. Therefore the
- // attribute and operands cannot appear at the same time.
- if (!getWaitOperands().empty() && getWait())
- return emitError("wait attribute cannot appear with waitOperands");
+ // The async attribute represent the async clause without value. Therefore
+ // the attribute and operand cannot appear at the same time.
+ if (getAsyncValue(dtype) && hasAsync(dtype))
+ return emitError("async attribute cannot appear with asyncOperand");
+
+ // The wait attribute represent the wait clause without values. Therefore
+ // the attribute and operands cannot appear at the same time.
+ if (!getWaitValues(dtype).empty() && hasWait(dtype))
+ return emitError("wait attribute cannot appear with waitOperands");
+ }
if (getWaitDevnum() && getWaitOperands().empty())
return emitError("wait_devnum cannot appear without waitOperands");
@@ -2443,7 +2556,7 @@ unsigned UpdateOp::getNumDataOperands() {
}
Value UpdateOp::getDataOperand(unsigned i) {
- unsigned numOptional = getAsyncOperand() ? 1 : 0;
+ unsigned numOptional = getAsyncOperands().size();
numOptional += getWaitDevnum() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
return getOperand(getWaitOperands().size() + numOptional + i);
@@ -2454,6 +2567,42 @@ void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveConstantIfCondition<UpdateOp>>(context);
}
+bool UpdateOp::hasAsync() { return hasAsync(mlir::acc::DeviceType::None); }
+
+bool UpdateOp::hasAsync(mlir::acc::DeviceType deviceType) {
+ return hasDeviceType(getAsync(), deviceType);
+}
+
+mlir::Value UpdateOp::getAsyncValue() {
+ return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
+ return {};
+
+ if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
+ return getAsyncOperands()[*pos];
+
+ return {};
+}
+
+bool UpdateOp::hasWait() { return hasWait(mlir::acc::DeviceType::None); }
+
+bool UpdateOp::hasWait(mlir::acc::DeviceType deviceType) {
+ return hasDeviceType(getWait(), deviceType);
+}
+
+mlir::Operation::operand_range UpdateOp::getWaitValues() {
+ return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), deviceType);
+}
+
//===----------------------------------------------------------------------===//
// WaitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 5dcdb3a37e4e3b..1fd06fe8cab752 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -122,7 +122,7 @@ acc.update wait_devnum(%cst: index) dataOperands(%0: memref<f32>)
%value = memref.alloc() : memref<f32>
%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
// expected-error at +1 {{async attribute cannot appear with asyncOperand}}
-acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async}
+acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async = [#acc.device_type<none>]}
// -----
@@ -130,7 +130,7 @@ acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async}
%value = memref.alloc() : memref<f32>
%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
// expected-error at +1 {{wait attribute cannot appear with waitOperands}}
-acc.update wait(%cst: index) dataOperands(%0: memref<f32>) attributes {wait}
+acc.update wait({%cst: index}) dataOperands(%0: memref<f32>) attributes {wait = [#acc.device_type<none>]}
// -----
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 99b44183758d95..f1603a21a14238 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -972,12 +972,12 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
- acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) dataOperands(%0: memref<f32>)
+ acc.update wait_devnum(%i64Value: i64) wait({%i32Value : i32, %idxValue : index}) dataOperands(%0: memref<f32>)
acc.update if(%ifCond) dataOperands(%0: memref<f32>)
acc.update dataOperands(%0: memref<f32>) attributes {acc.device_types = [#acc.device_type<star>]}
acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
- acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {async}
- acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
+ acc.update async dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
+ acc.update wait dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {ifPresent}
return
}
@@ -991,12 +991,12 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
// CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
-// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) dataOperands(%{{.*}} : memref<f32>)
+// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait({[[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
// CHECK: acc.update dataOperands(%{{.*}} : memref<f32>) attributes {acc.device_types = [#acc.device_type<star>]}
// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
-// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {async}
-// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
+// CHECK: acc.update async dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
+// CHECK: acc.update wait dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {ifPresent}
// -----
>From 8efff129a04c4a188f4c3e5f738923fe0e5a2712 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 19 Jan 2024 13:15:31 -0800
Subject: [PATCH 2/2] Switch hasAsync to hasAsyncOnly and hasWait to
hasWaitOnly
---
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td | 8 ++++----
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 12 ++++++++----
2 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 5b678e84b93ee4..f904aac8c37263 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2207,9 +2207,9 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
/// Return true if the op has the async attribute for the
/// mlir::acc::DeviceType::None device_type.
- bool hasAsync();
+ bool hasAsyncOnly();
/// Return true if the op has the async attribute for the given device_type.
- bool hasAsync(mlir::acc::DeviceType deviceType);
+ bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
/// Return the value of the async clause if present.
mlir::Value getAsyncValue();
/// Return the value of the async clause for the given device_type if
@@ -2218,9 +2218,9 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
/// Return true if the op has the wait attribute for the
/// mlir::acc::DeviceType::None device_type.
- bool hasWait();
+ bool hasWaitOnly();
/// Return true if the op has the wait attribute for the given device_type.
- bool hasWait(mlir::acc::DeviceType deviceType);
+ bool hasWaitOnly(mlir::acc::DeviceType deviceType);
/// Return the values of the wait clause if present.
mlir::Operation::operand_range getWaitValues();
/// Return the values of the wait clause for the given device_type if
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 4e31f7b163b9dc..80ce419971759e 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2567,9 +2567,11 @@ void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveConstantIfCondition<UpdateOp>>(context);
}
-bool UpdateOp::hasAsync() { return hasAsync(mlir::acc::DeviceType::None); }
+bool UpdateOp::hasAsyncOnly() {
+ return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
-bool UpdateOp::hasAsync(mlir::acc::DeviceType deviceType) {
+bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
return hasDeviceType(getAsync(), deviceType);
}
@@ -2587,9 +2589,11 @@ mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
return {};
}
-bool UpdateOp::hasWait() { return hasWait(mlir::acc::DeviceType::None); }
+bool UpdateOp::hasWaitOnly() {
+ return hasWaitOnly(mlir::acc::DeviceType::None);
+}
-bool UpdateOp::hasWait(mlir::acc::DeviceType deviceType) {
+bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
return hasDeviceType(getWait(), deviceType);
}
More information about the Mlir-commits
mailing list