[flang-commits] [flang] 7634a96 - [flang][acc] Improve lowering of Fortran optional in data clause (#102224)
via flang-commits
flang-commits at lists.llvm.org
Wed Aug 7 08:04:10 PDT 2024
Author: Razvan Lupusoru
Date: 2024-08-07T08:04:06-07:00
New Revision: 7634a96589637186b640a0441c0544a9868d9913
URL: https://github.com/llvm/llvm-project/commit/7634a96589637186b640a0441c0544a9868d9913
DIFF: https://github.com/llvm/llvm-project/commit/7634a96589637186b640a0441c0544a9868d9913.diff
LOG: [flang][acc] Improve lowering of Fortran optional in data clause (#102224)
Fortran optional arguments are effectively null references. To deal with
this possibility, flang lowering of OpenACC data clauses creates three
if-else regions when preparing the data pointer for the data clause:
1) Load box value from box reference
2) Load box addr from box value
3) Load box dims from box value
However, this pattern makes it more complicated to find the original box
reference. Effectively, the first if-else region to get the box value is
not needed - since the value can be loaded before the corresponding
`fir.box_addr` and `fir.box_dims` operations. Thus, reduce the number of
if-else regions by deferring the box load to the use sites.
For non-optional cases, the old functionality is left alone - which
preloads the box value.
Added:
Modified:
flang/lib/Lower/DirectivesCommon.h
flang/lib/Lower/OpenACC.cpp
flang/test/Lower/OpenACC/acc-bounds.f90
Removed:
################################################################################
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 24cb7c2fcf7db8..d8b1f1f3e43621 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -58,9 +58,20 @@ struct AddrAndBoundsInfo {
explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
mlir::Value isPresent)
: addr(addr), rawInput(rawInput), isPresent(isPresent) {}
+ explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
+ mlir::Value isPresent, mlir::Type boxType)
+ : addr(addr), rawInput(rawInput), isPresent(isPresent), boxType(boxType) {
+ }
mlir::Value addr = nullptr;
mlir::Value rawInput = nullptr;
mlir::Value isPresent = nullptr;
+ mlir::Type boxType = nullptr;
+ void dump(llvm::raw_ostream &os) {
+ os << "AddrAndBoundsInfo addr: " << addr << "\n";
+ os << "AddrAndBoundsInfo rawInput: " << rawInput << "\n";
+ os << "AddrAndBoundsInfo isPresent: " << isPresent << "\n";
+ os << "AddrAndBoundsInfo boxType: " << boxType << "\n";
+ }
};
/// Checks if the assignment statement has a single variable on the RHS.
@@ -674,27 +685,18 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
TODO(loc, "derived type");
- // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
- // `fir.ref<fir.class<T>>` type.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType())) {
- if (Fortran::semantics::IsOptional(sym)) {
- mlir::Value addr =
- builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
- .genThen([&]() {
- mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr);
- builder.create<fir::ResultOp>(loc, mlir::ValueRange{load});
- })
- .genElse([&] {
- mlir::Value absent =
- builder.create<fir::AbsentOp>(loc, boxTy);
- builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
- })
- .getResults()[0];
- return AddrAndBoundsInfo(addr, rawInput, isPresent);
- }
+ // In case of a box reference, load it here to get the box value.
+ // This is preferrable because then the same box value can then be used for
+ // all address/dimension retrievals. For Fortran optional though, leave
+ // the load generation for later so it can be done in the appropriate
+ // if branches.
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
+ !Fortran::semantics::IsOptional(sym)) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
- return AddrAndBoundsInfo(addr, rawInput, isPresent);
+ return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
+
+ return AddrAndBoundsInfo(symAddr, rawInput, isPresent, boxTy);
}
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
@@ -704,6 +706,7 @@ llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
fir::ExtendedValue dataExv, mlir::Value box,
bool collectValuesOnly = false) {
+ assert(box && "box must exist");
llvm::SmallVector<mlir::Value> values;
mlir::Value byteStride;
mlir::Type idxTy = builder.getIndexType();
@@ -748,8 +751,10 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type idxTy = builder.getIndexType();
mlir::Type boundTy = builder.getType<BoundsType>();
- assert(mlir::isa<fir::BaseBoxType>(info.addr.getType()) &&
+ assert(mlir::isa<fir::BaseBoxType>(info.boxType) &&
"expect fir.box or fir.class");
+ assert(fir::unwrapRefType(info.addr.getType()) == info.boxType &&
+ "expected box type consistency");
if (info.isPresent) {
llvm::SmallVector<mlir::Type> resTypes;
@@ -760,9 +765,13 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Operation::result_range ifRes =
builder.genIfOp(loc, resTypes, info.isPresent, /*withElseRegion=*/true)
.genThen([&]() {
+ mlir::Value box =
+ !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
llvm::SmallVector<mlir::Value> boundValues =
gatherBoundsOrBoundValues<BoundsOp, BoundsType>(
- builder, loc, dataExv, info.addr,
+ builder, loc, dataExv, box,
/*collectValuesOnly=*/true);
builder.create<fir::ResultOp>(loc, boundValues);
})
@@ -790,8 +799,11 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
bounds.push_back(bound);
}
} else {
- bounds = gatherBoundsOrBoundValues<BoundsOp, BoundsType>(
- builder, loc, dataExv, info.addr);
+ mlir::Value box = !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
+ bounds = gatherBoundsOrBoundValues<BoundsOp, BoundsType>(builder, loc,
+ dataExv, box);
}
return bounds;
}
@@ -941,10 +953,14 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
builder
.genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
.genThen([&]() {
+ mlir::Value box =
+ !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
mlir::Value d =
builder.createIntegerConstant(loc, idxTy, dimension);
auto dimInfo = builder.create<fir::BoxDimsOp>(
- loc, idxTy, idxTy, idxTy, info.addr, d);
+ loc, idxTy, idxTy, idxTy, box, d);
builder.create<fir::ResultOp>(loc, dimInfo.getByteStride());
})
.genElse([&] {
@@ -954,9 +970,12 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
})
.getResults()[0];
} else {
+ mlir::Value box = !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
- auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy,
- idxTy, info.addr, d);
+ auto dimInfo =
+ builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, d);
stride = dimInfo.getByteStride();
}
strideInBytes = true;
@@ -1197,8 +1216,10 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
if (auto loadOp =
mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) {
if (fir::isAllocatableType(loadOp.getType()) ||
- fir::isPointerType(loadOp.getType()))
+ fir::isPointerType(loadOp.getType())) {
+ info.boxType = info.addr.getType();
info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr);
+ }
info.rawInput = info.addr;
}
@@ -1209,6 +1230,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
if (auto boxAddrOp =
mlir::dyn_cast_or_null<fir::BoxAddrOp>(info.addr.getDefiningOp())) {
info.addr = boxAddrOp.getVal();
+ info.boxType = info.addr.getType();
info.rawInput = info.addr;
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, compExv, info);
@@ -1227,6 +1249,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
getDataOperandBaseAddr(converter, builder, *symRef, operandLocation);
if (mlir::isa<fir::BaseBoxType>(
fir::unwrapRefType(info.addr.getType()))) {
+ info.boxType = fir::unwrapRefType(info.addr.getType());
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, dataExv, info);
}
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 6266a5056ace85..be184aeead6ee5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -32,6 +32,9 @@
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "llvm/Frontend/OpenACC/ACC.h.inc"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "flang-lower-openacc"
// Special value for * passed in device_type or gang clauses.
static constexpr std::int64_t starCst = -1;
@@ -85,11 +88,17 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type retTy, llvm::ArrayRef<mlir::Value> async,
llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
- mlir::Value isPresent = {}) {
+ bool unwrapBoxAddr = false, mlir::Value isPresent = {}) {
mlir::Value varPtrPtr;
- if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
+ // The data clause may apply to either the box reference itself or the
+ // pointer to the data it holds. So use `unwrapBoxAddr` to decide.
+ // When we have a box value - assume it refers to the data inside box.
+ if ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) ||
+ fir::isa_box_type(baseAddr.getType())) {
if (isPresent) {
- mlir::Type ifRetTy = boxTy.getEleTy();
+ mlir::Type ifRetTy =
+ mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))
+ .getEleTy();
if (!fir::isa_ref_type(ifRetTy))
ifRetTy = fir::ReferenceType::get(ifRetTy);
baseAddr =
@@ -97,6 +106,8 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
.genIfOp(loc, {ifRetTy}, isPresent,
/*withElseRegion=*/true)
.genThen([&]() {
+ if (fir::isBoxAddress(baseAddr.getType()))
+ baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
mlir::Value boxAddr =
builder.create<fir::BoxAddrOp>(loc, baseAddr);
builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr});
@@ -108,6 +119,8 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
})
.getResults()[0];
} else {
+ if (fir::isBoxAddress(baseAddr.getType()))
+ baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
}
retTy = baseAddr.getType();
@@ -342,18 +355,19 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds,
/*treatIndexAsSection=*/true);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
- mlir::Value baseAddr =
- ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
- info.isPresent)
- ? info.rawInput
- : info.addr;
- Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
- bounds, structured, implicit, dataClause,
- baseAddr.getType(), async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, info.isPresent);
+ mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) !=
+ fir::unwrapRefType(info.rawInput.getType())) &&
+ info.isPresent)
+ ? info.rawInput
+ : info.addr;
+ Op op = createDataEntryOp<Op>(
+ builder, operandLocation, baseAddr, asFortran, bounds, structured,
+ implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
dataOperands.push_back(op.getAccPtr());
}
}
@@ -380,6 +394,7 @@ static void genDeclareDataOperandOperations(
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType(),
@@ -842,6 +857,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
+
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
@@ -853,7 +870,7 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
auto op = createDataEntryOp<mlir::acc::PrivateOp>(
builder, operandLocation, info.addr, asFortran, bounds, true,
/*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async,
- asyncDeviceTypes, asyncOnlyDeviceTypes);
+ asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
dataOperands.push_back(op.getAccPtr());
} else {
std::string suffix =
@@ -865,7 +882,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
builder, operandLocation, info.addr, asFortran, bounds, true,
/*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy,
- async, asyncDeviceTypes, asyncOnlyDeviceTypes);
+ async, asyncDeviceTypes, asyncOnlyDeviceTypes,
+ /*unwrapBoxAddr=*/true);
dataOperands.push_back(op.getAccPtr());
}
privatizations.push_back(mlir::SymbolRefAttr::get(
@@ -1421,6 +1439,7 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
@@ -1433,7 +1452,7 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
builder, operandLocation, info.addr, asFortran, bounds,
/*structured=*/true, /*implicit=*/false,
mlir::acc::DataClause::acc_reduction, info.addr.getType(), async,
- asyncDeviceTypes, asyncOnlyDeviceTypes);
+ asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
mlir::Type ty = op.getAccPtr().getType();
if (!areAllBoundConstant(bounds) ||
fir::isAssumedShape(info.addr.getType()) ||
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index a83de91a67aed1..e44c786e629645 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -128,14 +128,8 @@ subroutine acc_optional_data(a)
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{[0-9]+}} {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
-! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
-! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-! CHECK: fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
-! CHECK: } else {
-! CHECK: %[[ABSENT:.*]] = fir.absent !fir.box<!fir.ptr<!fir.array<?xf32>>>
-! CHECK: fir.result %[[ABSENT]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
-! CHECK: }
! CHECK: %[[RES:.*]]:5 = fir.if %[[IS_PRESENT]] -> (index, index, index, index, index) {
+! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
! CHECK: fir.result %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : index, index, index, index, index
! CHECK: } else {
! CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -144,7 +138,8 @@ subroutine acc_optional_data(a)
! CHECK: }
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[RES]]#0 : index) upperbound(%[[RES]]#1 : index) extent(%[[RES]]#2 : index) stride(%[[RES]]#3 : index) startIdx(%[[RES]]#4 : index) {strideInBytes = true}
! CHECK: %[[BOX_ADDR:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.ptr<!fir.array<?xf32>>) {
-! CHECK: %[[ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
+! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+! CHECK: %[[ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
! CHECK: fir.result %[[ADDR]] : !fir.ptr<!fir.array<?xf32>>
! CHECK: } else {
! CHECK: %[[ABSENT:.*]] = fir.absent !fir.ptr<!fir.array<?xf32>>
More information about the flang-commits
mailing list