[flang-commits] [flang] [mlir] [mlir][openacc] Add device_type support for operation (PR #76126)
via flang-commits
flang-commits at lists.llvm.org
Wed Dec 20 22:31:39 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Following #<!-- -->75864, this patch adds device_type support to the data operation on the async and wait operands and attributes.
---
Full diff: https://github.com/llvm/llvm-project/pull/76126.diff
5 Files Affected:
- (modified) flang/lib/Lower/OpenACC.cpp (+88-37)
- (modified) flang/test/Lower/OpenACC/acc-data.f90 (+4-4)
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+38-9)
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+42-1)
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+4-4)
``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecf70818c4ac0f..d10e56e5d11779 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1464,6 +1464,24 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
}
}
+static void
+genAsyncClause(Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::AccClause::Async *asyncClause,
+ llvm::SmallVector<mlir::Value> &async,
+ llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
+ llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
+ mlir::acc::DeviceTypeAttr deviceTypeAttr,
+ Fortran::lower::StatementContext &stmtCtx) {
+ const auto &asyncClauseValue = asyncClause->v;
+ if (asyncClauseValue) { // async has a value.
+ async.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+ asyncDeviceTypes.push_back(deviceTypeAttr);
+ } else {
+ asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
+ }
+}
+
static mlir::acc::DeviceType
getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
switch (device) {
@@ -1533,6 +1551,39 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}
+static void
+genWaitClause(Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::AccClause::Wait *waitClause,
+ llvm::SmallVector<mlir::Value> &waitOperands,
+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+ llvm::SmallVector<int32_t> &waitOperandsSegments,
+ mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
+ Fortran::lower::StatementContext &stmtCtx) {
+ const auto &waitClauseValue = waitClause->v;
+ if (waitClauseValue) { // wait has a value.
+ const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+ const auto &waitList =
+ std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+ auto crtWaitOperands = waitOperands.size();
+ for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+ waitOperands.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(value), stmtCtx)));
+ }
+ waitOperandsDeviceTypes.push_back(deviceTypeAttr);
+ waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+
+ // TODO: move to device_type model.
+ const auto &waitDevnumValue =
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+ if (waitDevnumValue)
+ waitDevnum = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
+ } else {
+ waitOnlyDeviceTypes.push_back(deviceTypeAttr);
+ }
+}
+
static mlir::acc::LoopOp
createLoopOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
@@ -1795,6 +1846,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
+ mlir::Value waitDevnum; // TODO not yet implemented on compute op.
// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
@@ -1818,31 +1870,14 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- const auto &asyncClauseValue = asyncClause->v;
- if (asyncClauseValue) { // async has a value.
- async.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
- asyncDeviceTypes.push_back(crtDeviceTypeAttr);
- } else {
- asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
- }
+ genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- const auto &waitClauseValue = waitClause->v;
- if (waitClauseValue) { // wait has a value.
- const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
- const auto &waitList =
- std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- auto crtWaitOperands = waitOperands.size();
- for (const Fortran::parser::ScalarIntExpr &value : waitList) {
- waitOperands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(value), stmtCtx)));
- }
- waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
- waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
- } else {
- waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
- }
+ genWaitClause(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
+ stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
@@ -2126,21 +2161,24 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
- mlir::Value ifCond, async, waitDevnum;
+ mlir::Value ifCond, waitDevnum;
llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
- copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands;
-
- // Async and wait have an optional value but can be present with
- // no value as well. When there is no value, the op has an attribute to
- // represent the clause.
- bool addAsyncAttr = false;
- bool addWaitAttr = false;
+ copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
+ async;
+ llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<int32_t> waitOperandsSegments;
bool hasDefaultNone = false;
bool hasDefaultPresent = false;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ // device_type attribute is set to `none` until a device_type clause is
+ // encountered.
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None);
+
// Lower clauses values mapped to operands.
// Keep track of each group of operands separately as clauses can appear
// more than once.
@@ -2221,11 +2259,14 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands.end());
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+ genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands, waitDevnum,
- addWaitAttr, stmtCtx);
+ genWaitClause(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
+ stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2239,7 +2280,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> operands;
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
- addOperand(operands, operandSegments, async);
+ addOperands(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);
@@ -2250,8 +2291,18 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
builder, currentLocation, eval, operands, operandSegments);
- dataOp.setAsyncAttr(addAsyncAttr);
- dataOp.setWaitAttr(addWaitAttr);
+ if (!asyncDeviceTypes.empty())
+ dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ if (!asyncOnlyDeviceTypes.empty())
+ dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+ if (!waitOperandsDeviceTypes.empty())
+ dataOp.setWaitOperandsDeviceTypeAttr(
+ builder.getArrayAttr(waitOperandsDeviceTypes));
+ if (!waitOperandsSegments.empty())
+ dataOp.setWaitOperandsSegmentsAttr(
+ builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!waitOnlyDeviceTypes.empty())
+ dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
if (hasDefaultNone)
dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index a6572e14707606..75ffd1fc3fcab2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -153,7 +153,7 @@ subroutine acc_data
!$acc end data
! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {asyncAttr}
+! CHECK: } attributes {asyncOnly = [#acc.device_type<none>]}
!$acc data present(a) async(1)
!$acc end data
@@ -165,18 +165,18 @@ subroutine acc_data
!$acc end data
! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitAttr}
+! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
!$acc data present(a) wait(1)
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) wait(%{{.*}} : i32) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({%{{.*}} : i32}) {
! CHECK: }{{$}}
!$acc data present(a) wait(devnum: 0: 1)
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait(%{{.*}} : i32) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
! CHECK: }{{$}}
!$acc data default(none)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 234c1076e14e3b..d84dfbc1613caa 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1235,13 +1235,16 @@ def OpenACC_DataOp : OpenACC_Op<"data",
let arguments = (ins Optional<I1>:$ifCond,
- Optional<IntOrIndex>:$async,
- UnitAttr:$asyncAttr,
- Optional<IntOrIndex>:$waitDevnum,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$waitAttr,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ Variadic<IntOrIndex>:$async,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+ Optional<IntOrIndex>:$waitDevnum,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ OptionalAttr<DefaultValueAttr>:$defaultAttr);
let regions = (region AnyRegion:$region);
@@ -1251,15 +1254,41 @@ def OpenACC_DataOp : OpenACC_Op<"data",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsyncOnly();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWaitOnly();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWaitOnly(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
- | `async` `(` $async `:` type($async) `)`
+ | `async` `(` custom<DeviceTypeOperands>($async,
+ type($async), $asyncDeviceType) `)`
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `wait` `(` custom<WaitOperands>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType,
+ $waitOperandsSegments) `)`
)
$region attr-dict-with-keyword
}];
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 45e0632db5ef2b..449157033e91ca 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1418,11 +1418,52 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
Value DataOp::getDataOperand(unsigned i) {
unsigned numOptional = getIfCond() ? 1 : 0;
- numOptional += getAsync() ? 1 : 0;
+ numOptional += getAsync().size() ? 1 : 0;
numOptional += getWaitOperands().size();
return getOperand(numOptional + i);
}
+bool acc::DataOp::hasAsyncOnly() {
+ return hasAsyncOnly(mlir::acc::DeviceType::None);
+}
+
+bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getAsyncOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Value DataOp::getAsyncValue() {
+ return getAsyncValue(mlir::acc::DeviceType::None);
+}
+
+mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
+ return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
+ deviceType);
+}
+
+bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
+
+bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
+ if (auto arrayAttr = getWaitOnly()) {
+ if (findSegment(*arrayAttr, deviceType))
+ return true;
+ }
+ return false;
+}
+
+mlir::Operation::operand_range DataOp::getWaitValues() {
+ return getWaitValues(mlir::acc::DeviceType::None);
+}
+
+mlir::Operation::operand_range
+DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
+ return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
+ getWaitOperandsSegments(), deviceType);
+}
+
//===----------------------------------------------------------------------===//
// ExitDataOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 5a95811685f845..52375b1af3141c 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -836,11 +836,11 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
%w1 = arith.constant 1 : i64
- acc.data wait(%w1 : i64) {
+ acc.data wait({%w1 : i64}) {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
%wd1 = arith.constant 1 : i64
- acc.data wait_devnum(%wd1 : i64) wait(%w1 : i64) {
+ acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
return
@@ -951,10 +951,10 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
// CHECK: acc.data {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
-// CHECK: acc.data wait(%{{.*}} : i64) {
+// CHECK: acc.data wait({%{{.*}} : i64}) {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
-// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait(%{{.*}} : i64) {
+// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/76126
More information about the flang-commits
mailing list