[flang-commits] [flang] [flang][openacc/mp] Do not read bounds on absent box (PR #75252)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Dec 12 15:02:33 PST 2023
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/75252
Make sure we only load box and read its bounds when it is present.
Fix also some template parameter ordering issues.
>From c927c793a06cb4b5dec987e91a87f0bdac5c9e4f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 12 Dec 2023 13:36:23 -0800
Subject: [PATCH] [flang][openacc] Do not load optional box if not present
---
flang/lib/Lower/DirectivesCommon.h | 104 ++++++++++++++++++++----
flang/lib/Lower/OpenACC.cpp | 30 ++++---
flang/lib/Lower/OpenMP.cpp | 4 +-
flang/test/Lower/OpenACC/acc-bounds.f90 | 31 +++++++
flang/test/Lower/OpenACC/acc-data.f90 | 1 -
5 files changed, 137 insertions(+), 33 deletions(-)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 88a8916663df75..39f87202f90f5f 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -620,25 +620,36 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
// `fir.ref<fir.class<T>>` type.
- if (symAddr.getType().isa<fir::ReferenceType>())
+ if (symAddr.getType().isa<fir::ReferenceType>()) {
+ if (Fortran::semantics::IsOptional(sym)) {
+ mlir::Value isPresent =
+ builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
+ return builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr);
+ builder.create<fir::ResultOp>(loc, mlir::ValueRange{load});
+ })
+ .genElse([&] {
+ mlir::Value absent = builder.create<fir::AbsentOp>(loc, boxTy);
+ builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
+ })
+ .getResults()[0];
+ }
return builder.create<fir::LoadOp>(loc, symAddr);
+ }
}
return symAddr;
}
-/// Generate the bounds operation from the descriptor information.
template <typename BoundsOp, typename BoundsType>
-llvm::SmallVector<mlir::Value>
-genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
- Fortran::lower::AbstractConverter &converter,
+static llvm::SmallVector<mlir::Value>
+gatherBoundsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
fir::ExtendedValue dataExv, mlir::Value box) {
+ mlir::Value byteStride;
llvm::SmallVector<mlir::Value> bounds;
mlir::Type idxTy = builder.getIndexType();
mlir::Type boundTy = builder.getType<BoundsType>();
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
- assert(box.getType().isa<fir::BaseBoxType>() &&
- "expect fir.box or fir.class");
- mlir::Value byteStride;
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
mlir::Value baseLb =
@@ -660,6 +671,58 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
return bounds;
}
+/// Generate the bounds operation from the descriptor information.
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
+ Fortran::lower::AbstractConverter &converter,
+ fir::ExtendedValue dataExv, mlir::Value box,
+ bool isOptional = false) {
+ llvm::SmallVector<mlir::Value> bounds;
+ mlir::Type idxTy = builder.getIndexType();
+ mlir::Type boundTy = builder.getType<BoundsType>();
+
+ assert(box.getType().isa<fir::BaseBoxType>() &&
+ "expect fir.box or fir.class");
+
+ if (isOptional) {
+ mlir::Value isPresent =
+ builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), box);
+
+ llvm::SmallVector<mlir::Type> resTypes;
+ for (unsigned dim = 0; dim < dataExv.rank(); ++dim)
+ resTypes.push_back(boundTy);
+
+ auto ifOp =
+ builder.genIfOp(loc, resTypes, isPresent, /*withElseRegion=*/true)
+ .genThen([&]() {
+ llvm::SmallVector<mlir::Value> tempBounds =
+ gatherBoundsFromBox<BoundsOp, BoundsType>(builder, loc,
+ dataExv, box);
+ builder.create<fir::ResultOp>(loc, tempBounds);
+ })
+ .genElse([&] {
+ llvm::SmallVector<mlir::Value> tempBounds;
+ mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+ mlir::Value minusOne =
+ builder.createIntegerConstant(loc, idxTy, -1);
+ for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
+ mlir::Value bound = builder.create<BoundsOp>(
+ loc, boundTy, zero, minusOne, zero, mlir::Value(), false,
+ mlir::Value{});
+ tempBounds.push_back(bound);
+ }
+ builder.create<fir::ResultOp>(loc, tempBounds);
+ });
+ bounds.append(ifOp.getResults().begin(), ifOp.getResults().end());
+ } else {
+ llvm::SmallVector<mlir::Value> tempBounds =
+ gatherBoundsFromBox<BoundsOp, BoundsType>(builder, loc, dataExv, box);
+ bounds.append(tempBounds.begin(), tempBounds.end());
+ }
+ return bounds;
+}
+
/// Generate bounds operation for base array without any subscripts
/// provided.
template <typename BoundsOp, typename BoundsType>
@@ -885,20 +948,20 @@ mlir::Value gatherDataOperandAddrAndBounds(
if (!arrayElement->subscripts.empty()) {
asFortran << '(';
- bounds = genBoundsOps<BoundsType, BoundsOp>(
+ bounds = genBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, stmtCtx,
arrayElement->subscripts, asFortran, dataExv, baseAddr,
treatIndexAsSection);
}
asFortran << ')';
- } else if (Fortran::parser::Unwrap<
+ } else if (auto structComp = Fortran::parser::Unwrap<
Fortran::parser::StructureComponent>(designator)) {
fir::ExtendedValue compExv =
converter.genExprAddr(operandLocation, *expr, stmtCtx);
baseAddr = fir::getBase(compExv);
if (fir::unwrapRefType(baseAddr.getType())
.isa<fir::SequenceType>())
- bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
+ bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, compExv, baseAddr);
asFortran << (*expr).AsFortran();
@@ -917,8 +980,11 @@ mlir::Value gatherDataOperandAddrAndBounds(
if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
baseAddr.getDefiningOp())) {
baseAddr = boxAddrOp.getVal();
- bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
- builder, operandLocation, converter, compExv, baseAddr);
+ bool isOptional = Fortran::semantics::IsOptional(
+ *Fortran::parser::GetLastName(*structComp).symbol);
+ bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, compExv, baseAddr,
+ isOptional);
}
} else {
if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
@@ -943,12 +1009,16 @@ mlir::Value gatherDataOperandAddrAndBounds(
baseAddr = getDataOperandBaseAddr(
converter, builder, *name.symbol, operandLocation);
if (fir::unwrapRefType(baseAddr.getType())
- .isa<fir::BaseBoxType>())
- bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
- builder, operandLocation, converter, dataExv, baseAddr);
+ .isa<fir::BaseBoxType>()) {
+ bool isOptional =
+ Fortran::semantics::IsOptional(*name.symbol);
+ bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, dataExv, baseAddr,
+ isOptional);
+ }
if (fir::unwrapRefType(baseAddr.getType())
.isa<fir::SequenceType>())
- bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
+ bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, dataExv, baseAddr);
asFortran << name.ToString();
} else { // Unsupported
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e2abed1b9f4f67..531685948bc843 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -266,10 +266,11 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsType,
- mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
- accObject, operandLocation, asFortran, bounds,
- /*treatIndexAsSection=*/true);
+ Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+ mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+ stmtCtx, accObject, operandLocation,
+ asFortran, bounds,
+ /*treatIndexAsSection=*/true);
Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
bounds, structured, implicit, dataClause,
baseAddr.getType());
@@ -291,9 +292,10 @@ static void genDeclareDataOperandOperations(
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsType,
- mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
- accObject, operandLocation, asFortran, bounds);
+ Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+ mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+ stmtCtx, accObject, operandLocation,
+ asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, baseAddr, asFortran, bounds, structured,
implicit, dataClause, baseAddr.getType());
@@ -748,9 +750,10 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsType,
- mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
- accObject, operandLocation, asFortran, bounds);
+ Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+ mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+ stmtCtx, accObject, operandLocation,
+ asFortran, bounds);
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, baseAddr.getType());
@@ -1324,9 +1327,10 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsType,
- mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
- accObject, operandLocation, asFortran, bounds);
+ Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+ mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+ stmtCtx, accObject, operandLocation,
+ asFortran, bounds);
mlir::Type reductionTy = fir::unwrapRefType(baseAddr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..59e06e8458e6c0 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1794,8 +1794,8 @@ bool ClauseProcessor::processMap(
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsType,
- mlir::omp::DataBoundsOp>(
+ Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
+ mlir::omp::DataBoundsType>(
converter, firOpBuilder, semanticsContext, stmtCtx, ompObject,
clauseLocation, asFortran, bounds, treatIndexAsSection);
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index 8db18ab5aa9c4b..c8787c5e118f97 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -116,4 +116,35 @@ subroutine acc_multi_strides(a)
! CHECK: %[[PRESENT:.*]] = acc.present varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?x?x?xf32>>) bounds(%29, %33, %37) -> !fir.ref<!fir.array<?x?x?xf32>> {name = "a"}
! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?x?x?xf32>>) {
+ subroutine acc_optional_data(a)
+ real, pointer, optional :: a(:)
+ !$acc data attach(a)
+ !$acc end data
+ end subroutine
+
+ ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
+ ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
+ ! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+ ! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
+ ! CHECK: %[[ADDR:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
+ ! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ ! CHECK: fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ ! CHECK: } else {
+ ! CHECK: %[[ABSENT:.*]] = fir.absent !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ ! CHECK: fir.result %[[ABSENT]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ ! CHECK: }
+ ! CHECK: %[[BOUNDS:.*]] = fir.if %{{.*}} -> (!acc.data_bounds_ty) {
+ ! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
+ ! CHECK: fir.result %[[BOUND]] : !acc.data_bounds_ty
+ ! CHECK: } else {
+ ! CHECK: %[[C0:.*]] = arith.constant 0 : index
+ ! CHECK: %[[CM1:.*]] = arith.constant -1 : index
+ ! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[C0]] : index) upperbound(%[[CM1]] : index) extent(%[[C0]] : index)
+ ! CHECK: fir.result %[[BOUND]] : !acc.data_bounds_ty
+ ! CHECK: }
+ ! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ADDR]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
+ ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
+ ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
+
+
end module
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index d302be85c5df46..a6572e14707606 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -198,4 +198,3 @@ subroutine acc_data
! CHECK-NOT: acc.data
end subroutine acc_data
-
More information about the flang-commits
mailing list