[flang-commits] [flang] [mlir] [flang][openacc] Support multiple device_type when lowering (PR #78634)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Jan 18 21:05:22 PST 2024
Valentin Clement =?utf-8?b?KOODkOODrOODsw=?Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/78634 at github.com>
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/78634
>From 81a54990280d797b7da912249eaf8dd818540c0b Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 18 Jan 2024 13:47:26 -0800
Subject: [PATCH 1/2] [flang][openacc] Support multiple device_type when
lowering
---
flang/lib/Lower/OpenACC.cpp | 262 +++++++++++--------
flang/test/Lower/OpenACC/acc-device-type.f90 | 4 +
flang/test/Lower/OpenACC/acc-loop.f90 | 6 +
flang/test/Lower/OpenACC/acc-routine.f90 | 5 +
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 5 +-
5 files changed, 175 insertions(+), 107 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fd89d27db74dc0..682ca06cabd6f6 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1470,15 +1470,19 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> &async,
llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
- mlir::acc::DeviceTypeAttr deviceTypeAttr,
+ llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
- async.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
- asyncDeviceTypes.push_back(deviceTypeAttr);
+ mlir::Value asyncValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
+ for (auto deviceTypeAttr : deviceTypeAttrs) {
+ async.push_back(asyncValue);
+ asyncDeviceTypes.push_back(deviceTypeAttr);
+ }
} else {
- asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
+ for (auto deviceTypeAttr : deviceTypeAttrs)
+ asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
}
}
@@ -1504,10 +1508,9 @@ getDeviceType(Fortran::common::OpenACCDeviceType device) {
}
static void gatherDeviceTypeAttrs(
- fir::FirOpBuilder &builder, mlir::Location clauseLocation,
+ fir::FirOpBuilder &builder,
const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
- llvm::SmallVector<mlir::Attribute> &deviceTypes,
- Fortran::lower::StatementContext &stmtCtx) {
+ llvm::SmallVector<mlir::Attribute> &deviceTypes) {
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
deviceTypeClause->v;
for (const auto &deviceTypeExpr : deviceTypeExprList.v)
@@ -1560,20 +1563,25 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<int32_t> &waitOperandsSegments,
- mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
+ mlir::Value &waitDevnum,
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- auto crtWaitOperands = waitOperands.size();
+ llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
- waitOperands.push_back(fir::getBase(converter.genExprValue(
+ waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
- waitOperandsDeviceTypes.push_back(deviceTypeAttr);
- waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+ for (auto deviceTypeAttr : deviceTypeAttrs) {
+ for (auto value : waitValues)
+ waitOperands.push_back(value);
+ waitOperandsDeviceTypes.push_back(deviceTypeAttr);
+ waitOperandsSegments.push_back(waitValues.size());
+ }
// TODO: move to device_type model.
const auto &waitDevnumValue =
@@ -1582,7 +1590,8 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
- waitOnlyDeviceTypes.push_back(deviceTypeAttr);
+ for (auto deviceTypeAttr : deviceTypeAttrs)
+ waitOnlyDeviceTypes.push_back(deviceTypeAttr);
}
}
@@ -1610,91 +1619,112 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
// device_type attribute is set to `none` until a device_type clause is
// encountered.
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), mlir::acc::DeviceType::None);
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *gangClause =
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
if (gangClause->v) {
- auto crtGangOperands = gangOperands.size();
const Fortran::parser::AccGangArgList &x = *gangClause->v;
+ mlir::SmallVector<mlir::Value> gangValues;
+ mlir::SmallVector<mlir::Attribute> gangArgs;
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
if (const auto *num =
std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
- gangOperands.push_back(fir::getBase(converter.genExprValue(
+ gangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(num->v), stmtCtx)));
- gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+ gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
builder.getContext(), mlir::acc::GangArgType::Num));
} else if (const auto *staticArg =
std::get_if<Fortran::parser::AccGangArg::Static>(
&gangArg.u)) {
const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
if (sizeExpr.v) {
- gangOperands.push_back(fir::getBase(converter.genExprValue(
+ gangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
} else {
// * was passed as value and will be represented as a special
// constant.
- gangOperands.push_back(builder.createIntegerConstant(
+ gangValues.push_back(builder.createIntegerConstant(
clauseLocation, builder.getIndexType(), starCst));
}
- gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+ gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
builder.getContext(), mlir::acc::GangArgType::Static));
} else if (const auto *dim =
std::get_if<Fortran::parser::AccGangArg::Dim>(
&gangArg.u)) {
- gangOperands.push_back(fir::getBase(converter.genExprValue(
+ gangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(dim->v), stmtCtx)));
- gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+ gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
builder.getContext(), mlir::acc::GangArgType::Dim));
}
}
- gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
- gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
+ gangOperands.push_back(std::get<0>(pair));
+ gangArgTypes.push_back(std::get<1>(pair));
+ }
+ gangOperandsSegments.push_back(gangValues.size());
+ gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else {
- gangDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ gangDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *workerClause =
std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
if (workerClause->v) {
- workerNumOperands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
- workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ workerNumOperands.push_back(workerNumValue);
+ workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else {
- workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *vectorClause =
std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
if (vectorClause->v) {
- vectorOperands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
- vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value vectorValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ vectorOperands.push_back(vectorValue);
+ vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else {
- vectorDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ vectorDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *tileClause =
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
- auto crtTileOperands = tileOperands.size();
+ llvm::SmallVector<mlir::Value> tileValues;
for (const auto &accTileExpr : accTileExprList.v) {
const auto &expr =
std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
accTileExpr.t);
if (expr) {
- tileOperands.push_back(fir::getBase(converter.genExprValue(
+ tileValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*expr), stmtCtx)));
} else {
// * was passed as value and will be represented as a special
// constant.
mlir::Value tileStar = builder.createIntegerConstant(
clauseLocation, builder.getIntegerType(32), starCst);
- tileOperands.push_back(tileStar);
+ tileValues.push_back(tileStar);
}
}
- tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
- tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ for (auto value : tileValues)
+ tileOperands.push_back(value);
+ tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ tileOperandsSegments.push_back(tileValues.size());
+ }
} else if (const auto *privateClause =
std::get_if<Fortran::parser::AccClause::Private>(
&clause.u)) {
@@ -1707,21 +1737,20 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
reductionOperands, reductionRecipes);
} else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
- seqDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ seqDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Independent>(
&clause.u)) {
- independentDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ independentDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
- autoDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ autoDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
- deviceTypeClause->v;
- assert(deviceTypeExprList.v.size() == 1 &&
- "expect only one device_type expr");
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
} else if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
@@ -1729,14 +1758,18 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
const auto &force = std::get<bool>(arg.t);
if (force)
TODO(clauseLocation, "OpenACC collapse force modifier");
+
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
const std::optional<int64_t> collapseValue =
Fortran::evaluate::ToInt64(*expr);
assert(collapseValue && "expect integer value for the collapse clause");
- collapseValues.push_back(*collapseValue);
- collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ collapseValues.push_back(*collapseValue);
+ collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
}
}
@@ -1923,45 +1956,56 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
// device_type attribute is set to `none` until a device_type clause is
// encountered.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None);
+ crtDeviceTypes.push_back(crtDeviceTypeAttr);
- // Lower clauses values mapped to operands.
- // Keep track of each group of operands separatly as clauses can appear
+ // Lower clauses values mapped to operands and array attributes.
+ // Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+ asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
- stmtCtx);
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
- auto crtNumGangs = numGangs.size();
+ llvm::SmallVector<mlir::Value> numGangValues;
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
- numGangs.push_back(fir::getBase(converter.genExprValue(
+ numGangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(expr), stmtCtx)));
- numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
- numGangsSegments.push_back(numGangs.size() - crtNumGangs);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ for (auto value : numGangValues)
+ numGangs.push_back(value);
+ numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+ numGangsSegments.push_back(numGangValues.size());
+ }
} else if (const auto *numWorkersClause =
std::get_if<Fortran::parser::AccClause::NumWorkers>(
&clause.u)) {
- numWorkers.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
- numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ numWorkers.push_back(numWorkerValue);
+ numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto *vectorLengthClause =
std::get_if<Fortran::parser::AccClause::VectorLength>(
&clause.u)) {
- vectorLength.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
- vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ vectorLength.push_back(vectorLengthValue);
+ vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2115,12 +2159,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
- deviceTypeClause->v;
- assert(deviceTypeExprList.v.size() == 1 &&
- "expect only one device_type expr");
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
}
}
@@ -2239,10 +2279,11 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
// device_type attribute is set to `none` until a device_type clause is
// encountered.
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), mlir::acc::DeviceType::None);
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
- // Lower clauses values mapped to operands.
+ // Lower clauses values mapped to operands and array attributes.
// Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2323,19 +2364,23 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+ asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
- stmtCtx);
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
hasDefaultNone = true;
else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present)
hasDefaultPresent = true;
+ } else if (const auto *deviceTypeClause =
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
+ &clause.u)) {
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
}
}
@@ -2727,8 +2772,7 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
- deviceTypes, stmtCtx);
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
}
}
@@ -2777,8 +2821,7 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
- deviceTypes, stmtCtx);
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
}
}
@@ -2835,8 +2878,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
- deviceTypes, stmtCtx);
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
} else if (const auto *hostClause =
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -3592,15 +3634,16 @@ void Fortran::lower::genOpenACCRoutineConstruct(
// device_type attribute is set to `none` until a device_type clause is
// encountered.
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), mlir::acc::DeviceType::None);
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
for (const Fortran::parser::AccClause &clause : clauses.v) {
if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
- seqDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ seqDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *gangClause =
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-
if (gangClause->v) {
const Fortran::parser::AccGangArgList &x = *gangClause->v;
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3611,27 +3654,36 @@ void Fortran::lower::genOpenACCRoutineConstruct(
if (!dimValue)
mlir::emitError(loc,
"dim value must be a constant positive integer");
- gangDimValues.push_back(
- builder.getIntegerAttr(builder.getI64Type(), *dimValue));
- gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Attribute gangDimAttr =
+ builder.getIntegerAttr(builder.getI64Type(), *dimValue);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ gangDimValues.push_back(gangDimAttr);
+ gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
}
}
} else {
- gangDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ gangDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
- vectorDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ vectorDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
- workerDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ workerDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
hasNohost = true;
} else if (const auto *bindClause =
std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
if (const auto *name =
std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
- bindNames.push_back(
- builder.getStringAttr(converter.mangleName(*name->symbol)));
- bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Attribute bindNameAttr =
+ builder.getStringAttr(converter.mangleName(*name->symbol));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ bindNames.push_back(bindNameAttr);
+ bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto charExpr =
std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
&bindClause->v.u)) {
@@ -3640,18 +3692,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
*charExpr);
if (!name)
mlir::emitError(loc, "Could not retrieve the bind name");
- bindNames.push_back(builder.getStringAttr(*name));
- bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+
+ mlir::Attribute bindNameAttr = builder.getStringAttr(*name);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ bindNames.push_back(bindNameAttr);
+ bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
}
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
- deviceTypeClause->v;
- assert(deviceTypeExprList.v.size() == 1 &&
- "expect only one device_type expr");
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
}
}
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
index 871dbc95f60fcb..b8feaa5c5f41ff 100644
--- a/flang/test/Lower/OpenACC/acc-device-type.f90
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -40,5 +40,9 @@ subroutine sub1()
! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+ !$acc parallel device_type(nvidia, default) num_gangs(1, 1, 1)
+ !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c1_i32_9 : i32, %c1_i32_10 : i32, %c1_i32_11 : i32} [#acc.device_type<nvidia>], {%c1_i32_9 : i32, %c1_i32_10 : i32, %c1_i32_11 : i32} [#acc.device_type<default>])
end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index 42e14afb35f522..59c2513332a976 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -326,4 +326,10 @@ program acc_loop
! CHECK: acc.loop gang([#acc.device_type<none>], {num=%c8{{.*}} : i32} [#acc.device_type<nvidia>])
+ !$acc loop device_type(nvidia, default) gang
+ DO i = 1, n
+ END DO
+
+! CHECK: acc.loop gang([#acc.device_type<nvidia>, #acc.device_type<default>]) {
+
end program
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 2fe150e70b0cfb..1170af18bc3341 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,6 +2,7 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+! CHECK: acc.routine @acc_routine_17 func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
@@ -120,3 +121,7 @@ subroutine acc_routine17()
subroutine acc_routine18()
!$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16)
end subroutine
+
+subroutine acc_routine19()
+ !$acc routine device_type(host,default) bind(acc_routine17) dtype(multicore) bind(acc_routine16)
+end subroutine
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 20465f6bb86ed1..bc03adbcae64df 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1449,7 +1449,8 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
std::optional<mlir::DenseI32ArrayAttr> segments,
std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
- if (operands.begin() == operands.end() && gangOnlyDeviceTypes &&
+ if (operands.begin() == operands.end() &&
+ hasDeviceTypeValues(gangOnlyDeviceTypes) &&
gangOnlyDeviceTypes->size() == 1) {
auto deviceTypeAttr =
mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
@@ -1464,7 +1465,7 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
hasDeviceTypeValues(deviceTypes))
p << ", ";
- if (deviceTypes) {
+ if (hasDeviceTypeValues(deviceTypes)) {
unsigned opIdx = 0;
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
p << "{";
>From 10e23db8a458f0716594e10185834fd3b60ea52a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
=?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
=?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Thu, 18 Jan 2024 21:05:14 -0800
Subject: [PATCH 2/2] Update flang/test/Lower/OpenACC/acc-device-type.f90
---
flang/test/Lower/OpenACC/acc-device-type.f90 | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
index b8feaa5c5f41ff..ae01d0dc5fcde3 100644
--- a/flang/test/Lower/OpenACC/acc-device-type.f90
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -43,6 +43,6 @@ subroutine sub1()
!$acc parallel device_type(nvidia, default) num_gangs(1, 1, 1)
!$acc end parallel
-! CHECK: acc.parallel num_gangs({%c1_i32_9 : i32, %c1_i32_10 : i32, %c1_i32_11 : i32} [#acc.device_type<nvidia>], {%c1_i32_9 : i32, %c1_i32_10 : i32, %c1_i32_11 : i32} [#acc.device_type<default>])
+! CHECK: acc.parallel num_gangs({%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<default>])
end subroutine
More information about the flang-commits
mailing list