[flang-commits] [flang] [flang][acc] Improve lowering of Fortran optional in data clause (PR #102224)
via flang-commits
flang-commits at lists.llvm.org
Tue Aug 6 14:08:49 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-flang-fir-hlfir
Author: Razvan Lupusoru (razvanlupusoru)
<details>
<summary>Changes</summary>
Fortran optional arguments are effectively null references. To deal with this possibility, flang lowering of OpenACC data clauses creates three if-else regions when preparing the data pointer for the data clause:
1) Load box value from box reference
2) Load box addr from box value
3) Load box dims from box value
However, this pattern makes it more complicated to find the original box reference. Effectively, the first if-else region to get the box value is not needed - since the value can be loaded before the corresponding `fir.box_addr` and `fir.box_dims` operations. Thus, reduce the number of if-else regions by deferring the box load to the use sites.
For non-optional cases, the old functionality is left alone - which preloads the box value.
---
Full diff: https://github.com/llvm/llvm-project/pull/102224.diff
3 Files Affected:
- (modified) flang/lib/Lower/DirectivesCommon.h (+49-26)
- (modified) flang/lib/Lower/OpenACC.cpp (+36-15)
- (modified) flang/test/Lower/OpenACC/acc-bounds.f90 (+3-8)
``````````diff
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 24cb7c2fcf7db8..f6b16e9adbbba2 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -58,9 +58,20 @@ struct AddrAndBoundsInfo {
explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
mlir::Value isPresent)
: addr(addr), rawInput(rawInput), isPresent(isPresent) {}
+ explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
+ mlir::Value isPresent, mlir::Type boxType)
+ : addr(addr), rawInput(rawInput), isPresent(isPresent), boxType(boxType) {
+ }
mlir::Value addr = nullptr;
mlir::Value rawInput = nullptr;
mlir::Value isPresent = nullptr;
+ mlir::Type boxType = nullptr;
+ void dump(llvm::raw_ostream &os) {
+ os << "AddrAndBoundsInfo addr: " << addr << "\n";
+ os << "AddrAndBoundsInfo rawInput: " << rawInput << "\n";
+ os << "AddrAndBoundsInfo isPresent: " << isPresent << "\n";
+ os << "AddrAndBoundsInfo boxType: " << boxType << "\n";
+ }
};
/// Checks if the assignment statement has a single variable on the RHS.
@@ -674,27 +685,18 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
TODO(loc, "derived type");
- // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
- // `fir.ref<fir.class<T>>` type.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType())) {
- if (Fortran::semantics::IsOptional(sym)) {
- mlir::Value addr =
- builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
- .genThen([&]() {
- mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr);
- builder.create<fir::ResultOp>(loc, mlir::ValueRange{load});
- })
- .genElse([&] {
- mlir::Value absent =
- builder.create<fir::AbsentOp>(loc, boxTy);
- builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
- })
- .getResults()[0];
- return AddrAndBoundsInfo(addr, rawInput, isPresent);
- }
+ // In case of a box reference, load it here to get the box value.
+ // This is preferrable because then the same box value can then be used for
+ // all address/dimension retrievals. For Fortran optional though, leave
+ // the load generation for later so it can be done in the appropriate
+ // if branches.
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
+ !Fortran::semantics::IsOptional(sym)) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
- return AddrAndBoundsInfo(addr, rawInput, isPresent);
+ return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
+
+ return AddrAndBoundsInfo(symAddr, rawInput, isPresent, boxTy);
}
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
@@ -704,6 +706,7 @@ llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
fir::ExtendedValue dataExv, mlir::Value box,
bool collectValuesOnly = false) {
+ assert(box && "box must exist");
llvm::SmallVector<mlir::Value> values;
mlir::Value byteStride;
mlir::Type idxTy = builder.getIndexType();
@@ -748,8 +751,10 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type idxTy = builder.getIndexType();
mlir::Type boundTy = builder.getType<BoundsType>();
- assert(mlir::isa<fir::BaseBoxType>(info.addr.getType()) &&
+ assert(mlir::isa<fir::BaseBoxType>(info.boxType) &&
"expect fir.box or fir.class");
+ assert(fir::unwrapRefType(info.addr.getType()) == info.boxType &&
+ "expected box type consistency");
if (info.isPresent) {
llvm::SmallVector<mlir::Type> resTypes;
@@ -760,9 +765,13 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Operation::result_range ifRes =
builder.genIfOp(loc, resTypes, info.isPresent, /*withElseRegion=*/true)
.genThen([&]() {
+ mlir::Value box =
+ !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
llvm::SmallVector<mlir::Value> boundValues =
gatherBoundsOrBoundValues<BoundsOp, BoundsType>(
- builder, loc, dataExv, info.addr,
+ builder, loc, dataExv, box,
/*collectValuesOnly=*/true);
builder.create<fir::ResultOp>(loc, boundValues);
})
@@ -790,8 +799,11 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
bounds.push_back(bound);
}
} else {
+ mlir::Value box = !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
bounds = gatherBoundsOrBoundValues<BoundsOp, BoundsType>(
- builder, loc, dataExv, info.addr);
+ builder, loc, dataExv, box);
}
return bounds;
}
@@ -941,10 +953,14 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
builder
.genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
.genThen([&]() {
+ mlir::Value box =
+ !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
mlir::Value d =
builder.createIntegerConstant(loc, idxTy, dimension);
auto dimInfo = builder.create<fir::BoxDimsOp>(
- loc, idxTy, idxTy, idxTy, info.addr, d);
+ loc, idxTy, idxTy, idxTy, box, d);
builder.create<fir::ResultOp>(loc, dimInfo.getByteStride());
})
.genElse([&] {
@@ -954,9 +970,12 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
})
.getResults()[0];
} else {
+ mlir::Value box = !fir::isBoxAddress(info.addr.getType())
+ ? info.addr
+ : builder.create<fir::LoadOp>(loc, info.addr);
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
- auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy,
- idxTy, info.addr, d);
+ auto dimInfo =
+ builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, d);
stride = dimInfo.getByteStride();
}
strideInBytes = true;
@@ -1197,8 +1216,10 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
if (auto loadOp =
mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) {
if (fir::isAllocatableType(loadOp.getType()) ||
- fir::isPointerType(loadOp.getType()))
+ fir::isPointerType(loadOp.getType())) {
+ info.boxType = info.addr.getType();
info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr);
+ }
info.rawInput = info.addr;
}
@@ -1209,6 +1230,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
if (auto boxAddrOp =
mlir::dyn_cast_or_null<fir::BoxAddrOp>(info.addr.getDefiningOp())) {
info.addr = boxAddrOp.getVal();
+ info.boxType = info.addr.getType();
info.rawInput = info.addr;
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, compExv, info);
@@ -1227,6 +1249,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
getDataOperandBaseAddr(converter, builder, *symRef, operandLocation);
if (mlir::isa<fir::BaseBoxType>(
fir::unwrapRefType(info.addr.getType()))) {
+ info.boxType = fir::unwrapRefType(info.addr.getType());
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, dataExv, info);
}
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 6266a5056ace85..40ba20d707e2c9 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -32,6 +32,9 @@
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "llvm/Frontend/OpenACC/ACC.h.inc"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "flang-lower-openacc"
// Special value for * passed in device_type or gang clauses.
static constexpr std::int64_t starCst = -1;
@@ -85,11 +88,17 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type retTy, llvm::ArrayRef<mlir::Value> async,
llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
- mlir::Value isPresent = {}) {
+ bool unwrapBoxAddr = false, mlir::Value isPresent = {}) {
mlir::Value varPtrPtr;
- if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
+ // The data clause may apply to either the box reference itself or the
+ // pointer to the data it holds. So use `unwrapBoxAddr` to decide.
+ // When we have a box value - assume it refers to the data inside box.
+ if ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) ||
+ fir::isa_box_type(baseAddr.getType())) {
if (isPresent) {
- mlir::Type ifRetTy = boxTy.getEleTy();
+ mlir::Type ifRetTy =
+ mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))
+ .getEleTy();
if (!fir::isa_ref_type(ifRetTy))
ifRetTy = fir::ReferenceType::get(ifRetTy);
baseAddr =
@@ -97,6 +106,9 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
.genIfOp(loc, {ifRetTy}, isPresent,
/*withElseRegion=*/true)
.genThen([&]() {
+ if (fir::isBoxAddress(baseAddr.getType())) {
+ baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
+ }
mlir::Value boxAddr =
builder.create<fir::BoxAddrOp>(loc, baseAddr);
builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr});
@@ -108,6 +120,9 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
})
.getResults()[0];
} else {
+ if (fir::isBoxAddress(baseAddr.getType())) {
+ baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
+ }
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
}
retTy = baseAddr.getType();
@@ -342,18 +357,19 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds,
/*treatIndexAsSection=*/true);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
- mlir::Value baseAddr =
- ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
- info.isPresent)
- ? info.rawInput
- : info.addr;
- Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
- bounds, structured, implicit, dataClause,
- baseAddr.getType(), async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, info.isPresent);
+ mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) !=
+ fir::unwrapRefType(info.rawInput.getType())) &&
+ info.isPresent)
+ ? info.rawInput
+ : info.addr;
+ Op op = createDataEntryOp<Op>(
+ builder, operandLocation, baseAddr, asFortran, bounds, structured,
+ implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
dataOperands.push_back(op.getAccPtr());
}
}
@@ -380,6 +396,7 @@ static void genDeclareDataOperandOperations(
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType(),
@@ -842,6 +859,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
+
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
@@ -853,7 +872,7 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
auto op = createDataEntryOp<mlir::acc::PrivateOp>(
builder, operandLocation, info.addr, asFortran, bounds, true,
/*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async,
- asyncDeviceTypes, asyncOnlyDeviceTypes);
+ asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
dataOperands.push_back(op.getAccPtr());
} else {
std::string suffix =
@@ -865,7 +884,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
builder, operandLocation, info.addr, asFortran, bounds, true,
/*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy,
- async, asyncDeviceTypes, asyncOnlyDeviceTypes);
+ async, asyncDeviceTypes, asyncOnlyDeviceTypes,
+ /*unwrapBoxAddr=*/true);
dataOperands.push_back(op.getAccPtr());
}
privatizations.push_back(mlir::SymbolRefAttr::get(
@@ -1421,6 +1441,7 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
@@ -1433,7 +1454,7 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
builder, operandLocation, info.addr, asFortran, bounds,
/*structured=*/true, /*implicit=*/false,
mlir::acc::DataClause::acc_reduction, info.addr.getType(), async,
- asyncDeviceTypes, asyncOnlyDeviceTypes);
+ asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
mlir::Type ty = op.getAccPtr().getType();
if (!areAllBoundConstant(bounds) ||
fir::isAssumedShape(info.addr.getType()) ||
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index a83de91a67aed1..e44c786e629645 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -128,14 +128,8 @@ subroutine acc_optional_data(a)
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{[0-9]+}} {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
-! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
-! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-! CHECK: fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
-! CHECK: } else {
-! CHECK: %[[ABSENT:.*]] = fir.absent !fir.box<!fir.ptr<!fir.array<?xf32>>>
-! CHECK: fir.result %[[ABSENT]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
-! CHECK: }
! CHECK: %[[RES:.*]]:5 = fir.if %[[IS_PRESENT]] -> (index, index, index, index, index) {
+! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
! CHECK: fir.result %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : index, index, index, index, index
! CHECK: } else {
! CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -144,7 +138,8 @@ subroutine acc_optional_data(a)
! CHECK: }
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[RES]]#0 : index) upperbound(%[[RES]]#1 : index) extent(%[[RES]]#2 : index) stride(%[[RES]]#3 : index) startIdx(%[[RES]]#4 : index) {strideInBytes = true}
! CHECK: %[[BOX_ADDR:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.ptr<!fir.array<?xf32>>) {
-! CHECK: %[[ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
+! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+! CHECK: %[[ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
! CHECK: fir.result %[[ADDR]] : !fir.ptr<!fir.array<?xf32>>
! CHECK: } else {
! CHECK: %[[ABSENT:.*]] = fir.absent !fir.ptr<!fir.array<?xf32>>
``````````
</details>
https://github.com/llvm/llvm-project/pull/102224
More information about the flang-commits
mailing list