[flang-commits] [flang] [mlir] [mlir][acc] Added async to data clause operations. (PR #97307)
via flang-commits
flang-commits at lists.llvm.org
Mon Jul 1 08:18:51 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
As long as the data clause operations are not tightly
"associated" with the compute/data operations (e.g.
they can be optimized as SSA producers and made block
arguments), the information about the original async()
clause should be attached to the data clause operations
to make it easier to generate proper runtime actions
for them. This change propagates the async() information
from the OpenACC data/compute constructs to the data clause
operations. This change also adds the CurrentDeviceIdResource
to guarantee proper ordering of the operations that read
and write the current device identifier.
---
Patch is 102.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97307.diff
10 Files Affected:
- (modified) flang/lib/Lower/OpenACC.cpp (+276-135)
- (modified) flang/test/Lower/OpenACC/acc-data.f90 (+4-2)
- (modified) flang/test/Lower/OpenACC/acc-enter-data.f90 (+6-6)
- (modified) flang/test/Lower/OpenACC/acc-exit-data.f90 (+8-8)
- (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+7-7)
- (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+2-2)
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+12-12)
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACC.h (+20)
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+215-20)
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+30)
``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 166fa686cd883..6266a5056ace8 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -58,13 +58,34 @@ genOperandLocation(Fortran::lower::AbstractConverter &converter,
return loc;
}
+static void addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<int32_t> &operandSegments,
+ llvm::ArrayRef<mlir::Value> clauseOperands) {
+ operands.append(clauseOperands.begin(), clauseOperands.end());
+ operandSegments.push_back(clauseOperands.size());
+}
+
+static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<int32_t> &operandSegments,
+ const mlir::Value &clauseOperand) {
+ if (clauseOperand) {
+ operands.push_back(clauseOperand);
+ operandSegments.push_back(1);
+ } else {
+ operandSegments.push_back(0);
+ }
+}
+
template <typename Op>
-static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Value baseAddr, std::stringstream &name,
- mlir::SmallVector<mlir::Value> bounds,
- bool structured, bool implicit,
- mlir::acc::DataClause dataClause, mlir::Type retTy,
- mlir::Value isPresent = {}) {
+static Op
+createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value baseAddr, std::stringstream &name,
+ mlir::SmallVector<mlir::Value> bounds, bool structured,
+ bool implicit, mlir::acc::DataClause dataClause,
+ mlir::Type retTy, llvm::ArrayRef<mlir::Value> async,
+ llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
+ mlir::Value isPresent = {}) {
mlir::Value varPtrPtr;
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
if (isPresent) {
@@ -92,20 +113,25 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
retTy = baseAddr.getType();
}
- Op op = builder.create<Op>(loc, retTy, baseAddr);
+ llvm::SmallVector<mlir::Value, 8> operands;
+ llvm::SmallVector<int32_t, 8> operandSegments;
+
+ addOperand(operands, operandSegments, baseAddr);
+ addOperand(operands, operandSegments, varPtrPtr);
+ addOperands(operands, operandSegments, bounds);
+ addOperands(operands, operandSegments, async);
+
+ Op op = builder.create<Op>(loc, retTy, operands);
op.setNameAttr(builder.getStringAttr(name.str()));
op.setStructured(structured);
op.setImplicit(implicit);
op.setDataClause(dataClause);
-
- unsigned insPos = 1;
- if (varPtrPtr)
- op->insertOperands(insPos++, varPtrPtr);
- if (bounds.size() > 0)
- op->insertOperands(insPos, bounds);
op->setAttr(Op::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr(
- {1, varPtrPtr ? 1 : 0, static_cast<int32_t>(bounds.size())}));
+ builder.getDenseI32ArrayAttr(operandSegments));
+ if (!asyncDeviceTypes.empty())
+ op.setAsyncOperandsDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ if (!asyncOnlyDeviceTypes.empty())
+ op.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
return op;
}
@@ -174,7 +200,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
createDataEntryOp<mlir::acc::UpdateDeviceOp>(
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
- mlir::acc::DataClause::acc_update_device, descTy);
+ mlir::acc::DataClause::acc_update_device, descTy,
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -185,7 +212,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
EntryOp entryOp = createDataEntryOp<EntryOp>(
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
- /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
+ /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
builder.create<mlir::acc::DeclareEnterOp>(
loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
mlir::ValueRange(entryOp.getAccPtr()));
@@ -217,8 +245,8 @@ static void createDeclareDeallocFuncWithArg(
mlir::acc::GetDevicePtrOp entryOp =
createDataEntryOp<mlir::acc::GetDevicePtrOp>(
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
- /*structured=*/false, /*implicit=*/false, clause,
- boxAddrOp.getType());
+ /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
builder.create<mlir::acc::DeclareExitOp>(
loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
@@ -226,12 +254,16 @@ static void createDeclareDeallocFuncWithArg(
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
entryOp.getVarPtr(), entryOp.getBounds(),
- entryOp.getDataClause(),
+ entryOp.getAsyncOperands(),
+ entryOp.getAsyncOperandsDeviceTypeAttr(),
+ entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
/*structured=*/false, /*implicit=*/false,
builder.getStringAttr(*entryOp.getName()));
else
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
- entryOp.getBounds(), entryOp.getDataClause(),
+ entryOp.getBounds(), entryOp.getAsyncOperands(),
+ entryOp.getAsyncOperandsDeviceTypeAttr(),
+ entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
/*structured=*/false, /*implicit=*/false,
builder.getStringAttr(*entryOp.getName()));
@@ -248,7 +280,8 @@ static void createDeclareDeallocFuncWithArg(
createDataEntryOp<mlir::acc::UpdateDeviceOp>(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
- mlir::acc::DataClause::acc_update_device, loadOp.getType());
+ mlir::acc::DataClause::acc_update_device, loadOp.getType(),
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -290,7 +323,10 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured,
- bool implicit, bool setDeclareAttr = false) {
+ bool implicit, llvm::ArrayRef<mlir::Value> async,
+ llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
+ bool setDeclareAttr = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
@@ -316,7 +352,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
: info.addr;
Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
bounds, structured, implicit, dataClause,
- baseAddr.getType(), info.isPresent);
+ baseAddr.getType(), async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes, info.isPresent);
dataOperands.push_back(op.getAccPtr());
}
}
@@ -345,7 +382,8 @@ static void genDeclareDataOperandOperations(
operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
- implicit, dataClause, info.addr.getType());
+ implicit, dataClause, info.addr.getType(),
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
dataOperands.push_back(op.getAccPtr());
addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
@@ -397,13 +435,16 @@ static void genDataExitOperations(fir::FirOpBuilder &builder,
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
builder.create<ExitOp>(
entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
- entryOp.getBounds(), entryOp.getDataClause(), structured,
- entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
+ entryOp.getBounds(), entryOp.getAsyncOperands(),
+ entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
+ entryOp.getDataClause(), structured, entryOp.getImplicit(),
+ builder.getStringAttr(*entryOp.getName()));
else
- builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
- entryOp.getBounds(), entryOp.getDataClause(),
- structured, entryOp.getImplicit(),
- builder.getStringAttr(*entryOp.getName()));
+ builder.create<ExitOp>(
+ entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getBounds(),
+ entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(),
+ entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), structured,
+ entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
}
}
@@ -783,7 +824,10 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
- llvm::SmallVector<mlir::Attribute> &privatizations) {
+ llvm::SmallVector<mlir::Attribute> &privatizations,
+ llvm::ArrayRef<mlir::Value> async,
+ llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
@@ -808,7 +852,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
operandLocation, retTy);
auto op = createDataEntryOp<mlir::acc::PrivateOp>(
builder, operandLocation, info.addr, asFortran, bounds, true,
- /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy);
+ /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async,
+ asyncDeviceTypes, asyncOnlyDeviceTypes);
dataOperands.push_back(op.getAccPtr());
} else {
std::string suffix =
@@ -819,7 +864,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
builder, recipeName, operandLocation, retTy, bounds);
auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
builder, operandLocation, info.addr, asFortran, bounds, true,
- /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy);
+ /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy,
+ async, asyncDeviceTypes, asyncOnlyDeviceTypes);
dataOperands.push_back(op.getAccPtr());
}
privatizations.push_back(mlir::SymbolRefAttr::get(
@@ -1354,7 +1400,10 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
- llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
+ llvm::SmallVector<mlir::Attribute> &reductionRecipes,
+ llvm::ArrayRef<mlir::Value> async,
+ llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
const auto &op = std::get<Fortran::parser::ReductionOperator>(objectList.t);
@@ -1383,7 +1432,8 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
auto op = createDataEntryOp<mlir::acc::ReductionOp>(
builder, operandLocation, info.addr, asFortran, bounds,
/*structured=*/true, /*implicit=*/false,
- mlir::acc::DataClause::acc_reduction, info.addr.getType());
+ mlir::acc::DataClause::acc_reduction, info.addr.getType(), async,
+ asyncDeviceTypes, asyncOnlyDeviceTypes);
mlir::Type ty = op.getAccPtr().getType();
if (!areAllBoundConstant(bounds) ||
fir::isAssumedShape(info.addr.getType()) ||
@@ -1404,25 +1454,6 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
}
}
-static void
-addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<int32_t> &operandSegments,
- const llvm::SmallVectorImpl<mlir::Value> &clauseOperands) {
- operands.append(clauseOperands.begin(), clauseOperands.end());
- operandSegments.push_back(clauseOperands.size());
-}
-
-static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<int32_t> &operandSegments,
- const mlir::Value &clauseOperand) {
- if (clauseOperand) {
- operands.push_back(clauseOperand);
- operandSegments.push_back(1);
- } else {
- operandSegments.push_back(0);
- }
-}
-
template <typename Op, typename Terminator>
static Op
createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1656,7 +1687,8 @@ static void privatizeIv(Fortran::lower::AbstractConverter &converter,
std::stringstream asFortran;
auto op = createDataEntryOp<mlir::acc::PrivateOp>(
builder, loc, ivValue, asFortran, {}, true, /*implicit=*/true,
- mlir::acc::DataClause::acc_private, ivValue.getType());
+ mlir::acc::DataClause::acc_private, ivValue.getType(),
+ /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
privateOperands.push_back(op.getAccPtr());
privatizations.push_back(mlir::SymbolRefAttr::get(builder.getContext(),
@@ -1897,12 +1929,14 @@ static mlir::acc::LoopOp createLoopOp(
&clause.u)) {
genPrivatizations<mlir::acc::PrivateRecipeOp>(
privateClause->v, converter, semanticsContext, stmtCtx,
- privateOperands, privatizations);
+ privateOperands, privatizations, /*async=*/{},
+ /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
} else if (const auto *reductionClause =
std::get_if<Fortran::parser::AccClause::Reduction>(
&clause.u)) {
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
- reductionOperands, reductionRecipes);
+ reductionOperands, reductionRecipes, /*async=*/{},
+ /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
} else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
for (auto crtDeviceTypeAttr : crtDeviceTypes)
seqDeviceTypes.push_back(crtDeviceTypeAttr);
@@ -2088,6 +2122,9 @@ static void genDataOperandOperationsWithModifier(
llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
const mlir::acc::DataClause clause,
const mlir::acc::DataClause clauseWithModifier,
+ llvm::ArrayRef<mlir::Value> async,
+ llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
bool setDeclareAttr = false) {
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
const auto &accObjectList =
@@ -2099,7 +2136,8 @@ static void genDataOperandOperationsWithModifier(
(modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
stmtCtx, dataClauseOperands, dataClause,
- /*structured=*/true, /*implicit=*/false,
+ /*structured=*/true, /*implicit=*/false, async,
+ asyncDeviceTypes, asyncOnlyDeviceTypes,
setDeclareAttr);
}
@@ -2150,8 +2188,9 @@ static Op createComputeOp(
// Lower clauses values mapped to operands and array attributes.
// Keep track of each group of operands separately as clauses can appear
// more than once.
+
+ // Process the clauses that may have a specified device_type first.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
- mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
@@ -2193,8 +2232,19 @@ static Op createComputeOp(
vectorLength.push_back(vectorLengthValue);
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
}
- } else if (const auto *ifClause =
- std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
+ } else if (const auto *deviceTypeClause =
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
+ &clause.u)) {
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
+ }
+ }
+
+ // Process the clauses independent of device_type.
+ for (const Fortran::parser::AccClause &clause : accClauseList.v) {
+ mlir::Location clauseLocation = converter.genLocation(clause.source);
+ if (const auto *ifClause =
+ std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
} else if (const auto *selfClause =
std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
@@ -2237,7 +2287,8 @@ static Op createComputeOp(
genDataOperandOperations<mlir::acc::CopyinOp>(
copyClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_copy,
- /*structured=*/true, /*implicit=*/false);
+ /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes);
copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *copyinClause =
@@ -2247,7 +2298,8 @@ static Op createComputeOp(
copyinClause, converter, semanticsContext, stmtCtx,
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
dataClauseOperands, mlir::acc::DataClause::acc_copyin,
- mlir::acc::DataClause::acc_copyin_readonly);
+ mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes);
} else if (const auto *copyoutClause =
std::get_if<Fortran::parser::AccClause::Copyout>(
&clause.u)) {
@@ -2257,...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/97307
More information about the flang-commits
mailing list